all_close.hpp 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
Scott Cyphers's avatar
Scott Cyphers committed
16 17 18

#pragma once

19
#include <cmath>
Scott Cyphers's avatar
Scott Cyphers committed
20
#include <memory>
21
#include <vector>
Scott Cyphers's avatar
Scott Cyphers committed
22

23
#include "ngraph/type/element_type.hpp"
24
#include "test_tools.hpp"
Scott Cyphers's avatar
Scott Cyphers committed
25 26 27 28 29

namespace ngraph
{
    namespace test
    {
30
        /// @brief Same as numpy.allclose
Robert Kimball's avatar
Robert Kimball committed
31 32
        /// @param a First tensor to compare
        /// @param b Second tensor to compare
33 34
        /// @param rtol Relative tolerance
        /// @param atol Absolute tolerance
Robert Kimball's avatar
Robert Kimball committed
35 36 37 38
        /// @returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
        template <typename T>
        bool all_close(const std::vector<T>& a,
                       const std::vector<T>& b,
39 40
                       T rtol = static_cast<T>(1e-5),
                       T atol = static_cast<T>(1e-8))
Robert Kimball's avatar
Robert Kimball committed
41
        {
42
            bool rc = true;
Robert Kimball's avatar
Robert Kimball committed
43 44 45 46 47
            assert(a.size() == b.size());
            for (size_t i = 0; i < a.size(); ++i)
            {
                if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i]))
                {
Fenglei's avatar
Fenglei committed
48
                    NGRAPH_INFO << a[i] << " is not close to " << b[i] << " at index " << i;
49
                    rc = false;
Robert Kimball's avatar
Robert Kimball committed
50 51
                }
            }
52
            return rc;
Robert Kimball's avatar
Robert Kimball committed
53
        }
54

Scott Cyphers's avatar
Scott Cyphers committed
55 56 57 58 59 60
        /// @brief Same as numpy.allclose
        /// @param a First tensor to compare
        /// @param b Second tensor to compare
        /// @param rtol Relative tolerance
        /// @param atol Absolute tolerance
        /// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
61 62 63 64 65
        template <typename T>
        bool all_close(const std::shared_ptr<ngraph::runtime::TensorView>& a,
                       const std::shared_ptr<ngraph::runtime::TensorView>& b,
                       T rtol = 1e-5f,
                       T atol = 1e-8f)
Robert Kimball's avatar
Robert Kimball committed
66 67 68 69 70 71
        {
            // Check that the layouts are compatible
            if (*a->get_tensor_view_layout() != *b->get_tensor_view_layout())
            {
                throw ngraph_error("Cannot compare tensors with different layouts");
            }
Scott Cyphers's avatar
Scott Cyphers committed
72

Robert Kimball's avatar
Robert Kimball committed
73
            if (a->get_shape() != b->get_shape())
74
            {
Robert Kimball's avatar
Robert Kimball committed
75
                return false;
76
            }
Scott Cyphers's avatar
Scott Cyphers committed
77

78
            return all_close(read_vector<T>(a), read_vector<T>(b), rtol, atol);
Robert Kimball's avatar
Robert Kimball committed
79
        }
Scott Cyphers's avatar
Scott Cyphers committed
80 81

        /// @brief Same as numpy.allclose
Robert Kimball's avatar
Robert Kimball committed
82 83
        /// @param as First tensors to compare
        /// @param bs Second tensors to compare
Scott Cyphers's avatar
Scott Cyphers committed
84 85
        /// @param rtol Relative tolerance
        /// @param atol Absolute tolerance
Robert Kimball's avatar
Robert Kimball committed
86
        /// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
87 88 89 90 91
        template <typename T>
        bool all_close(const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& as,
                       const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& bs,
                       T rtol,
                       T atol)
Robert Kimball's avatar
Robert Kimball committed
92 93 94 95 96 97 98 99 100 101 102 103 104 105
        {
            if (as.size() != bs.size())
            {
                return false;
            }
            for (size_t i = 0; i < as.size(); ++i)
            {
                if (!all_close(as[i], bs[i], rtol, atol))
                {
                    return false;
                }
            }
            return true;
        }
Scott Cyphers's avatar
Scott Cyphers committed
106 107
    }
}