Unverified Commit d1d27d9e authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into silee2/pragma

parents 66b6f186 f50e12a1
......@@ -32,6 +32,7 @@
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp"
......@@ -141,6 +142,92 @@ shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_reshape()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1});
auto constant_reshape_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root());
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto handler = m_cfmap.find(type_index(typeid(ngraph::op::Reshape)));
NGRAPH_CHECK(handler != m_cfmap.end(),
"constant folding map should have reshape entry");
func = handler->second(reshape_match.get());
}
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
switch (type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_reshape_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_reshape<char>(constant_match, reshape_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_reshape<bfloat16>(constant_match, reshape_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_reshape<float16>(constant_match, reshape_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_reshape<float>(constant_match, reshape_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_reshape<double>(constant_match, reshape_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_reshape<int8_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_reshape<int16_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_reshape<int32_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_reshape<int64_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_reshape<uint8_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_reshape<uint16_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_reshape<uint32_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_reshape<uint64_t>(constant_match, reshape_match, func);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
shared_ptr<op::Pad> pad,
......@@ -207,123 +294,63 @@ void pass::ConstantFolding::construct_constant_pad()
func = handler->second(pad_match.get());
}
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
fold_constant_pad<int>(constant_match, pad_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_pad<int8_t>(constant_match, pad_match, func));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
fold_constant_pad<float>(constant_match, pad_match, func));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
fold_constant_pad<double>(constant_match, pad_match, func));
return true;
}
return false;
};
auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::ConstantFolding::construct_constant_reshape()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1});
auto constant_reshape_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root());
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto handler = m_cfmap.find(type_index(typeid(ngraph::op::Reshape)));
NGRAPH_CHECK(handler != m_cfmap.end(),
"constant folding map should have reshape entry");
func = handler->second(reshape_match.get());
}
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
switch (type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_reshape_callback");
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_pad_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback");
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_pad_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_reshape<char>(constant_match, reshape_match, func);
replacement = fold_constant_pad<char>(constant_match, pad_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_reshape<bfloat16>(constant_match, reshape_match, func);
replacement = fold_constant_pad<bfloat16>(constant_match, pad_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_reshape<float16>(constant_match, reshape_match, func);
replacement = fold_constant_pad<float16>(constant_match, pad_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_reshape<float>(constant_match, reshape_match, func);
replacement = fold_constant_pad<float>(constant_match, pad_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_reshape<double>(constant_match, reshape_match, func);
replacement = fold_constant_pad<double>(constant_match, pad_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_reshape<int8_t>(constant_match, reshape_match, func);
replacement = fold_constant_pad<int8_t>(constant_match, pad_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_reshape<int16_t>(constant_match, reshape_match, func);
replacement = fold_constant_pad<int16_t>(constant_match, pad_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_reshape<int32_t>(constant_match, reshape_match, func);
replacement = fold_constant_pad<int32_t>(constant_match, pad_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_reshape<int64_t>(constant_match, reshape_match, func);
replacement = fold_constant_pad<int64_t>(constant_match, pad_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_reshape<uint8_t>(constant_match, reshape_match, func);
replacement = fold_constant_pad<uint8_t>(constant_match, pad_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_reshape<uint16_t>(constant_match, reshape_match, func);
replacement = fold_constant_pad<uint16_t>(constant_match, pad_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_reshape<uint32_t>(constant_match, reshape_match, func);
replacement = fold_constant_pad<uint32_t>(constant_match, pad_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_reshape<uint64_t>(constant_match, reshape_match, func);
replacement = fold_constant_pad<uint64_t>(constant_match, pad_match, func);
break;
}
replace_node(m.get_match_root(), replacement);
return false;
return true;
};
auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
......@@ -430,7 +457,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
}
replace_node(m.get_match_root(), replacement);
return false;
return true;
};
auto dyn_reshape_matcher =
......@@ -546,7 +573,7 @@ void pass::ConstantFolding::construct_constant_transpose()
}
replace_node(m.get_match_root(), replacement);
return false;
return true;
};
auto transpose_matcher =
......@@ -609,40 +636,61 @@ void pass::ConstantFolding::construct_constant_broadcast()
func = handler->second(broadcast_match.get());
}
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<int>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<int8_t>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<float>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<double>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::bf16)
std::shared_ptr<Node> replacement;
auto type = broadcast_match->get_element_type();
switch (type.get_type_enum())
{
replace_node(
m.get_match_root(),
fold_constant_broadcast<ngraph::bfloat16>(constant_match, broadcast_match, func));
return true;
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_broadcast_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_broadcast_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_broadcast<char>(constant_match, broadcast_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_broadcast<bfloat16>(constant_match, broadcast_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_broadcast<float16>(constant_match, broadcast_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_broadcast<float>(constant_match, broadcast_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_broadcast<double>(constant_match, broadcast_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_broadcast<int8_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_broadcast<int16_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_broadcast<int32_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_broadcast<int64_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_broadcast<uint8_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_broadcast<uint16_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_broadcast<uint32_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_broadcast<uint64_t>(constant_match, broadcast_match, func);
break;
}
return false;
replace_node(m.get_match_root(), replacement);
return true;
};
auto broadcast_matcher =
......@@ -651,6 +699,127 @@ void pass::ConstantFolding::construct_constant_broadcast()
broadcast_matcher, constant_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
shared_ptr<op::Constant> fold_constant_dyn_broadcast(shared_ptr<op::Constant> arg,
shared_ptr<op::Constant> shape,
shared_ptr<op::Constant> axes)
{
auto out_shape = shape->get_shape_val();
vector<T> out_vec(shape_size(out_shape));
runtime::reference::broadcast<T>(arg->get_data_ptr<T>(),
out_vec.data(),
arg->get_shape(),
out_shape,
axes->get_axis_set_val());
return make_shared<op::Constant>(arg->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_dyn_broadcast()
{
auto constant_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
auto constant_shape_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto dyn_broadcast = make_shared<op::DynBroadcast>(
constant_arg_label, constant_shape_label, constant_axes_label);
auto constant_dyn_broadcast_callback = [constant_arg_label,
constant_shape_label,
constant_axes_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_dyn_broadcast_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_arg_match =
static_pointer_cast<op::Constant>(pattern_map[constant_arg_label]);
auto constant_shape_match =
static_pointer_cast<op::Constant>(pattern_map[constant_shape_label]);
auto constant_axes_match =
static_pointer_cast<op::Constant>(pattern_map[constant_axes_label]);
auto dyn_broadcast_match = static_pointer_cast<op::DynBroadcast>(m.get_match_root());
std::shared_ptr<Node> replacement;
auto type = dyn_broadcast_match->get_output_element_type(0);
switch (type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_dyn_broadcast_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_dyn_broadcast_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_dyn_broadcast<char>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::bf16:
replacement = fold_constant_dyn_broadcast<bfloat16>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::f16:
replacement = fold_constant_dyn_broadcast<float16>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::f32:
replacement = fold_constant_dyn_broadcast<float>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::f64:
replacement = fold_constant_dyn_broadcast<double>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::i8:
replacement = fold_constant_dyn_broadcast<int8_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::i16:
replacement = fold_constant_dyn_broadcast<int16_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::i32:
replacement = fold_constant_dyn_broadcast<int32_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::i64:
replacement = fold_constant_dyn_broadcast<int64_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::u8:
replacement = fold_constant_dyn_broadcast<uint8_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::u16:
replacement = fold_constant_dyn_broadcast<uint16_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::u32:
replacement = fold_constant_dyn_broadcast<uint32_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::u64:
replacement = fold_constant_dyn_broadcast<uint64_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto dyn_broadcast_matcher =
make_shared<pattern::Matcher>(dyn_broadcast, "ConstantFolding.ConstantDynBroadcast");
this->add_matcher(
dyn_broadcast_matcher, constant_dyn_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class Tin, class Tout>
shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
......@@ -1382,7 +1551,6 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant>
}
NGRAPH_UNREACHABLE("Unexpected switch case");
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
......@@ -1440,7 +1608,6 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
}
NGRAPH_UNREACHABLE("Unexpected switch case");
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
......
......@@ -34,6 +34,7 @@ public:
{
RESHAPE,
BROADCAST,
DYN_BROADCAST,
PAD,
DEQUANTIZE,
UNARY,
......@@ -60,6 +61,7 @@ public:
m_cfmap = cfmap;
construct_constant_reshape();
construct_constant_broadcast();
construct_constant_dyn_broadcast();
construct_constant_pad();
construct_constant_unary();
construct_constant_binary();
......@@ -93,6 +95,7 @@ public:
{
case CFTransformations::RESHAPE: construct_constant_reshape(); break;
case CFTransformations::BROADCAST: construct_constant_broadcast(); break;
case CFTransformations::DYN_BROADCAST: construct_constant_dyn_broadcast(); break;
case CFTransformations::PAD: construct_constant_pad(); break;
case CFTransformations::UNARY: construct_constant_unary(); break;
case CFTransformations::BINARY: construct_constant_binary(); break;
......@@ -122,6 +125,7 @@ public:
private:
void construct_constant_reshape();
void construct_constant_broadcast();
void construct_constant_dyn_broadcast();
void construct_constant_pad();
void construct_constant_unary();
void construct_constant_binary();
......
......@@ -97,8 +97,35 @@ TEST(constant_folding, constant_broadcast)
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> values_permute{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_permute, values_out);
vector<int> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_dyn_broadcast)
{
vector<int32_t> values_in{0, 1};
auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
vector<int64_t> shape_in{2, 4};
auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
vector<int64_t> axes_in{1};
auto constant_axes = make_shared<op::Constant>(element::i64, Shape{1}, axes_in);
auto dyn_broadcast = make_shared<op::DynBroadcast>(constant_in, constant_shape, constant_axes);
auto f = make_shared<Function>(dyn_broadcast, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynBroadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_pad_exterior)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment