test_case.hpp 8.58 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <utility>

21 22
#include "all_close.hpp"
#include "all_close_f.hpp"
23 24
#include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp"
25
#include "test_tools.hpp"
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

namespace ngraph
{
    namespace test
    {
        class NgraphTestCase
        {
        public:
            NgraphTestCase(const std::shared_ptr<Function>& function,
                           const std::string& backend_name)
                : m_function(function)
                , m_backend(ngraph::runtime::Backend::create(backend_name))
            {
            }

41 42 43 44 45 46 47 48 49 50
            NgraphTestCase& set_tolerance(int tolerance_bits);

            /// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes.
            ///
            /// Just before the assertion is done, the current test case will gather expected and computed values,
            /// format them as 2 columns and print out to the console along with a corresponding index in the vector.
            ///
            /// \param dump - Indicates if the test case should perform the console printout
            NgraphTestCase& dump_results(bool dump = true);

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
            template <typename T>
            void add_input(const std::vector<T>& values)
            {
                auto params = m_function->get_parameters();

                NGRAPH_CHECK(m_input_index < params.size(),
                             "All function parameters already have inputs.");

                auto tensor = m_backend->create_tensor(params.at(m_input_index)->get_element_type(),
                                                       params.at(m_input_index)->get_shape());
                copy_data(tensor, values);

                m_input_tensors.push_back(tensor);

                ++m_input_index;
            }

68 69 70 71 72 73 74 75 76 77 78 79 80 81
            template <typename T>
            void add_input_from_file(const std::string& basepath, const std::string& filename)
            {
                auto filepath = ngraph::file_util::path_join(basepath, filename);
                add_input_from_file<T>(filepath);
            }

            template <typename T>
            void add_input_from_file(const std::string& filepath)
            {
                auto value = read_binary_file<T>(filepath);
                add_input(value);
            }

82 83 84 85 86 87 88 89 90 91
            template <typename T>
            void add_multiple_inputs(const std::vector<std::vector<T>>& vector_of_values)
            {
                for (const auto& value : vector_of_values)
                {
                    add_input(value);
                }
            }

            template <typename T>
92
            void add_expected_output(ngraph::Shape expected_shape, const std::vector<T>& values)
93 94 95 96 97 98 99 100
            {
                auto results = m_function->get_results();

                NGRAPH_CHECK(m_output_index < results.size(),
                             "All function results already have expected outputs.");

                auto function_output_type = results.at(m_output_index)->get_element_type();
                m_result_tensors.emplace_back(
101
                    m_backend->create_tensor(function_output_type, expected_shape));
102 103

                m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>(
104
                    function_output_type, expected_shape, values));
105 106 107 108

                ++m_output_index;
            }

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
            template <typename T>
            void add_expected_output(const std::vector<T>& values)
            {
                auto shape = m_function->get_results().at(m_output_index)->get_shape();
                add_expected_output(shape, values);
            }

            template <typename T>
            void add_expected_output_from_file(ngraph::Shape expected_shape,
                                               const std::string& basepath,
                                               const std::string& filename)
            {
                auto filepath = ngraph::file_util::path_join(basepath, filename);
                add_expected_output_from_file<T>(expected_shape, filepath);
            }

            template <typename T>
            void add_expected_output_from_file(ngraph::Shape expected_shape,
                                               const std::string& filepath)
            {
                auto value = read_binary_file<T>(filepath);
                add_expected_output(expected_shape, value);
            }
132

133 134
            void run();

135
        private:
136
            template <typename T>
137 138
            typename std::enable_if<std::is_floating_point<T>::value,
                                    ::testing::AssertionResult>::type
139 140 141 142 143
                compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
                               const std::shared_ptr<ngraph::runtime::Tensor>& results)
            {
                const auto expected = expected_results->get_vector<T>();
                const auto result = read_vector<T>(results);
144 145 146 147 148 149

                if (m_dump_results)
                {
                    std::cout << get_results_str<T>(expected, result, expected.size());
                }

150
                return ngraph::test::all_close_f(expected, result, m_tolerance_bits);
151 152 153
            }

            template <typename T>
154
            typename std::enable_if<std::is_integral<T>::value, ::testing::AssertionResult>::type
155 156 157 158 159
                compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
                               const std::shared_ptr<ngraph::runtime::Tensor>& results)
            {
                const auto expected = expected_results->get_vector<T>();
                const auto result = read_vector<T>(results);
160 161 162 163 164 165

                if (m_dump_results)
                {
                    std::cout << get_results_str<T>(expected, result, expected.size());
                }

166 167 168 169 170 171 172
                return ngraph::test::all_close(expected, result);
            }

            using value_comparator_function = std::function<::testing::AssertionResult(
                const std::shared_ptr<ngraph::op::Constant>&,
                const std::shared_ptr<ngraph::runtime::Tensor>&)>;

173 174 175 176 177 178 179 180
#define REGISTER_COMPARATOR(element_type_, type_)                                                  \
    {                                                                                              \
        ngraph::element::Type_t::element_type_, std::bind(&NgraphTestCase::compare_values<type_>,  \
                                                          this,                                    \
                                                          std::placeholders::_1,                   \
                                                          std::placeholders::_2)                   \
    }

181 182 183 184 185 186 187 188 189 190 191
            std::map<ngraph::element::Type_t, value_comparator_function> m_value_comparators = {
                REGISTER_COMPARATOR(f32, float),
                REGISTER_COMPARATOR(f64, double),
                REGISTER_COMPARATOR(i8, int8_t),
                REGISTER_COMPARATOR(i16, int16_t),
                REGISTER_COMPARATOR(i32, int32_t),
                REGISTER_COMPARATOR(i64, int64_t),
                REGISTER_COMPARATOR(u8, uint8_t),
                REGISTER_COMPARATOR(u16, uint16_t),
                REGISTER_COMPARATOR(u32, uint32_t),
                REGISTER_COMPARATOR(u64, uint64_t),
192 193 194 195
            };
#undef REGISTER_COMPARATOR

        protected:
196
            std::shared_ptr<Function> m_function;
197
            std::shared_ptr<runtime::Backend> m_backend;
198 199 200 201 202
            std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
            std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_result_tensors;
            std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
            int m_input_index = 0;
            int m_output_index = 0;
203
            bool m_dump_results = false;
204
            int m_tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS;
205 206 207
        };
    }
}