#include "gmock/gmock.h" #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" #include "ngraph/pass/manager.hpp" #include "ngraph/pass/opset0_downgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp" #include "util/test_control.hpp" #include "util/type_prop.hpp" using namespace std; using namespace ngraph; TEST(opset_transform, opset0_select_downgrade_pass) { auto cond = make_shared<op::Parameter>(element::boolean, Shape{2}); auto ptrue = make_shared<op::Parameter>(element::f32, Shape{4, 2}); auto pfalse = make_shared<op::Parameter>(element::f32, Shape{4, 2}); auto v1_node = make_shared<op::v1::Select>(cond, ptrue, pfalse); auto result = make_shared<op::Result>(v1_node); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{cond, ptrue, pfalse}); ngraph::pass::Manager pass_manager; pass_manager.register_pass<pass::Opset0Downgrade>(); pass_manager.run_passes(f); auto v0_result = f->get_results().at(0); auto node = v0_result->input_value(0).get_node_shared_ptr(); auto v0_node = as_type_ptr<op::v0::Select>(node); ASSERT_TRUE(v0_node); EXPECT_EQ(v0_node->get_output_element_type(0), element::f32); EXPECT_EQ(v0_node->get_output_shape(0), (Shape{4, 2})); } TEST(opset_transform, opset1_select_upgrade_pass) { auto cond = make_shared<op::Parameter>(element::boolean, Shape{4, 2}); auto ptrue = make_shared<op::Parameter>(element::f32, Shape{4, 2}); auto pfalse = make_shared<op::Parameter>(element::f32, Shape{4, 2}); auto v0_node = make_shared<op::v0::Select>(cond, ptrue, pfalse); auto result = make_shared<op::Result>(v0_node); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{cond, ptrue, pfalse}); ngraph::pass::Manager pass_manager; pass_manager.register_pass<pass::Opset1Upgrade>(); pass_manager.run_passes(f); auto v1_result = f->get_results().at(0); auto node = v1_result->input_value(0).get_node_shared_ptr(); auto v1_node = as_type_ptr<op::v1::Select>(node); ASSERT_TRUE(v1_node); EXPECT_EQ(v1_node->get_auto_broadcast(), op::AutoBroadcastSpec()); EXPECT_EQ(v1_node->get_output_element_type(0), element::f32); EXPECT_EQ(v1_node->get_output_shape(0), (Shape{4, 2})); }