//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "gtest/gtest.h"

#include "ngraph/ngraph.hpp"
#include "ngraph/specialize_function.hpp"

using namespace ngraph;

// Simple case: create a function with static parameter shapes and "specialize" them to the same
// shapes.
TEST(specialize_function, et_shape_static)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::i32, Shape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);
}

// Test specialization of dynamic element types.
TEST(specialize_function, et_dynamic_shape_static)
{
    auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);
}

// Test specialization of rank-dynamic shapes.
TEST(specialize_function, et_static_shape_rank_dynamic)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic());

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);
}

// Test specialization of rank-static dynamic shapes.
TEST(specialize_function, et_static_shape_rank_static_dynamic)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);
}

// Test specialization of values to a shape-dynamic parameters.
TEST(specialize_function, et_static_shape_rank_static_dynamic_subst_val)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<int32_t> p1_subst_vals{5, 0, 3, 8, 5, 8};

    std::vector<void*> param_vals{nullptr, p1_subst_vals.data()};

    auto g = specialize_function(f,
                                 {element::f32, element::i32},
                                 {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                 param_vals);

    ASSERT_EQ(g->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(g->get_output_element_type(0), element::f32);

    auto plus_node = std::dynamic_pointer_cast<op::Add>(g->get_results().at(0)->get_argument(0));
    ASSERT_TRUE(plus_node);
    auto convert_node = std::dynamic_pointer_cast<op::Convert>(plus_node->get_argument(1));
    ASSERT_TRUE(convert_node);
    auto const_node = std::dynamic_pointer_cast<op::Constant>(convert_node->get_argument(0));
    ASSERT_TRUE(const_node);

    ASSERT_EQ(const_node->get_output_element_type(0), element::i32);
    ASSERT_EQ(const_node->get_output_shape(0), (Shape{1, 2, 3}));
    ASSERT_EQ(const_node->get_vector<int32_t>(), p1_subst_vals);
}

// Test specialization of rank-dynamic shapes to a case where validation will fail.
//
// (The input shapes we provide at specialization time are inconsistent.)
TEST(specialize_function, et_static_shape_rank_dynamic_validation_fails)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic());

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    ASSERT_THROW(
        {
            specialize_function(f,
                                {element::f32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}},
                                param_vals);
        },
        NodeValidationFailure);
}

// Test specialization of dynamic element types to a case where validation will fail.
//
// (The input element types we provide at specialization time are inconsistent.)
TEST(specialize_function, et_dynamic_shape_static_validation_fails)
{
    auto p0 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::dynamic, Shape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    ASSERT_THROW(
        {
            specialize_function(f,
                                {element::u32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                param_vals);
        },
        NodeValidationFailure);
}

// Test specialization of rank-static dynamic shapes, where the replacement shapes have the wrong
// rank.
//
// (Note that we are testing for a different exception class here because the failure is in
// specialize_shape's pre-checks, which use NGRAPH_CHECK, rather than inside validation as we
// reconstruct the graph.)
TEST(specialize_function, et_static_shape_rank_static_dynamic_rank_mismatch)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape::dynamic(3));
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape::dynamic(3));

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    ASSERT_THROW(
        {
            specialize_function(f,
                                {element::f32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3, 4}},
                                param_vals);
        },
        CheckFailure);
}

// Test specialization of rank-static dynamic shapes, where the replacement shapes have wrong
// dimensions.
//
// (Note that we are testing for a different exception class here because the failure is in
// specialize_shape's pre-checks, which use NGRAPH_CHECK, rather than inside validation as we
// reconstruct the graph.)
TEST(specialize_function, et_static_shape_rank_static_dynamic_dim_mismatch)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
    auto p1 =
        std::make_shared<op::Parameter>(element::i32, PartialShape{1, Dimension::dynamic(), 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    ASSERT_THROW(
        {
            specialize_function(f,
                                {element::f32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 9, 4}},
                                param_vals);
        },
        CheckFailure);
}

// Test for failure when we supply the wrong number of replacement element types.
TEST(specialize_function, et_count_wrong)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    ASSERT_THROW(
        {
            specialize_function(f,
                                {element::f32, element::i32, element::u32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                param_vals);
        },
        CheckFailure);
}

// Test for failure when we supply the wrong number of replacement shapes.
TEST(specialize_function, shape_count_wrong)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr};

    ASSERT_THROW(
        {
            specialize_function(
                f,
                {element::f32, element::i32},
                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}, PartialShape{4, 5, 6}},
                param_vals);
        },
        CheckFailure);
}

// Test for failure when we supply the wrong number of replacement parameter values.
TEST(specialize_function, value_count_wrong)
{
    auto p0 = std::make_shared<op::Parameter>(element::f32, PartialShape{1, 2, 3});
    auto p1 = std::make_shared<op::Parameter>(element::i32, PartialShape{1, 2, 3});

    auto k = std::make_shared<op::Convert>(p1, element::f32);
    auto a = p0 + k;

    auto f = std::make_shared<Function>(a, ParameterVector{p0, p1});

    std::vector<void*> param_vals{nullptr, nullptr, nullptr};

    ASSERT_THROW(
        {
            specialize_function(f,
                                {element::f32, element::i32},
                                {PartialShape{1, 2, 3}, PartialShape{1, 2, 3}},
                                param_vals);
        },
        CheckFailure);
}