Commit bb400712 authored by Amy Zhuang's avatar Amy Zhuang

Modify Gather constant folding to support v1 op.

parent bc448701
......@@ -26,17 +26,34 @@ using namespace ngraph;
template <typename T, typename U>
static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op::Constant>& data,
const shared_ptr<op::Constant>& indices,
const shared_ptr<op::Gather>& gather)
const shared_ptr<Node>& gather)
{
std::vector<T> result_vec(shape_size(gather->get_shape()));
if (auto gather_v1 = as_type_ptr<op::v1::Gather>(gather))
{
runtime::reference::gather<T, U>(data->get_data_ptr<T>(),
indices->get_data_ptr<U>(),
result_vec.data(),
data->get_shape(),
indices->get_shape(),
gather_v1->get_shape(),
gather_v1->get_axis());
}
else if (auto gather_v0 = as_type_ptr<op::v0::Gather>(gather))
{
runtime::reference::gather<T, U>(data->get_data_ptr<T>(),
indices->get_data_ptr<U>(),
result_vec.data(),
data->get_shape(),
indices->get_shape(),
gather->get_shape(),
gather->get_axis());
gather_v0->get_shape(),
gather_v0->get_axis());
}
else
{
throw ngraph_error("Unsupported op in gather constant folding.");
}
return make_shared<op::Constant>(
gather->get_output_element_type(0), gather->get_output_shape(0), result_vec);
......@@ -45,7 +62,7 @@ static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op:
template <typename T>
static shared_ptr<op::Constant> fold_constant_gather(const shared_ptr<op::Constant>& data,
const shared_ptr<op::Constant>& indices,
const shared_ptr<op::Gather>& gather)
const shared_ptr<Node>& gather)
{
auto indices_type = indices->get_output_element_type(0);
......@@ -88,7 +105,11 @@ void pass::ConstantFolding::construct_constant_gather()
auto indices_label =
make_shared<pattern::op::Label>(element::i64, Shape{5}, pattern::has_class<op::Constant>());
size_t gather_axis = 1;
auto gather_op = make_shared<op::Gather>(data_label, indices_label, gather_axis);
auto gather_v0 = make_shared<op::Gather>(data_label, indices_label, gather_axis);
auto axis_label =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto gather_v1 = make_shared<op::v1::Gather>(data_label, indices_label, axis_label);
auto constant_gather_callback = [data_label, indices_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_gather_callback against node = "
......@@ -98,7 +119,7 @@ void pass::ConstantFolding::construct_constant_gather()
auto data = static_pointer_cast<op::Constant>(pattern_map[data_label]);
auto indices = static_pointer_cast<op::Constant>(pattern_map[indices_label]);
auto gather = static_pointer_cast<op::Gather>(m.get_match_root());
auto gather = m.get_match_root();
NGRAPH_CHECK(revalidate_and_ensure_static(gather));
......@@ -158,7 +179,12 @@ void pass::ConstantFolding::construct_constant_gather()
return true;
};
auto gather_matcher =
make_shared<pattern::Matcher>(gather_op, "ConstantFolding.ConstantGather");
this->add_matcher(gather_matcher, constant_gather_callback, PassProperty::CHANGE_DYNAMIC_STATE);
auto gather_matcher_v0 =
make_shared<pattern::Matcher>(gather_v0, "ConstantFolding.ConstantGatherV0");
this->add_matcher(
gather_matcher_v0, constant_gather_callback, PassProperty::CHANGE_DYNAMIC_STATE);
auto gather_matcher_v1 =
make_shared<pattern::Matcher>(gather_v1, "ConstantFolding.ConstantGatherV1");
this->add_matcher(
gather_matcher_v1, constant_gather_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
......@@ -1418,6 +1418,62 @@ TEST(constant_folding, const_gather)
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, const_gather_v1)
{
auto constant_data = op::Constant::create(
element::f32,
Shape{2, 5},
vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
auto constant_indices =
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
auto constant_axis = op::Constant::create(element::i64, Shape{1}, vector<int64_t>{1});
auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
auto f = make_shared<Function>(gather, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Gather>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, const_gather_v1_scalar)
{
auto constant_data = op::Constant::create(
element::f32,
Shape{2, 5},
vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
auto constant_indices =
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
auto constant_axis = op::Constant::create(element::i64, Shape{}, vector<int64_t>{1});
auto gather = make_shared<op::v1::Gather>(constant_data, constant_indices, constant_axis);
auto f = make_shared<Function>(gather, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Gather>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, const_slice)
{
Shape shape_in{16};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment