select_opset_pass.cpp 2.21 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/test_control.hpp"
#include "util/type_prop.hpp"

using namespace std;
using namespace ngraph;

TEST(opset_transform, opset0_select_downgrade_pass)
{
    auto cond = make_shared<op::Parameter>(element::boolean, Shape{2});
    auto ptrue = make_shared<op::Parameter>(element::f32, Shape{4, 2});
    auto pfalse = make_shared<op::Parameter>(element::f32, Shape{4, 2});

    auto v1_node = make_shared<op::v1::Select>(cond, ptrue, pfalse);
    auto result = make_shared<op::Result>(v1_node);
    auto f = make_shared<Function>(ResultVector{result}, ParameterVector{cond, ptrue, pfalse});

    ngraph::pass::Manager pass_manager;
    pass_manager.register_pass<pass::Opset0Downgrade>();
    pass_manager.run_passes(f);

    auto v0_result = f->get_results().at(0);
    auto node = v0_result->input_value(0).get_node_shared_ptr();
    auto v0_node = as_type_ptr<op::v0::Select>(node);

    ASSERT_TRUE(v0_node);
    EXPECT_EQ(v0_node->output(0).get_element_type(), element::f32);
    EXPECT_EQ(v0_node->output(0).get_shape(), (Shape{4, 2}));
}

TEST(opset_transform, opset1_select_upgrade_pass)
{
    auto cond = make_shared<op::Parameter>(element::boolean, Shape{4, 2});
    auto ptrue = make_shared<op::Parameter>(element::f32, Shape{4, 2});
    auto pfalse = make_shared<op::Parameter>(element::f32, Shape{4, 2});

    auto v0_node = make_shared<op::v0::Select>(cond, ptrue, pfalse);
    auto result = make_shared<op::Result>(v0_node);
    auto f = make_shared<Function>(ResultVector{result}, ParameterVector{cond, ptrue, pfalse});

    ngraph::pass::Manager pass_manager;
    pass_manager.register_pass<pass::Opset1Upgrade>();
    pass_manager.run_passes(f);

    auto v1_result = f->get_results().at(0);
    auto node = v1_result->input_value(0).get_node_shared_ptr();
    auto v1_node = as_type_ptr<op::v1::Select>(node);

    ASSERT_TRUE(v1_node);
    EXPECT_EQ(v1_node->get_auto_broadcast(), op::AutoBroadcastSpec());
    EXPECT_EQ(v1_node->output(0).get_element_type(), element::f32);
    EXPECT_EQ(v1_node->output(0).get_shape(), (Shape{4, 2}));
}