#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"

using namespace std;
using namespace ngraph;

TEST(opset_transform, opset1_broadcast_upgrade_pass)
{
    auto arg = make_shared<op::Parameter>(element::f32, Shape{5, 6});

    auto bcast_v0 = make_shared<op::v0::Broadcast>(arg, Shape{3, 5, 4, 6}, AxisSet{0, 2});
    auto f = make_shared<Function>(NodeVector{bcast_v0}, ParameterVector{arg});

    ngraph::pass::Manager pass_manager;
    pass_manager.register_pass<pass::Opset1Upgrade>();
    pass_manager.run_passes(f);

    auto bcast_v1 = as_type_ptr<op::v1::Broadcast>(
        f->get_results().at(0)->input_value(0).get_node_shared_ptr());

    ASSERT_TRUE(bcast_v1);
    EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
    EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
    ASSERT_TRUE(bcast_v1->input_value(1).get_node()->is_constant());
    ASSERT_TRUE(bcast_v1->input_value(2).get_node()->is_constant());
    EXPECT_EQ(
        as_type_ptr<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())->get_shape_val(),
        (Shape{3, 5, 4, 6}));
    EXPECT_EQ(as_type_ptr<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr())
                  ->get_axis_set_val(),
              (AxisSet{1, 3}));
}

TEST(opset_transform, opset1_broadcast_downgrade_pass)
{
    auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
    auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{5}, {3, 1, 4, 2, 3});
    auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 3, 4});

    auto bcast_v1 = make_shared<op::v1::Broadcast>(arg, target_shape, axes_mapping);
    auto f = make_shared<Function>(NodeVector{bcast_v1}, ParameterVector{arg});

    ngraph::pass::Manager pass_manager;
    pass_manager.register_pass<pass::Opset0Downgrade>();
    pass_manager.run_passes(f);

    auto bcast_v0 = as_type_ptr<op::v0::Broadcast>(
        f->get_results().at(0)->input_value(0).get_node_shared_ptr());

    ASSERT_TRUE(bcast_v0);
    EXPECT_EQ(bcast_v0->get_broadcast_shape(), (Shape{3, 1, 4, 2, 3}));
    EXPECT_EQ(bcast_v0->get_broadcast_axes(), (AxisSet{0, 2}));
}