Commit a6c2f23b authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Add ConstantFolding for Gather (#3342)

* Add CF for Gather

* Style
parent 6b90c1bd
......@@ -35,6 +35,7 @@
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
......@@ -70,6 +71,7 @@
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
......@@ -1863,6 +1865,146 @@ void pass::ConstantFolding::construct_constant_concat()
this->add_matcher(concat_matcher, constant_concat_callback, all_pass_property_off);
}
// "Inner" helper for fold_constant_gather, which has to switch on the indices
// element type.
template <typename T, typename U>
static shared_ptr<op::Constant> fold_constant_gather_helper(const shared_ptr<op::Constant>& data,
const shared_ptr<op::Constant>& indices,
const shared_ptr<op::Gather>& gather)
{
std::vector<T> result_vec(shape_size(gather->get_shape()));
runtime::reference::gather<T, U>(data->get_data_ptr<T>(),
indices->get_data_ptr<U>(),
result_vec.data(),
data->get_shape(),
indices->get_shape(),
gather->get_shape(),
gather->get_axis());
return make_shared<op::Constant>(
gather->get_output_element_type(0), gather->get_output_shape(0), result_vec);
}
template <typename T>
static shared_ptr<op::Constant> fold_constant_gather(const shared_ptr<op::Constant>& data,
const shared_ptr<op::Constant>& indices,
const shared_ptr<op::Gather>& gather)
{
auto indices_type = indices->get_output_element_type(0);
switch (indices_type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_gather_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_gather_callback");
break;
case element::Type_t::boolean:
case element::Type_t::bf16:
case element::Type_t::f16:
case element::Type_t::f32:
case element::Type_t::f64:
case element::Type_t::i8:
case element::Type_t::i16:
case element::Type_t::u8:
case element::Type_t::u16:
case element::Type_t::u32:
case element::Type_t::u64:
NGRAPH_CHECK(false,
"Encountered unsupported indices element type in constant_gather_callback: ",
indices_type);
break;
case element::Type_t::i32:
return fold_constant_gather_helper<T, int32_t>(data, indices, gather);
case element::Type_t::i64:
return fold_constant_gather_helper<T, int64_t>(data, indices, gather);
}
NGRAPH_UNREACHABLE("Unhandled switch case");
}
void pass::ConstantFolding::construct_constant_gather()
{
auto data_label = make_shared<pattern::op::Label>(
element::f32, Shape{10, 20, 30}, pattern::has_class<op::Constant>());
auto indices_label =
make_shared<pattern::op::Label>(element::i64, Shape{5}, pattern::has_class<op::Constant>());
size_t gather_axis = 1;
auto gather_op = make_shared<op::Gather>(data_label, indices_label, gather_axis);
auto constant_gather_callback = [data_label, indices_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_gather_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto data = static_pointer_cast<op::Constant>(pattern_map[data_label]);
auto indices = static_pointer_cast<op::Constant>(pattern_map[indices_label]);
auto gather = static_pointer_cast<op::Gather>(m.get_match_root());
std::shared_ptr<Node> replacement;
auto data_type = data->get_output_element_type(0);
auto indices_type = indices->get_output_element_type(0);
switch (data_type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_gather_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_gather_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_gather<char>(data, indices, gather);
break;
case element::Type_t::bf16:
replacement = fold_constant_gather<bfloat16>(data, indices, gather);
break;
case element::Type_t::f16:
replacement = fold_constant_gather<float16>(data, indices, gather);
break;
case element::Type_t::f32:
replacement = fold_constant_gather<float>(data, indices, gather);
break;
case element::Type_t::f64:
replacement = fold_constant_gather<double>(data, indices, gather);
break;
case element::Type_t::i8:
replacement = fold_constant_gather<int8_t>(data, indices, gather);
break;
case element::Type_t::i16:
replacement = fold_constant_gather<int16_t>(data, indices, gather);
break;
case element::Type_t::i32:
replacement = fold_constant_gather<int32_t>(data, indices, gather);
break;
case element::Type_t::i64:
replacement = fold_constant_gather<int64_t>(data, indices, gather);
break;
case element::Type_t::u8:
replacement = fold_constant_gather<uint8_t>(data, indices, gather);
break;
case element::Type_t::u16:
replacement = fold_constant_gather<uint16_t>(data, indices, gather);
break;
case element::Type_t::u32:
replacement = fold_constant_gather<uint32_t>(data, indices, gather);
break;
case element::Type_t::u64:
replacement = fold_constant_gather<uint64_t>(data, indices, gather);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto gather_matcher =
make_shared<pattern::Matcher>(gather_op, "ConstantFolding.ConstantGather");
this->add_matcher(gather_matcher, constant_gather_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
shared_ptr<op::Constant> fold_constant_slice(shared_ptr<op::Constant> constant,
shared_ptr<op::Slice> slice)
......
......@@ -45,6 +45,7 @@ public:
PRODUCT,
SUM,
CONCAT,
GATHER,
SLICE,
DYN_SLICE,
DYN_RESHAPE,
......@@ -68,6 +69,7 @@ public:
construct_constant_product();
construct_constant_sum();
construct_constant_concat();
construct_constant_gather();
construct_constant_slice();
construct_constant_dyn_slice();
construct_constant_dyn_reshape();
......@@ -98,6 +100,7 @@ public:
case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::SUM: construct_constant_sum(); break;
case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break;
case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break;
case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break;
......@@ -120,6 +123,7 @@ private:
void construct_constant_product();
void construct_constant_sum();
void construct_constant_concat();
void construct_constant_gather();
void construct_constant_slice();
void construct_constant_dyn_slice();
void construct_constant_dyn_reshape();
......
......@@ -43,8 +43,8 @@ namespace ngraph
// out' = out[out_index] # rank(out') == rank(params')
// gather_nd(params', indices'', out')
template <typename T, typename U>
void gather(T* params,
U* indices,
void gather(const T* params,
const U* indices,
T* out,
const Shape& params_shape,
const Shape& indices_shape,
......@@ -148,13 +148,14 @@ namespace ngraph
auto out_outer_coord_iter = out_outer_transform.begin();
for (const Coordinate& params_outer_coord : params_outer_transform)
{
T* params_prime = &params[params_outer_transform.index(params_outer_coord)];
const T* params_prime =
&params[params_outer_transform.index(params_outer_coord)];
T* out_outer = &out[out_outer_transform.index(*out_outer_coord_iter)];
auto out_inner_coord_iter = out_inner_transform.begin();
for (const Coordinate& indices_outer_coord : indices_outer_transform)
{
U* indices_prime =
const U* indices_prime =
&indices[indices_outer_transform.index(indices_outer_coord)];
T* out_prime = &out_outer[out_inner_transform.index(*out_inner_coord_iter)];
gather_nd<T, U>(params_prime,
......
......@@ -739,6 +739,35 @@ TEST(constant_folding, const_floor)
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, const_gather)
{
auto constant_data = op::Constant::create(
element::f32,
Shape{2, 5},
vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f});
auto constant_indices =
op::Constant::create(element::i64, Shape{4}, vector<int64_t>{0, 3, 2, 2});
size_t gather_axis = 1;
auto gather = make_shared<op::Gather>(constant_data, constant_indices, gather_axis);
auto f = make_shared<Function>(gather, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Gather>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
vector<float> values_expected{1.0f, 4.0f, 3.0f, 3.0f, 6.0f, 9.0f, 8.0f, 8.0f};
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, const_slice)
{
Shape shape_in{16};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment