Commit 75379523 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

CF updates: Reshape, DynReshape, Transpose (#3338)

* Update Reshape CF to support all ETs

* Add CF for DynReshape

* Add CF for Transpose

* Add #include <numeric>, for std::iota

* style, oops
parent 30c7028f
......@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include <stdint.h>
#include "constant_folding.hpp"
......@@ -29,7 +30,9 @@
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
......@@ -243,52 +246,295 @@ void pass::ConstantFolding::construct_constant_reshape()
func = handler->second(reshape_match.get());
}
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
fold_constant_reshape<int32_t>(constant_match, reshape_match, func));
return true;
}
if (type == element::i64)
{
replace_node(m.get_match_root(),
fold_constant_reshape<int64_t>(constant_match, reshape_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_reshape<int8_t>(constant_match, reshape_match, func));
return true;
}
else if (type == element::f32)
switch (type.get_type_enum())
{
replace_node(m.get_match_root(),
fold_constant_reshape<float>(constant_match, reshape_match, func));
return true;
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_reshape_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_reshape<char>(constant_match, reshape_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_reshape<bfloat16>(constant_match, reshape_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_reshape<float16>(constant_match, reshape_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_reshape<float>(constant_match, reshape_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_reshape<double>(constant_match, reshape_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_reshape<int8_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_reshape<int16_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_reshape<int32_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_reshape<int64_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_reshape<uint8_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_reshape<uint16_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_reshape<uint32_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_reshape<uint64_t>(constant_match, reshape_match, func);
break;
}
else if (type == element::f64)
replace_node(m.get_match_root(), replacement);
return false;
};
auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
shared_ptr<op::Constant> fold_constant_dyn_reshape(shared_ptr<op::Constant> constant_data,
shared_ptr<op::DynReshape> dyn_reshape)
{
auto out_shape = dyn_reshape->get_shape();
AxisVector input_order(constant_data->get_shape().size());
std::iota(input_order.begin(), input_order.end(), 0);
vector<T> out_vec(shape_size(out_shape));
runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(),
out_vec.data(),
constant_data->get_shape(),
input_order,
out_shape);
return make_shared<op::Constant>(dyn_reshape->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_dyn_reshape()
{
auto constant_data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto constant_shape_label =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto dyn_reshape = make_shared<op::DynReshape>(constant_data_label, constant_shape_label);
// Note: No need to capture or consider constant_shape_label, because
// shape propagation will have transferred the info to dyn_reshape's
// output.
auto constant_dyn_reshape_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_dyn_reshape_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_data_match =
static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto dyn_reshape_match = static_pointer_cast<op::DynReshape>(m.get_match_root());
std::shared_ptr<Node> replacement;
auto type = dyn_reshape_match->get_element_type();
switch (type.get_type_enum())
{
replace_node(m.get_match_root(),
fold_constant_reshape<double>(constant_match, reshape_match, func));
return true;
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_dyn_reshape_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_dyn_reshape<char>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::bf16:
replacement =
fold_constant_dyn_reshape<bfloat16>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::f16:
replacement =
fold_constant_dyn_reshape<float16>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::f32:
replacement = fold_constant_dyn_reshape<float>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::f64:
replacement = fold_constant_dyn_reshape<double>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i8:
replacement = fold_constant_dyn_reshape<int8_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i16:
replacement =
fold_constant_dyn_reshape<int16_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i32:
replacement =
fold_constant_dyn_reshape<int32_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::i64:
replacement =
fold_constant_dyn_reshape<int64_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u8:
replacement =
fold_constant_dyn_reshape<uint8_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u16:
replacement =
fold_constant_dyn_reshape<uint16_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u32:
replacement =
fold_constant_dyn_reshape<uint32_t>(constant_data_match, dyn_reshape_match);
break;
case element::Type_t::u64:
replacement =
fold_constant_dyn_reshape<uint64_t>(constant_data_match, dyn_reshape_match);
break;
}
else if (type == element::bf16)
replace_node(m.get_match_root(), replacement);
return false;
};
auto dyn_reshape_matcher =
make_shared<pattern::Matcher>(dyn_reshape, "ConstantFolding.ConstantDynReshape");
this->add_matcher(
dyn_reshape_matcher, constant_dyn_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
shared_ptr<op::Constant> fold_constant_transpose(shared_ptr<op::Constant> constant_data,
shared_ptr<op::Constant> constant_perm,
shared_ptr<op::Transpose> transpose)
{
auto out_shape = transpose->get_shape();
auto input_order = constant_perm->get_axis_vector_val();
vector<T> out_vec(shape_size(out_shape));
runtime::reference::reshape<T>(constant_data->get_data_ptr<T>(),
out_vec.data(),
constant_data->get_shape(),
input_order,
out_shape);
return make_shared<op::Constant>(transpose->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_transpose()
{
auto constant_data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto constant_perm_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto transpose = make_shared<op::Transpose>(constant_data_label, constant_perm_label);
auto constant_transpose_callback = [constant_data_label,
constant_perm_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_transpose_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_data_match =
static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto constant_perm_match =
static_pointer_cast<op::Constant>(pattern_map[constant_perm_label]);
auto transpose_match = static_pointer_cast<op::Transpose>(m.get_match_root());
std::shared_ptr<Node> replacement;
auto type = transpose_match->get_element_type();
switch (type.get_type_enum())
{
replace_node(
m.get_match_root(),
fold_constant_reshape<ngraph::bfloat16>(constant_match, reshape_match, func));
return true;
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_transpose_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_transpose_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_transpose<char>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::bf16:
replacement = fold_constant_transpose<bfloat16>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::f16:
replacement = fold_constant_transpose<float16>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::f32:
replacement = fold_constant_transpose<float>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::f64:
replacement = fold_constant_transpose<double>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::i8:
replacement = fold_constant_transpose<int8_t>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::i16:
replacement = fold_constant_transpose<int16_t>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::i32:
replacement = fold_constant_transpose<int32_t>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::i64:
replacement = fold_constant_transpose<int64_t>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::u8:
replacement = fold_constant_transpose<uint8_t>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::u16:
replacement = fold_constant_transpose<uint16_t>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::u32:
replacement = fold_constant_transpose<uint32_t>(
constant_data_match, constant_perm_match, transpose_match);
break;
case element::Type_t::u64:
replacement = fold_constant_transpose<uint64_t>(
constant_data_match, constant_perm_match, transpose_match);
break;
}
replace_node(m.get_match_root(), replacement);
return false;
};
auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
auto transpose_matcher =
make_shared<pattern::Matcher>(transpose, "ConstantFolding.ConstantTranspose");
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
transpose_matcher, constant_transpose_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
......
......@@ -44,7 +44,9 @@ public:
REVERSE,
PRODUCT,
SUM,
CONCAT
CONCAT,
DYN_RESHAPE,
TRANSPOSE
};
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
......@@ -64,6 +66,8 @@ public:
construct_constant_product();
construct_constant_sum();
construct_constant_concat();
construct_constant_dyn_reshape();
construct_constant_transpose();
}
//this allows to specify the order in which matchers will be run
......@@ -90,6 +94,8 @@ public:
case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::SUM: construct_constant_sum(); break;
case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break;
case CFTransformations::TRANSPOSE: construct_constant_transpose(); break;
}
}
}
......@@ -108,6 +114,8 @@ private:
void construct_constant_product();
void construct_constant_sum();
void construct_constant_concat();
void construct_constant_dyn_reshape();
void construct_constant_transpose();
ngraph::BuildNodeExecutorMap m_cfmap;
};
......@@ -739,6 +739,63 @@ TEST(constant_folding, const_floor)
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_dyn_reshape)
{
Shape shape_in{2, 4};
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
Shape shape_shape{3};
vector<int64_t> values_shape{2, 4, 1};
auto constant_in = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto constant_shape = make_shared<op::Constant>(element::i64, shape_shape, values_shape);
auto dyn_reshape = make_shared<op::DynReshape>(constant_in, constant_shape);
auto f = make_shared<Function>(dyn_reshape, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_transpose)
{
Shape shape_in{2, 4};
vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
Shape shape_perm{2};
vector<int64_t> values_perm{1, 0};
auto constant_in = make_shared<op::Constant>(element::f64, shape_in, values_in);
auto constant_perm = make_shared<op::Constant>(element::i64, shape_perm, values_perm);
auto transpose = make_shared<op::Transpose>(constant_in, constant_perm);
auto f = make_shared<Function>(transpose, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Transpose>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<double>();
vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment