//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <utility>

#include "all_close.hpp"
#include "all_close_f.hpp"
#include "ngraph/function.hpp"
#include "ngraph/ngraph.hpp"
#include "test_tools.hpp"

namespace ngraph
{
    namespace test
    {
        class NgraphTestCase
        {
        public:
            NgraphTestCase(const std::shared_ptr<Function>& function,
                           const std::string& backend_name)
                : m_function(function)
                , m_backend(ngraph::runtime::Backend::create(backend_name))
            {
            }

            NgraphTestCase& set_tolerance(int tolerance_bits);

            /// \brief Makes the test case print the expected and computed values to the console. This should only be used for debugging purposes.
            ///
            /// Just before the assertion is done, the current test case will gather expected and computed values,
            /// format them as 2 columns and print out to the console along with a corresponding index in the vector.
            ///
            /// \param dump - Indicates if the test case should perform the console printout
            NgraphTestCase& dump_results(bool dump = true);

            template <typename T>
            void add_input(const std::vector<T>& values)
            {
                auto params = m_function->get_parameters();

                NGRAPH_CHECK(m_input_index < params.size(),
                             "All function parameters already have inputs.");

                auto tensor = m_backend->create_tensor(params.at(m_input_index)->get_element_type(),
                                                       params.at(m_input_index)->get_shape());
                copy_data(tensor, values);

                m_input_tensors.push_back(tensor);

                ++m_input_index;
            }

            template <typename T>
            void add_input_from_file(const std::string& basepath, const std::string& filename)
            {
                auto filepath = ngraph::file_util::path_join(basepath, filename);
                add_input_from_file<T>(filepath);
            }

            template <typename T>
            void add_input_from_file(const std::string& filepath)
            {
                auto value = read_binary_file<T>(filepath);
                add_input(value);
            }

            template <typename T>
            void add_multiple_inputs(const std::vector<std::vector<T>>& vector_of_values)
            {
                for (const auto& value : vector_of_values)
                {
                    add_input(value);
                }
            }

            template <typename T>
            void add_expected_output(ngraph::Shape expected_shape, const std::vector<T>& values)
            {
                auto results = m_function->get_results();

                NGRAPH_CHECK(m_output_index < results.size(),
                             "All function results already have expected outputs.");

                auto function_output_type = results.at(m_output_index)->get_element_type();
                auto function_output_shape = results.at(m_output_index)->get_shape();
                m_result_tensors.emplace_back(
                    m_backend->create_tensor(function_output_type, function_output_shape));

                m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>(
                    function_output_type, expected_shape, values));

                ++m_output_index;
            }

            template <typename T>
            void add_expected_output(const std::vector<T>& values)
            {
                auto shape = m_function->get_results().at(m_output_index)->get_shape();
                add_expected_output(shape, values);
            }

            template <typename T>
            void add_expected_output_from_file(ngraph::Shape expected_shape,
                                               const std::string& basepath,
                                               const std::string& filename)
            {
                auto filepath = ngraph::file_util::path_join(basepath, filename);
                add_expected_output_from_file<T>(expected_shape, filepath);
            }

            template <typename T>
            void add_expected_output_from_file(ngraph::Shape expected_shape,
                                               const std::string& filepath)
            {
                auto value = read_binary_file<T>(filepath);
                add_expected_output(expected_shape, value);
            }

            void run();

        private:
            template <typename T>
            typename std::enable_if<std::is_floating_point<T>::value,
                                    ::testing::AssertionResult>::type
                compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
                               const std::shared_ptr<ngraph::runtime::Tensor>& results)
            {
                const auto expected = expected_results->get_vector<T>();
                const auto result = read_vector<T>(results);

                if (m_dump_results)
                {
                    std::cout << get_results_str<T>(expected, result, expected.size());
                }

                return ngraph::test::all_close_f(expected, result, m_tolerance_bits);
            }

            template <typename T>
            typename std::enable_if<std::is_integral<T>::value, ::testing::AssertionResult>::type
                compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
                               const std::shared_ptr<ngraph::runtime::Tensor>& results)
            {
                const auto expected = expected_results->get_vector<T>();
                const auto result = read_vector<T>(results);

                if (m_dump_results)
                {
                    std::cout << get_results_str<T>(expected, result, expected.size());
                }

                return ngraph::test::all_close(expected, result);
            }

            using value_comparator_function = std::function<::testing::AssertionResult(
                const std::shared_ptr<ngraph::op::Constant>&,
                const std::shared_ptr<ngraph::runtime::Tensor>&)>;

#define REGISTER_COMPARATOR(element_type_, type_)                                                  \
    {                                                                                              \
        ngraph::element::Type_t::element_type_, std::bind(&NgraphTestCase::compare_values<type_>,  \
                                                          this,                                    \
                                                          std::placeholders::_1,                   \
                                                          std::placeholders::_2)                   \
    }

            std::map<ngraph::element::Type_t, value_comparator_function> m_value_comparators = {
                REGISTER_COMPARATOR(f32, float),
                REGISTER_COMPARATOR(f64, double),
                REGISTER_COMPARATOR(i8, int8_t),
                REGISTER_COMPARATOR(i16, int16_t),
                REGISTER_COMPARATOR(i32, int32_t),
                REGISTER_COMPARATOR(i64, int64_t),
                REGISTER_COMPARATOR(u8, uint8_t),
                REGISTER_COMPARATOR(u16, uint16_t),
                REGISTER_COMPARATOR(u32, uint32_t),
                REGISTER_COMPARATOR(u64, uint64_t),
            };
#undef REGISTER_COMPARATOR

        protected:
            std::shared_ptr<Function> m_function;
            std::shared_ptr<runtime::Backend> m_backend;
            std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
            std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_result_tensors;
            std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
            int m_input_index = 0;
            int m_output_index = 0;
            bool m_dump_results = false;
            int m_tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS;
        };
    }
}