all_close.hpp 6.42 KB
Newer Older
1
//*****************************************************************************
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
Scott Cyphers's avatar
Scott Cyphers committed
16 17 18

#pragma once

19
#include <cmath>
Scott Cyphers's avatar
Scott Cyphers committed
20
#include <memory>
21
#include <vector>
Scott Cyphers's avatar
Scott Cyphers committed
22

23
#include "gtest/gtest.h"
24
#include "ngraph/type/element_type.hpp"
25
#include "test_tools.hpp"
Scott Cyphers's avatar
Scott Cyphers committed
26 27 28 29 30

namespace ngraph
{
    namespace test
    {
31 32 33 34 35 36
        /// \brief Same as numpy.allclose
        /// \param a First tensor to compare
        /// \param b Second tensor to compare
        /// \param rtol Relative tolerance
        /// \param atol Absolute tolerance
        /// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
Robert Kimball's avatar
Robert Kimball committed
37
        template <typename T>
38
        typename std::enable_if<std::is_floating_point<T>::value, ::testing::AssertionResult>::type
39 40 41 42
            all_close(const std::vector<T>& a,
                      const std::vector<T>& b,
                      T rtol = static_cast<T>(1e-5),
                      T atol = static_cast<T>(1e-8))
Robert Kimball's avatar
Robert Kimball committed
43
        {
44
            bool rc = true;
45
            ::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
46 47 48 49
            if (a.size() != b.size())
            {
                throw std::invalid_argument("all_close: Argument vectors' sizes do not match");
            }
50
            size_t count = 0;
Robert Kimball's avatar
Robert Kimball committed
51 52
            for (size_t i = 0; i < a.size(); ++i)
            {
53 54
                if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i]) || !std::isfinite(a[i]) ||
                    !std::isfinite(b[i]))
Robert Kimball's avatar
Robert Kimball committed
55
                {
56 57
                    if (count < 5)
                    {
58
                        ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
59 60
                                << a[i] << " is not close to " << b[i] << " at index " << i
                                << std::endl;
61 62
                    }
                    count++;
63
                    rc = false;
64 65
                }
            }
66
            ar_fail << "diff count: " << count << " out of " << a.size() << std::endl;
67
            return rc ? ::testing::AssertionSuccess() : ar_fail;
68 69 70 71 72 73 74 75 76
        }

        /// \brief Same as numpy.allclose
        /// \param a First tensor to compare
        /// \param b Second tensor to compare
        /// \param rtol Relative tolerance
        /// \param atol Absolute tolerance
        /// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
        template <typename T>
77
        typename std::enable_if<std::is_integral<T>::value, ::testing::AssertionResult>::type
78 79 80 81 82 83
            all_close(const std::vector<T>& a,
                      const std::vector<T>& b,
                      T rtol = static_cast<T>(1e-5),
                      T atol = static_cast<T>(1e-8))
        {
            bool rc = true;
84
            ::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
85 86 87 88
            if (a.size() != b.size())
            {
                throw std::invalid_argument("all_close: Argument vectors' sizes do not match");
            }
89 90 91 92 93
            for (size_t i = 0; i < a.size(); ++i)
            {
                T abs_diff = (a[i] > b[i]) ? (a[i] - b[i]) : (b[i] - a[i]);
                if (abs_diff > atol + rtol * b[i])
                {
94 95 96
                    // use unary + operator to force integral values to be displayed as numbers
                    ar_fail << +a[i] << " is not close to " << +b[i] << " at index " << i
                            << std::endl;
97
                    rc = false;
Robert Kimball's avatar
Robert Kimball committed
98 99
                }
            }
100
            return rc ? ::testing::AssertionSuccess() : ar_fail;
Robert Kimball's avatar
Robert Kimball committed
101
        }
102

103 104 105 106 107
        /// \brief Same as numpy.allclose
        /// \param a First tensor to compare
        /// \param b Second tensor to compare
        /// \param rtol Relative tolerance
        /// \param atol Absolute tolerance
Scott Cyphers's avatar
Scott Cyphers committed
108
        /// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
109
        template <typename T>
110 111 112 113
        ::testing::AssertionResult all_close(const std::shared_ptr<ngraph::runtime::Tensor>& a,
                                             const std::shared_ptr<ngraph::runtime::Tensor>& b,
                                             T rtol = 1e-5f,
                                             T atol = 1e-8f)
Robert Kimball's avatar
Robert Kimball committed
114 115
        {
            // Check that the layouts are compatible
Scott Cyphers's avatar
Scott Cyphers committed
116
            if (*a->get_tensor_layout() != *b->get_tensor_layout())
Robert Kimball's avatar
Robert Kimball committed
117
            {
118 119
                return ::testing::AssertionFailure()
                       << "Cannot compare tensors with different layouts";
Robert Kimball's avatar
Robert Kimball committed
120
            }
Scott Cyphers's avatar
Scott Cyphers committed
121

Robert Kimball's avatar
Robert Kimball committed
122
            if (a->get_shape() != b->get_shape())
123
            {
124 125
                return ::testing::AssertionFailure()
                       << "Cannot compare tensors with different shapes";
126
            }
Scott Cyphers's avatar
Scott Cyphers committed
127

128
            return all_close(read_vector<T>(a), read_vector<T>(b), rtol, atol);
Robert Kimball's avatar
Robert Kimball committed
129
        }
Scott Cyphers's avatar
Scott Cyphers committed
130

131 132 133 134 135
        /// \brief Same as numpy.allclose
        /// \param as First tensors to compare
        /// \param bs Second tensors to compare
        /// \param rtol Relative tolerance
        /// \param atol Absolute tolerance
Robert Kimball's avatar
Robert Kimball committed
136
        /// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
137
        template <typename T>
138 139 140 141 142
        ::testing::AssertionResult
            all_close(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& as,
                      const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& bs,
                      T rtol,
                      T atol)
Robert Kimball's avatar
Robert Kimball committed
143 144 145
        {
            if (as.size() != bs.size())
            {
146 147
                return ::testing::AssertionFailure()
                       << "Cannot compare tensors with different sizes";
Robert Kimball's avatar
Robert Kimball committed
148 149 150
            }
            for (size_t i = 0; i < as.size(); ++i)
            {
151 152
                auto ar = all_close(as[i], bs[i], rtol, atol);
                if (!ar)
Robert Kimball's avatar
Robert Kimball committed
153
                {
154
                    return ar;
Robert Kimball's avatar
Robert Kimball committed
155 156
                }
            }
157
            return ::testing::AssertionSuccess();
Robert Kimball's avatar
Robert Kimball committed
158
        }
Scott Cyphers's avatar
Scott Cyphers committed
159 160
    }
}