#include "gmock/gmock.h" #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" #include "ngraph/pass/manager.hpp" #include "ngraph/pass/opset0_downgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp" #include "util/type_prop.hpp" using namespace std; using namespace ngraph; TEST(opset_transform, opset1_one_hot_upgrade_pass) { auto indices = make_shared<op::Parameter>(element::i64, Shape{1, 3, 2, 3}); const auto depth = 4; PartialShape shape{1, 3, 2, depth, 3}; size_t one_hot_axis = 3; auto ont_hot_v0 = make_shared<op::v0::OneHot>(indices, shape, one_hot_axis); auto result = make_shared<op::Result>(ont_hot_v0); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{indices}); ngraph::pass::Manager pass_manager; pass_manager.register_pass<pass::Opset1Upgrade>(); pass_manager.run_passes(f); const auto pass_replacement_node = f->get_result()->input(0).get_source_output().get_node_shared_ptr(); const auto one_hot_v1 = as_type_ptr<op::v1::OneHot>(pass_replacement_node); ASSERT_TRUE(one_hot_v1); EXPECT_EQ(one_hot_v1->get_axis(), one_hot_axis); auto one_hot_v1_depth = as_type_ptr<op::Constant>(one_hot_v1->input_value(1).get_node_shared_ptr()); EXPECT_EQ(one_hot_v1_depth->get_vector<int64_t>()[0], depth); auto one_hot_v1_on_value = as_type_ptr<op::Constant>(one_hot_v1->input_value(2).get_node_shared_ptr()); EXPECT_EQ(one_hot_v1_on_value->get_vector<int64_t>()[0], 1); auto one_hot_v1_off_value = as_type_ptr<op::Constant>(one_hot_v1->input_value(3).get_node_shared_ptr()); EXPECT_EQ(one_hot_v1_off_value->get_vector<int64_t>()[0], 0); } TEST(opset_transform, opset1_one_hot_downgrade_pass) { auto indices = make_shared<op::Parameter>(element::i64, Shape{1, 3, 2, 3}); auto depth = op::Constant::create(element::i64, Shape{}, {4}); auto on_value = op::Constant::create(element::u32, Shape{}, {5}); auto off_value = op::Constant::create(element::u32, Shape{}, {10}); int64_t axis = 3; auto ont_hot_v1 = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis); auto result = make_shared<op::Result>(ont_hot_v1); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{indices}); ngraph::pass::Manager pass_manager; pass_manager.register_pass<pass::Opset0Downgrade>(); pass_manager.run_passes(f); const auto pass_replacement_node = f->get_result()->input_value(0).get_node_shared_ptr(); ASSERT_FALSE(is_type<op::v1::OneHot>(pass_replacement_node)); EXPECT_EQ(pass_replacement_node->get_shape(), (Shape{1, 3, 2, 4, 3})); } TEST(opset_transform, opset1_one_hot_downgrade_pass_depth_not_constant) { auto indices = make_shared<op::Parameter>(element::i64, Shape{1, 3, 2, 3}); auto depth = make_shared<op::Parameter>(element::i64, Shape{}); auto on_value = op::Constant::create(element::u32, Shape{}, {5}); auto off_value = op::Constant::create(element::u32, Shape{}, {10}); int64_t axis = 3; auto ont_hot_v1 = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis); auto result = make_shared<op::Result>(ont_hot_v1); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{indices, depth}); ngraph::pass::Manager pass_manager; pass_manager.register_pass<pass::Opset0Downgrade>(); try { pass_manager.run_passes(f); // Should have thrown, so fail if it didn't FAIL() << "Not constant depth not detected"; } catch (const ngraph_error& error) { EXPECT_HAS_SUBSTRING(error.what(), std::string("depth input must be constant")); } catch (...) { FAIL() << "OneHot downgrade failed for unexpected reason"; } } TEST(opset_transform, opset1_one_hot_downgrade_pass_indices_shape_not_static) { auto indices = make_shared<op::Parameter>(element::i64, PartialShape::dynamic()); auto depth = op::Constant::create(element::i64, Shape{}, {4}); auto on_value = op::Constant::create(element::u32, Shape{}, {5}); auto off_value = op::Constant::create(element::u32, Shape{}, {10}); int64_t axis = 3; auto ont_hot_v1 = make_shared<op::v1::OneHot>(indices, depth, on_value, off_value, axis); auto result = make_shared<op::Result>(ont_hot_v1); auto f = make_shared<Function>(ResultVector{result}, ParameterVector{indices}); ngraph::pass::Manager pass_manager; pass_manager.register_pass<pass::Opset0Downgrade>(); try { pass_manager.run_passes(f); // Should have thrown, so fail if it didn't FAIL() << "Not static indices shape not detected"; } catch (const ngraph_error& error) { EXPECT_HAS_SUBSTRING(error.what(), std::string("indices shape must be static")); } catch (...) { FAIL() << "OneHot downgrade failed for unexpected reason"; } }