Commit 9f928d92 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

CF for Select (#3359)

parent d8d940d0
...@@ -53,6 +53,7 @@ ...@@ -53,6 +53,7 @@
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
...@@ -89,6 +90,7 @@ ...@@ -89,6 +90,7 @@
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/sign.hpp" #include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/slice.hpp" #include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/sqrt.hpp" #include "ngraph/runtime/reference/sqrt.hpp"
...@@ -2247,3 +2249,102 @@ void pass::ConstantFolding::construct_constant_dyn_slice() ...@@ -2247,3 +2249,102 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
make_shared<pattern::Matcher>(dyn_slice_op, "ConstantFolding.ConstantDynSlice"); make_shared<pattern::Matcher>(dyn_slice_op, "ConstantFolding.ConstantDynSlice");
this->add_matcher(dyn_slice_matcher, constant_dyn_slice_callback, all_pass_property_off); this->add_matcher(dyn_slice_matcher, constant_dyn_slice_callback, all_pass_property_off);
} }
template <class T>
shared_ptr<op::Constant> fold_constant_select(shared_ptr<op::Constant> selection,
shared_ptr<op::Constant> t,
shared_ptr<op::Constant> f,
shared_ptr<op::Select> select)
{
auto out_shape = select->get_shape();
vector<T> out_vec(shape_size(out_shape));
runtime::reference::select<T>(selection->get_data_ptr<char>(),
t->get_data_ptr<T>(),
f->get_data_ptr<T>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(select->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_select()
{
auto selection_label = make_shared<pattern::op::Label>(
element::boolean, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto t_label = make_shared<pattern::op::Label>(
element::i64, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto f_label = make_shared<pattern::op::Label>(
element::i64, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto select_op = make_shared<op::Select>(selection_label, t_label, f_label);
auto constant_select_callback = [selection_label, t_label, f_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_select_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto selection_node = static_pointer_cast<op::Constant>(pattern_map[selection_label]);
auto t_node = static_pointer_cast<op::Constant>(pattern_map[t_label]);
auto f_node = static_pointer_cast<op::Constant>(pattern_map[f_label]);
auto select = static_pointer_cast<op::Select>(m.get_match_root());
std::shared_ptr<op::Constant> replacement;
switch (select->get_output_element_type(0).get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_select_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_select_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_select<char>(selection_node, t_node, f_node, select);
break;
case element::Type_t::bf16:
replacement = fold_constant_select<bfloat16>(selection_node, t_node, f_node, select);
break;
case element::Type_t::f16:
replacement = fold_constant_select<float16>(selection_node, t_node, f_node, select);
break;
case element::Type_t::f32:
replacement = fold_constant_select<float>(selection_node, t_node, f_node, select);
break;
case element::Type_t::f64:
replacement = fold_constant_select<double>(selection_node, t_node, f_node, select);
break;
case element::Type_t::i8:
replacement = fold_constant_select<int8_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::i16:
replacement = fold_constant_select<int16_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::i32:
replacement = fold_constant_select<int32_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::i64:
replacement = fold_constant_select<int64_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::u8:
replacement = fold_constant_select<uint8_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::u16:
replacement = fold_constant_select<uint16_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::u32:
replacement = fold_constant_select<uint32_t>(selection_node, t_node, f_node, select);
break;
case element::Type_t::u64:
replacement = fold_constant_select<uint64_t>(selection_node, t_node, f_node, select);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto select_matcher =
make_shared<pattern::Matcher>(select_op, "ConstantFolding.ConstantSelect");
this->add_matcher(select_matcher, constant_select_callback, all_pass_property_off);
}
...@@ -49,7 +49,8 @@ public: ...@@ -49,7 +49,8 @@ public:
SLICE, SLICE,
DYN_SLICE, DYN_SLICE,
DYN_RESHAPE, DYN_RESHAPE,
TRANSPOSE TRANSPOSE,
SELECT
}; };
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap()) ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
...@@ -74,6 +75,7 @@ public: ...@@ -74,6 +75,7 @@ public:
construct_constant_dyn_slice(); construct_constant_dyn_slice();
construct_constant_dyn_reshape(); construct_constant_dyn_reshape();
construct_constant_transpose(); construct_constant_transpose();
construct_constant_select();
} }
//this allows to specify the order in which matchers will be run //this allows to specify the order in which matchers will be run
...@@ -105,6 +107,7 @@ public: ...@@ -105,6 +107,7 @@ public:
case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break; case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break;
case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break; case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break;
case CFTransformations::TRANSPOSE: construct_constant_transpose(); break; case CFTransformations::TRANSPOSE: construct_constant_transpose(); break;
case CFTransformations::SELECT: construct_constant_select(); break;
} }
} }
} }
...@@ -128,6 +131,7 @@ private: ...@@ -128,6 +131,7 @@ private:
void construct_constant_dyn_slice(); void construct_constant_dyn_slice();
void construct_constant_dyn_reshape(); void construct_constant_dyn_reshape();
void construct_constant_transpose(); void construct_constant_transpose();
void construct_constant_select();
ngraph::BuildNodeExecutorMap m_cfmap; ngraph::BuildNodeExecutorMap m_cfmap;
}; };
...@@ -891,6 +891,35 @@ TEST(constant_folding, constant_transpose) ...@@ -891,6 +891,35 @@ TEST(constant_folding, constant_transpose)
ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS)); ASSERT_TRUE(test::all_close_f(values_permute, values_out, MIN_FLOAT_TOLERANCE_BITS));
} }
TEST(constant_folding, constant_select)
{
Shape shape{2, 4};
vector<char> values_selection{0, 1, 1, 0, 1, 0, 0, 1};
vector<int64_t> values_t{2, 4, 6, 8, 10, 12, 14, 16};
vector<int64_t> values_f{1, 3, 5, 7, 9, 11, 13, 15};
auto constant_selection = make_shared<op::Constant>(element::boolean, shape, values_selection);
auto constant_t = make_shared<op::Constant>(element::i64, shape, values_t);
auto constant_f = make_shared<op::Constant>(element::i64, shape, values_f);
auto select = make_shared<op::Select>(constant_selection, constant_t, constant_f);
auto f = make_shared<Function>(select, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Select>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int64_t>();
vector<int64_t> values_expected{1, 4, 6, 7, 10, 11, 13, 16};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, pass_property) TEST(constant_folding, pass_property)
{ {
auto pass = std::make_shared<ngraph::pass::ConstantFolding>(); auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment