pad_opset_pass.cpp 2.55 KB
Newer Older
1 2 3 4 5
#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
6
#include "ngraph/pass/opset0_downgrade.hpp"
7 8 9 10 11 12
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"

using namespace std;
using namespace ngraph;

13
TEST(opset_transform, opset1_pad_upgrade_pass)
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
{
    auto arg = make_shared<op::Parameter>(element::f32, Shape{5, 6});
    auto arg_pad_value = make_shared<op::Parameter>(element::f32, Shape{});
    CoordinateDiff padding_below{1, 2};
    CoordinateDiff padding_above{3, 4};
    auto pad_mode = op::PadMode::EDGE;

    auto pad_v0 =
        make_shared<op::v0::Pad>(arg, arg_pad_value, padding_below, padding_above, pad_mode);
    auto result = make_shared<op::Result>(pad_v0);
    auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg, arg_pad_value});

    ngraph::pass::Manager pass_manager;
    pass_manager.register_pass<pass::Opset1Upgrade>();
    pass_manager.run_passes(f);

    auto pad_s1_result = f->get_results().at(0);
    auto node = pad_s1_result->input(0).get_source_output().get_node_shared_ptr();
Scott Cyphers's avatar
Scott Cyphers committed
32
    auto pad_v1_node = as_type_ptr<op::v1::Pad>(node);
Scott Cyphers's avatar
Scott Cyphers committed
33
    ASSERT_TRUE(pad_v1_node);
34 35 36 37 38
    EXPECT_EQ(pad_v1_node->get_pad_mode(), pad_mode);

    EXPECT_EQ(pad_v1_node->get_pads_begin(), padding_below);
    EXPECT_EQ(pad_v1_node->get_pads_end(), padding_above);
}
39

40
TEST(opset_transform, opset1_pad_downgrade_pass)
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
{
    auto arg = make_shared<op::Parameter>(element::f32, Shape{5, 6});
    auto arg_pad_value = make_shared<op::Parameter>(element::f32, Shape{});
    const auto pads_begin =
        make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
    const auto pads_end = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{3, 4});
    auto pad_mode = op::PadMode::EDGE;

    auto pad_v1 = make_shared<op::v1::Pad>(arg, pads_begin, pads_end, arg_pad_value, pad_mode);
    auto result = make_shared<op::Result>(pad_v1);
    auto f = make_shared<Function>(ResultVector{result}, ParameterVector{arg, arg_pad_value});

    ngraph::pass::Manager pass_manager;
    pass_manager.register_pass<pass::Opset0Downgrade>();
    pass_manager.run_passes(f);

    auto pad_s0_result = f->get_results().at(0);
    auto node = pad_s0_result->input(0).get_source_output().get_node_shared_ptr();
Scott Cyphers's avatar
Scott Cyphers committed
59
    auto pad_v0_node = as_type_ptr<op::v0::Pad>(node);
Scott Cyphers's avatar
Scott Cyphers committed
60
    ASSERT_TRUE(pad_v0_node);
61 62 63 64 65
    EXPECT_EQ(pad_v0_node->get_pad_mode(), pad_mode);

    EXPECT_EQ(pad_v0_node->get_padding_below(), CoordinateDiff({1, 2}));
    EXPECT_EQ(pad_v0_node->get_padding_above(), CoordinateDiff({3, 4}));
}