broadcast_opset_pass.cpp 2.36 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"

using namespace std;
using namespace ngraph;

TEST(opset_transform, opset1_broadcast_upgrade_pass)
{
    auto arg = make_shared<op::Parameter>(element::f32, Shape{5, 6});

    auto bcast_v0 = make_shared<op::v0::Broadcast>(arg, Shape{3, 5, 4, 6}, AxisSet{0, 2});
    auto f = make_shared<Function>(NodeVector{bcast_v0}, ParameterVector{arg});

    ngraph::pass::Manager pass_manager;
    pass_manager.register_pass<pass::Opset1Upgrade>();
    pass_manager.run_passes(f);

Scott Cyphers's avatar
Scott Cyphers committed
25
    auto bcast_v1 = as_type_ptr<op::v1::Broadcast>(
26 27
        f->get_results().at(0)->input_value(0).get_node_shared_ptr());

Scott Cyphers's avatar
Scott Cyphers committed
28
    ASSERT_TRUE(bcast_v1);
29 30
    EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
    EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
Scott Cyphers's avatar
Scott Cyphers committed
31 32 33 34 35
    ASSERT_TRUE(bcast_v1->input_value(1).get_node()->is_constant());
    ASSERT_TRUE(bcast_v1->input_value(2).get_node()->is_constant());
    EXPECT_EQ(
        as_type_ptr<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())->get_shape_val(),
        (Shape{3, 5, 4, 6}));
Scott Cyphers's avatar
Scott Cyphers committed
36
    EXPECT_EQ(as_type_ptr<op::Constant>(bcast_v1->input_value(2).get_node_shared_ptr())
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
                  ->get_axis_set_val(),
              (AxisSet{1, 3}));
}

TEST(opset_transform, opset1_broadcast_downgrade_pass)
{
    auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
    auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{5}, {3, 1, 4, 2, 3});
    auto axes_mapping = op::Constant::create<int64_t>(element::i64, Shape{3}, {1, 3, 4});

    auto bcast_v1 = make_shared<op::v1::Broadcast>(arg, target_shape, axes_mapping);
    auto f = make_shared<Function>(NodeVector{bcast_v1}, ParameterVector{arg});

    ngraph::pass::Manager pass_manager;
    pass_manager.register_pass<pass::Opset0Downgrade>();
    pass_manager.run_passes(f);

Scott Cyphers's avatar
Scott Cyphers committed
54
    auto bcast_v0 = as_type_ptr<op::v0::Broadcast>(
55 56
        f->get_results().at(0)->input_value(0).get_node_shared_ptr());

Scott Cyphers's avatar
Scott Cyphers committed
57
    ASSERT_TRUE(bcast_v0);
58 59 60
    EXPECT_EQ(bcast_v0->get_broadcast_shape(), (Shape{3, 1, 4, 2, 3}));
    EXPECT_EQ(bcast_v0->get_broadcast_axes(), (AxisSet{0, 2}));
}