Unverified Commit df661763 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into aprocter/cf-dyn-broadcast

parents e9e410aa e57c0f0f
...@@ -142,6 +142,92 @@ shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant ...@@ -142,6 +142,92 @@ shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
void pass::ConstantFolding::construct_constant_reshape()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1});
auto constant_reshape_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root());
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto handler = m_cfmap.find(type_index(typeid(ngraph::op::Reshape)));
NGRAPH_CHECK(handler != m_cfmap.end(),
"constant folding map should have reshape entry");
func = handler->second(reshape_match.get());
}
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
switch (type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_reshape_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_reshape<char>(constant_match, reshape_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_reshape<bfloat16>(constant_match, reshape_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_reshape<float16>(constant_match, reshape_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_reshape<float>(constant_match, reshape_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_reshape<double>(constant_match, reshape_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_reshape<int8_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_reshape<int16_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_reshape<int32_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_reshape<int64_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_reshape<uint8_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_reshape<uint16_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_reshape<uint32_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_reshape<uint64_t>(constant_match, reshape_match, func);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T> template <class T>
shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
shared_ptr<op::Pad> pad, shared_ptr<op::Pad> pad,
...@@ -208,123 +294,63 @@ void pass::ConstantFolding::construct_constant_pad() ...@@ -208,123 +294,63 @@ void pass::ConstantFolding::construct_constant_pad()
func = handler->second(pad_match.get()); func = handler->second(pad_match.get());
} }
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
fold_constant_pad<int>(constant_match, pad_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_pad<int8_t>(constant_match, pad_match, func));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
fold_constant_pad<float>(constant_match, pad_match, func));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
fold_constant_pad<double>(constant_match, pad_match, func));
return true;
}
return false;
};
auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::ConstantFolding::construct_constant_reshape()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1});
auto constant_reshape_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root());
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto handler = m_cfmap.find(type_index(typeid(ngraph::op::Reshape)));
NGRAPH_CHECK(handler != m_cfmap.end(),
"constant folding map should have reshape entry");
func = handler->second(reshape_match.get());
}
std::shared_ptr<Node> replacement; std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
switch (type.get_type_enum()) switch (type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
NGRAPH_CHECK(false, NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_pad_callback");
"Encountered 'undefined' element type in constant_reshape_callback");
break; break;
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_pad_callback");
break; break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_reshape<char>(constant_match, reshape_match, func); replacement = fold_constant_pad<char>(constant_match, pad_match, func);
break; break;
case element::Type_t::bf16: case element::Type_t::bf16:
replacement = fold_constant_reshape<bfloat16>(constant_match, reshape_match, func); replacement = fold_constant_pad<bfloat16>(constant_match, pad_match, func);
break; break;
case element::Type_t::f16: case element::Type_t::f16:
replacement = fold_constant_reshape<float16>(constant_match, reshape_match, func); replacement = fold_constant_pad<float16>(constant_match, pad_match, func);
break; break;
case element::Type_t::f32: case element::Type_t::f32:
replacement = fold_constant_reshape<float>(constant_match, reshape_match, func); replacement = fold_constant_pad<float>(constant_match, pad_match, func);
break; break;
case element::Type_t::f64: case element::Type_t::f64:
replacement = fold_constant_reshape<double>(constant_match, reshape_match, func); replacement = fold_constant_pad<double>(constant_match, pad_match, func);
break; break;
case element::Type_t::i8: case element::Type_t::i8:
replacement = fold_constant_reshape<int8_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<int8_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::i16: case element::Type_t::i16:
replacement = fold_constant_reshape<int16_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<int16_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::i32: case element::Type_t::i32:
replacement = fold_constant_reshape<int32_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<int32_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::i64: case element::Type_t::i64:
replacement = fold_constant_reshape<int64_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<int64_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::u8: case element::Type_t::u8:
replacement = fold_constant_reshape<uint8_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<uint8_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::u16: case element::Type_t::u16:
replacement = fold_constant_reshape<uint16_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<uint16_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::u32: case element::Type_t::u32:
replacement = fold_constant_reshape<uint32_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<uint32_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::u64: case element::Type_t::u64:
replacement = fold_constant_reshape<uint64_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<uint64_t>(constant_match, pad_match, func);
break; break;
} }
replace_node(m.get_match_root(), replacement); replace_node(m.get_match_root(), replacement);
return false; return true;
}; };
auto reshape_matcher = auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape"); this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::REQUIRE_STATIC_SHAPE);
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
template <class T> template <class T>
...@@ -431,7 +457,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape() ...@@ -431,7 +457,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
} }
replace_node(m.get_match_root(), replacement); replace_node(m.get_match_root(), replacement);
return false; return true;
}; };
auto dyn_reshape_matcher = auto dyn_reshape_matcher =
...@@ -547,7 +573,7 @@ void pass::ConstantFolding::construct_constant_transpose() ...@@ -547,7 +573,7 @@ void pass::ConstantFolding::construct_constant_transpose()
} }
replace_node(m.get_match_root(), replacement); replace_node(m.get_match_root(), replacement);
return false; return true;
}; };
auto transpose_matcher = auto transpose_matcher =
...@@ -610,40 +636,61 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -610,40 +636,61 @@ void pass::ConstantFolding::construct_constant_broadcast()
func = handler->second(broadcast_match.get()); func = handler->second(broadcast_match.get());
} }
auto type = constant_match->get_element_type(); std::shared_ptr<Node> replacement;
if (type == element::i32) auto type = broadcast_match->get_element_type();
{ switch (type.get_type_enum())
replace_node(m.get_match_root(),
fold_constant_broadcast<int>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<int8_t>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<float>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<double>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::bf16)
{ {
replace_node( case element::Type_t::undefined:
m.get_match_root(), NGRAPH_CHECK(false,
fold_constant_broadcast<ngraph::bfloat16>(constant_match, broadcast_match, func)); "Encountered 'undefined' element type in constant_broadcast_callback");
return true; break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_broadcast_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_broadcast<char>(constant_match, broadcast_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_broadcast<bfloat16>(constant_match, broadcast_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_broadcast<float16>(constant_match, broadcast_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_broadcast<float>(constant_match, broadcast_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_broadcast<double>(constant_match, broadcast_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_broadcast<int8_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_broadcast<int16_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_broadcast<int32_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_broadcast<int64_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_broadcast<uint8_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_broadcast<uint16_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_broadcast<uint32_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_broadcast<uint64_t>(constant_match, broadcast_match, func);
break;
} }
return false; replace_node(m.get_match_root(), replacement);
return true;
}; };
auto broadcast_matcher = auto broadcast_matcher =
...@@ -1462,11 +1509,6 @@ template <typename TI> ...@@ -1462,11 +1509,6 @@ template <typename TI>
shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant> constant,
const element::Type& output_element_type) const element::Type& output_element_type)
{ {
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (output_element_type.get_type_enum()) switch (output_element_type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
...@@ -1504,10 +1546,6 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant> ...@@ -1504,10 +1546,6 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant>
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); NGRAPH_UNREACHABLE("Unexpected switch case");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
} }
static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> constant, static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> constant,
...@@ -1520,11 +1558,6 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c ...@@ -1520,11 +1558,6 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
return constant; return constant;
} }
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (input_element_type.get_type_enum()) switch (input_element_type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
...@@ -1562,10 +1595,6 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c ...@@ -1562,10 +1595,6 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); NGRAPH_UNREACHABLE("Unexpected switch case");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
} }
void pass::ConstantFolding::construct_constant_convert() void pass::ConstantFolding::construct_constant_convert()
...@@ -1648,11 +1677,6 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c ...@@ -1648,11 +1677,6 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
{ {
auto& input_element_type = constant->get_output_element_type(0); auto& input_element_type = constant->get_output_element_type(0);
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (input_element_type.get_type_enum()) switch (input_element_type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
...@@ -1686,10 +1710,6 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c ...@@ -1686,10 +1710,6 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); NGRAPH_UNREACHABLE("Unexpected switch case");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
} }
void pass::ConstantFolding::construct_constant_reverse() void pass::ConstantFolding::construct_constant_reverse()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment