Unverified Commit 7f57d4e1 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into cyphers/fuseddesc

parents f6528bf5 f50e12a1
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/ceiling.hpp" #include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
...@@ -30,6 +32,7 @@ ...@@ -30,6 +32,7 @@
#include "ngraph/op/dequantize.hpp" #include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp" #include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/range.hpp" #include "ngraph/op/experimental/range.hpp"
...@@ -41,7 +44,9 @@ ...@@ -41,7 +44,9 @@
#include "ngraph/op/greater_eq.hpp" #include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp" #include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
...@@ -64,7 +69,9 @@ ...@@ -64,7 +69,9 @@
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/add.hpp" #include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/and.hpp" #include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/broadcast.hpp" #include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/ceiling.hpp" #include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp" #include "ngraph/runtime/reference/concat.hpp"
...@@ -78,7 +85,9 @@ ...@@ -78,7 +85,9 @@
#include "ngraph/runtime/reference/greater_eq.hpp" #include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp" #include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp" #include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/maximum.hpp" #include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/minimum.hpp" #include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp" #include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp" #include "ngraph/runtime/reference/negate.hpp"
...@@ -133,6 +142,92 @@ shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant ...@@ -133,6 +142,92 @@ shared_ptr<op::Constant> fold_constant_reshape(shared_ptr<op::Constant> constant
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
} }
void pass::ConstantFolding::construct_constant_reshape()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1});
auto constant_reshape_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root());
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto handler = m_cfmap.find(type_index(typeid(ngraph::op::Reshape)));
NGRAPH_CHECK(handler != m_cfmap.end(),
"constant folding map should have reshape entry");
func = handler->second(reshape_match.get());
}
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
switch (type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_reshape_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_reshape<char>(constant_match, reshape_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_reshape<bfloat16>(constant_match, reshape_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_reshape<float16>(constant_match, reshape_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_reshape<float>(constant_match, reshape_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_reshape<double>(constant_match, reshape_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_reshape<int8_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_reshape<int16_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_reshape<int32_t>(constant_match, reshape_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_reshape<int64_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_reshape<uint8_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_reshape<uint16_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_reshape<uint32_t>(constant_match, reshape_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_reshape<uint64_t>(constant_match, reshape_match, func);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T> template <class T>
shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_pad(shared_ptr<op::Constant> constant,
shared_ptr<op::Pad> pad, shared_ptr<op::Pad> pad,
...@@ -199,123 +294,63 @@ void pass::ConstantFolding::construct_constant_pad() ...@@ -199,123 +294,63 @@ void pass::ConstantFolding::construct_constant_pad()
func = handler->second(pad_match.get()); func = handler->second(pad_match.get());
} }
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
fold_constant_pad<int>(constant_match, pad_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_pad<int8_t>(constant_match, pad_match, func));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
fold_constant_pad<float>(constant_match, pad_match, func));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
fold_constant_pad<double>(constant_match, pad_match, func));
return true;
}
return false;
};
auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
void pass::ConstantFolding::construct_constant_reshape()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1});
auto constant_reshape_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = static_pointer_cast<op::Reshape>(m.get_match_root());
NodeExecutorTy func = nullptr;
if (!m_cfmap.empty())
{
auto handler = m_cfmap.find(type_index(typeid(ngraph::op::Reshape)));
NGRAPH_CHECK(handler != m_cfmap.end(),
"constant folding map should have reshape entry");
func = handler->second(reshape_match.get());
}
std::shared_ptr<Node> replacement; std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type(); auto type = constant_match->get_element_type();
switch (type.get_type_enum()) switch (type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
NGRAPH_CHECK(false, NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_pad_callback");
"Encountered 'undefined' element type in constant_reshape_callback");
break; break;
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_reshape_callback"); NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_pad_callback");
break; break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = fold_constant_reshape<char>(constant_match, reshape_match, func); replacement = fold_constant_pad<char>(constant_match, pad_match, func);
break; break;
case element::Type_t::bf16: case element::Type_t::bf16:
replacement = fold_constant_reshape<bfloat16>(constant_match, reshape_match, func); replacement = fold_constant_pad<bfloat16>(constant_match, pad_match, func);
break; break;
case element::Type_t::f16: case element::Type_t::f16:
replacement = fold_constant_reshape<float16>(constant_match, reshape_match, func); replacement = fold_constant_pad<float16>(constant_match, pad_match, func);
break; break;
case element::Type_t::f32: case element::Type_t::f32:
replacement = fold_constant_reshape<float>(constant_match, reshape_match, func); replacement = fold_constant_pad<float>(constant_match, pad_match, func);
break; break;
case element::Type_t::f64: case element::Type_t::f64:
replacement = fold_constant_reshape<double>(constant_match, reshape_match, func); replacement = fold_constant_pad<double>(constant_match, pad_match, func);
break; break;
case element::Type_t::i8: case element::Type_t::i8:
replacement = fold_constant_reshape<int8_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<int8_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::i16: case element::Type_t::i16:
replacement = fold_constant_reshape<int16_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<int16_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::i32: case element::Type_t::i32:
replacement = fold_constant_reshape<int32_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<int32_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::i64: case element::Type_t::i64:
replacement = fold_constant_reshape<int64_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<int64_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::u8: case element::Type_t::u8:
replacement = fold_constant_reshape<uint8_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<uint8_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::u16: case element::Type_t::u16:
replacement = fold_constant_reshape<uint16_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<uint16_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::u32: case element::Type_t::u32:
replacement = fold_constant_reshape<uint32_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<uint32_t>(constant_match, pad_match, func);
break; break;
case element::Type_t::u64: case element::Type_t::u64:
replacement = fold_constant_reshape<uint64_t>(constant_match, reshape_match, func); replacement = fold_constant_pad<uint64_t>(constant_match, pad_match, func);
break; break;
} }
replace_node(m.get_match_root(), replacement); replace_node(m.get_match_root(), replacement);
return false; return true;
}; };
auto reshape_matcher = auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape"); this->add_matcher(pad_matcher, constant_pad_callback, PassProperty::REQUIRE_STATIC_SHAPE);
this->add_matcher(
reshape_matcher, constant_reshape_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
template <class T> template <class T>
...@@ -422,7 +457,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape() ...@@ -422,7 +457,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
} }
replace_node(m.get_match_root(), replacement); replace_node(m.get_match_root(), replacement);
return false; return true;
}; };
auto dyn_reshape_matcher = auto dyn_reshape_matcher =
...@@ -538,7 +573,7 @@ void pass::ConstantFolding::construct_constant_transpose() ...@@ -538,7 +573,7 @@ void pass::ConstantFolding::construct_constant_transpose()
} }
replace_node(m.get_match_root(), replacement); replace_node(m.get_match_root(), replacement);
return false; return true;
}; };
auto transpose_matcher = auto transpose_matcher =
...@@ -601,40 +636,61 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -601,40 +636,61 @@ void pass::ConstantFolding::construct_constant_broadcast()
func = handler->second(broadcast_match.get()); func = handler->second(broadcast_match.get());
} }
auto type = constant_match->get_element_type(); std::shared_ptr<Node> replacement;
if (type == element::i32) auto type = broadcast_match->get_element_type();
{ switch (type.get_type_enum())
replace_node(m.get_match_root(),
fold_constant_broadcast<int>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<int8_t>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<float>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
fold_constant_broadcast<double>(constant_match, broadcast_match, func));
return true;
}
else if (type == element::bf16)
{ {
replace_node( case element::Type_t::undefined:
m.get_match_root(), NGRAPH_CHECK(false,
fold_constant_broadcast<ngraph::bfloat16>(constant_match, broadcast_match, func)); "Encountered 'undefined' element type in constant_broadcast_callback");
return true; break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_broadcast_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_broadcast<char>(constant_match, broadcast_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_broadcast<bfloat16>(constant_match, broadcast_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_broadcast<float16>(constant_match, broadcast_match, func);
break;
case element::Type_t::f32:
replacement = fold_constant_broadcast<float>(constant_match, broadcast_match, func);
break;
case element::Type_t::f64:
replacement = fold_constant_broadcast<double>(constant_match, broadcast_match, func);
break;
case element::Type_t::i8:
replacement = fold_constant_broadcast<int8_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_broadcast<int16_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_broadcast<int32_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_broadcast<int64_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_broadcast<uint8_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_broadcast<uint16_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_broadcast<uint32_t>(constant_match, broadcast_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_broadcast<uint64_t>(constant_match, broadcast_match, func);
break;
} }
return false; replace_node(m.get_match_root(), replacement);
return true;
}; };
auto broadcast_matcher = auto broadcast_matcher =
...@@ -643,6 +699,127 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -643,6 +699,127 @@ void pass::ConstantFolding::construct_constant_broadcast()
broadcast_matcher, constant_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE); broadcast_matcher, constant_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
template <class T>
shared_ptr<op::Constant> fold_constant_dyn_broadcast(shared_ptr<op::Constant> arg,
shared_ptr<op::Constant> shape,
shared_ptr<op::Constant> axes)
{
auto out_shape = shape->get_shape_val();
vector<T> out_vec(shape_size(out_shape));
runtime::reference::broadcast<T>(arg->get_data_ptr<T>(),
out_vec.data(),
arg->get_shape(),
out_shape,
axes->get_axis_set_val());
return make_shared<op::Constant>(arg->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_dyn_broadcast()
{
auto constant_arg_label =
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
auto constant_shape_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{1}, pattern::has_class<op::Constant>());
auto dyn_broadcast = make_shared<op::DynBroadcast>(
constant_arg_label, constant_shape_label, constant_axes_label);
auto constant_dyn_broadcast_callback = [constant_arg_label,
constant_shape_label,
constant_axes_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_dyn_broadcast_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_arg_match =
static_pointer_cast<op::Constant>(pattern_map[constant_arg_label]);
auto constant_shape_match =
static_pointer_cast<op::Constant>(pattern_map[constant_shape_label]);
auto constant_axes_match =
static_pointer_cast<op::Constant>(pattern_map[constant_axes_label]);
auto dyn_broadcast_match = static_pointer_cast<op::DynBroadcast>(m.get_match_root());
std::shared_ptr<Node> replacement;
auto type = dyn_broadcast_match->get_output_element_type(0);
switch (type.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_dyn_broadcast_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_dyn_broadcast_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_dyn_broadcast<char>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::bf16:
replacement = fold_constant_dyn_broadcast<bfloat16>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::f16:
replacement = fold_constant_dyn_broadcast<float16>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::f32:
replacement = fold_constant_dyn_broadcast<float>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::f64:
replacement = fold_constant_dyn_broadcast<double>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::i8:
replacement = fold_constant_dyn_broadcast<int8_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::i16:
replacement = fold_constant_dyn_broadcast<int16_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::i32:
replacement = fold_constant_dyn_broadcast<int32_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::i64:
replacement = fold_constant_dyn_broadcast<int64_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::u8:
replacement = fold_constant_dyn_broadcast<uint8_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::u16:
replacement = fold_constant_dyn_broadcast<uint16_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::u32:
replacement = fold_constant_dyn_broadcast<uint32_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
case element::Type_t::u64:
replacement = fold_constant_dyn_broadcast<uint64_t>(
constant_arg_match, constant_shape_match, constant_axes_match);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto dyn_broadcast_matcher =
make_shared<pattern::Matcher>(dyn_broadcast, "ConstantFolding.ConstantDynBroadcast");
this->add_matcher(
dyn_broadcast_matcher, constant_dyn_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class Tin, class Tout> template <class Tin, class Tout>
shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b, shared_ptr<op::Constant> b,
...@@ -1332,11 +1509,6 @@ template <typename TI> ...@@ -1332,11 +1509,6 @@ template <typename TI>
shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant> constant, shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant> constant,
const element::Type& output_element_type) const element::Type& output_element_type)
{ {
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (output_element_type.get_type_enum()) switch (output_element_type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
...@@ -1374,10 +1546,6 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant> ...@@ -1374,10 +1546,6 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant>
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); NGRAPH_UNREACHABLE("Unexpected switch case");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
} }
static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> constant, static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> constant,
...@@ -1390,11 +1558,6 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c ...@@ -1390,11 +1558,6 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
return constant; return constant;
} }
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (input_element_type.get_type_enum()) switch (input_element_type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
...@@ -1432,10 +1595,6 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c ...@@ -1432,10 +1595,6 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); NGRAPH_UNREACHABLE("Unexpected switch case");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
} }
void pass::ConstantFolding::construct_constant_convert() void pass::ConstantFolding::construct_constant_convert()
...@@ -1518,11 +1677,6 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c ...@@ -1518,11 +1677,6 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
{ {
auto& input_element_type = constant->get_output_element_type(0); auto& input_element_type = constant->get_output_element_type(0);
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (input_element_type.get_type_enum()) switch (input_element_type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
...@@ -1556,10 +1710,6 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c ...@@ -1556,10 +1710,6 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); NGRAPH_UNREACHABLE("Unexpected switch case");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
} }
void pass::ConstantFolding::construct_constant_reverse() void pass::ConstantFolding::construct_constant_reverse()
...@@ -1588,180 +1738,207 @@ void pass::ConstantFolding::construct_constant_reverse() ...@@ -1588,180 +1738,207 @@ void pass::ConstantFolding::construct_constant_reverse()
} }
template <typename T> template <typename T>
static shared_ptr<op::Constant> fold_constant_product_helper(shared_ptr<op::Constant> constant, static shared_ptr<op::Constant>
const AxisSet& reduction_axes, fold_constant_arithmetic_reduction_helper(shared_ptr<op::Constant> constant,
const Shape& result_shape) shared_ptr<Node> reduction_node)
{ {
vector<T> out_vec(shape_size(result_shape)); vector<T> out_vec(shape_size(reduction_node->get_shape()));
if (auto max = dynamic_pointer_cast<op::Max>(reduction_node))
{
runtime::reference::max<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
max->get_reduction_axes());
}
else if (auto min = dynamic_pointer_cast<op::Min>(reduction_node))
{
runtime::reference::min<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
min->get_reduction_axes());
}
else if (auto prod = dynamic_pointer_cast<op::Product>(reduction_node))
{
runtime::reference::product<T>(constant->get_vector<T>().data(), runtime::reference::product<T>(constant->get_vector<T>().data(),
out_vec.data(), out_vec.data(),
constant->get_output_shape(0), constant->get_output_shape(0),
result_shape, reduction_node->get_shape(),
reduction_axes); prod->get_reduction_axes());
}
else if (auto sum = dynamic_pointer_cast<op::Sum>(reduction_node))
{
runtime::reference::sum<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_output_shape(0),
reduction_node->get_shape(),
sum->get_reduction_axes());
}
else
{
NGRAPH_CHECK(false,
"Internal nGraph error: Ops handled in "
"fold_constant_arithmetic_reduction_helper must be consistent with those "
"matched in construct_constant_arithmetic_reduction");
}
return make_shared<op::Constant>(constant->get_output_element_type(0), result_shape, out_vec); return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
} }
static shared_ptr<op::Constant> fold_constant_product(shared_ptr<op::Constant> constant, static shared_ptr<op::Constant>
const AxisSet& reduction_axes, fold_constant_arithmetic_reduction(shared_ptr<op::Constant> constant,
const Shape& result_shape) shared_ptr<Node> reduction_node)
{ {
auto& input_element_type = constant->get_output_element_type(0); auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum()) switch (input_element_type.get_type_enum())
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_product"); NGRAPH_CHECK(false,
"Encountered 'undefined' element type in fold_constant_arithmetic_reduction");
break; break;
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_product"); NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in fold_constant_arithmetic_reduction");
break; break;
case element::Type_t::boolean: case element::Type_t::boolean:
return fold_constant_product_helper<char>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<char>(constant, reduction_node);
case element::Type_t::bf16: case element::Type_t::bf16:
return fold_constant_product_helper<bfloat16>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<bfloat16>(constant, reduction_node);
case element::Type_t::f16: case element::Type_t::f16:
return fold_constant_product_helper<float16>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<float16>(constant, reduction_node);
case element::Type_t::f32: case element::Type_t::f32:
return fold_constant_product_helper<float>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<float>(constant, reduction_node);
case element::Type_t::f64: case element::Type_t::f64:
return fold_constant_product_helper<double>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<double>(constant, reduction_node);
case element::Type_t::i8: case element::Type_t::i8:
return fold_constant_product_helper<int8_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<int8_t>(constant, reduction_node);
case element::Type_t::i16: case element::Type_t::i16:
return fold_constant_product_helper<int16_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<int16_t>(constant, reduction_node);
case element::Type_t::i32: case element::Type_t::i32:
return fold_constant_product_helper<int32_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<int32_t>(constant, reduction_node);
case element::Type_t::i64: case element::Type_t::i64:
return fold_constant_product_helper<int64_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<int64_t>(constant, reduction_node);
case element::Type_t::u8: case element::Type_t::u8:
return fold_constant_product_helper<uint8_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<uint8_t>(constant, reduction_node);
case element::Type_t::u16: case element::Type_t::u16:
return fold_constant_product_helper<uint16_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<uint16_t>(constant, reduction_node);
case element::Type_t::u32: case element::Type_t::u32:
return fold_constant_product_helper<uint32_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<uint32_t>(constant, reduction_node);
case element::Type_t::u64: case element::Type_t::u64:
return fold_constant_product_helper<uint64_t>(constant, reduction_axes, result_shape); return fold_constant_arithmetic_reduction_helper<uint64_t>(constant, reduction_node);
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); NGRAPH_UNREACHABLE("Unexpected switch case");
} }
void pass::ConstantFolding::construct_constant_product() void pass::ConstantFolding::construct_constant_arithmetic_reduction()
{ {
auto constant_label = make_shared<pattern::op::Label>( auto constant_data_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>()); element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Product>(constant_label, AxisSet{0, 1, 2}); auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto constant_product_callback = [constant_label](pattern::Matcher& m) { auto is_supported_reduction = [](std::shared_ptr<Node> n) {
NGRAPH_DEBUG << "In callback for constant_product_callback against node = " return (pattern::has_class<op::Max>()(n) || pattern::has_class<op::Min>()(n) ||
pattern::has_class<op::Product>()(n) || pattern::has_class<op::Sum>()(n));
};
auto reduction =
std::make_shared<pattern::op::Any>(element::i32,
Shape{2},
is_supported_reduction,
NodeVector{constant_data_label, constant_axes_label});
auto constant_arithmetic_reduction_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_arithmetic_reduction_callback against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto product_match = static_pointer_cast<op::Product>(m.get_match_root()); auto reduction_match = m.get_match_root();
replace_node(m.get_match_root(), replace_node(reduction_match,
fold_constant_product(constant_match, fold_constant_arithmetic_reduction(constant_match, reduction_match));
product_match->get_reduction_axes(),
product_match->get_output_shape(0)));
return true; return true;
}; };
auto convert_matcher = auto arithmetic_reduction_matcher =
make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantProduct"); make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantArithmeticReduction");
this->add_matcher(convert_matcher, constant_product_callback, all_pass_property_off); this->add_matcher(arithmetic_reduction_matcher,
constant_arithmetic_reduction_callback,
all_pass_property_off);
} }
// TODO(amprocte): Find a way to reduce duplication with Product. (The fact static shared_ptr<op::Constant> fold_constant_logical_reduction(shared_ptr<op::Constant> constant,
// that we bottom out in a reference call makes it a bit tricky.) shared_ptr<Node> reduction_node)
template <typename T>
static shared_ptr<op::Constant> fold_constant_sum_helper(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
{ {
vector<T> out_vec(shape_size(result_shape)); vector<char> out_vec(shape_size(reduction_node->get_shape()));
runtime::reference::sum<T>(constant->get_vector<T>().data(), if (auto all = dynamic_pointer_cast<::ngraph::op::All>(reduction_node))
{
runtime::reference::all(constant->get_vector<char>().data(),
out_vec.data(), out_vec.data(),
constant->get_output_shape(0), constant->get_output_shape(0),
result_shape, reduction_node->get_shape(),
reduction_axes); all->get_reduction_axes());
}
return make_shared<op::Constant>(constant->get_output_element_type(0), result_shape, out_vec); else if (auto any = dynamic_pointer_cast<::ngraph::op::Any>(reduction_node))
}
static shared_ptr<op::Constant> fold_constant_sum(shared_ptr<op::Constant> constant,
const AxisSet& reduction_axes,
const Shape& result_shape)
{
auto& input_element_type = constant->get_output_element_type(0);
switch (input_element_type.get_type_enum())
{ {
case element::Type_t::undefined: runtime::reference::any(constant->get_vector<char>().data(),
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_sum"); out_vec.data(),
break; constant->get_output_shape(0),
case element::Type_t::dynamic: reduction_node->get_shape(),
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_sum"); any->get_reduction_axes());
break; }
case element::Type_t::boolean: else
return fold_constant_sum_helper<char>(constant, reduction_axes, result_shape); {
case element::Type_t::bf16: NGRAPH_CHECK(false,
return fold_constant_sum_helper<bfloat16>(constant, reduction_axes, result_shape); "Internal nGraph error: Ops handled in "
case element::Type_t::f16: "fold_constant_logical_reduction must be consistent with those "
return fold_constant_sum_helper<float16>(constant, reduction_axes, result_shape); "matched in construct_constant_logical_reduction");
case element::Type_t::f32:
return fold_constant_sum_helper<float>(constant, reduction_axes, result_shape);
case element::Type_t::f64:
return fold_constant_sum_helper<double>(constant, reduction_axes, result_shape);
case element::Type_t::i8:
return fold_constant_sum_helper<int8_t>(constant, reduction_axes, result_shape);
case element::Type_t::i16:
return fold_constant_sum_helper<int16_t>(constant, reduction_axes, result_shape);
case element::Type_t::i32:
return fold_constant_sum_helper<int32_t>(constant, reduction_axes, result_shape);
case element::Type_t::i64:
return fold_constant_sum_helper<int64_t>(constant, reduction_axes, result_shape);
case element::Type_t::u8:
return fold_constant_sum_helper<uint8_t>(constant, reduction_axes, result_shape);
case element::Type_t::u16:
return fold_constant_sum_helper<uint16_t>(constant, reduction_axes, result_shape);
case element::Type_t::u32:
return fold_constant_sum_helper<uint32_t>(constant, reduction_axes, result_shape);
case element::Type_t::u64:
return fold_constant_sum_helper<uint64_t>(constant, reduction_axes, result_shape);
} }
NGRAPH_UNREACHABLE("Unexpected switch case"); return make_shared<op::Constant>(
reduction_node->get_output_element_type(0), reduction_node->get_shape(), out_vec);
} }
void pass::ConstantFolding::construct_constant_sum() void pass::ConstantFolding::construct_constant_logical_reduction()
{ {
auto constant_label = make_shared<pattern::op::Label>( auto constant_data_label = make_shared<pattern::op::Label>(
element::i32, Shape{2, 3, 4}, pattern::has_class<op::Constant>()); element::boolean, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto convert_op = make_shared<op::Sum>(constant_label, AxisSet{0, 1, 2}); auto constant_axes_label =
make_shared<pattern::op::Label>(element::i64, Shape{2}, pattern::has_class<op::Constant>());
auto constant_sum_callback = [constant_label](pattern::Matcher& m) { auto is_supported_reduction = [](std::shared_ptr<Node> n) {
NGRAPH_DEBUG << "In callback for constant_sum_callback against node = " return (pattern::has_class<::ngraph::op::All>()(n) ||
pattern::has_class<::ngraph::op::Any>()(n));
};
auto reduction =
std::make_shared<pattern::op::Any>(element::i32,
Shape{2},
is_supported_reduction,
NodeVector{constant_data_label, constant_axes_label});
auto constant_logical_reduction_callback = [constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_logical_reduction_callback against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto sum_match = static_pointer_cast<op::Sum>(m.get_match_root()); auto reduction_match = m.get_match_root();
replace_node(m.get_match_root(), replace_node(reduction_match,
fold_constant_sum(constant_match, fold_constant_logical_reduction(constant_match, reduction_match));
sum_match->get_reduction_axes(),
sum_match->get_output_shape(0)));
return true; return true;
}; };
auto convert_matcher = make_shared<pattern::Matcher>(convert_op, "ConstantFolding.ConstantSum"); auto logical_reduction_matcher =
this->add_matcher(convert_matcher, constant_sum_callback, all_pass_property_off); make_shared<pattern::Matcher>(reduction, "ConstantFolding.ConstantLogicalReduction");
this->add_matcher(
logical_reduction_matcher, constant_logical_reduction_callback, all_pass_property_off);
} }
template <typename T> template <typename T>
......
...@@ -34,6 +34,7 @@ public: ...@@ -34,6 +34,7 @@ public:
{ {
RESHAPE, RESHAPE,
BROADCAST, BROADCAST,
DYN_BROADCAST,
PAD, PAD,
DEQUANTIZE, DEQUANTIZE,
UNARY, UNARY,
...@@ -42,8 +43,8 @@ public: ...@@ -42,8 +43,8 @@ public:
CONVERT, CONVERT,
SHAPE_OF, SHAPE_OF,
REVERSE, REVERSE,
PRODUCT, ARITHMETIC_REDUCTION,
SUM, LOGICAL_REDUCTION,
CONCAT, CONCAT,
GATHER, GATHER,
SLICE, SLICE,
...@@ -60,6 +61,7 @@ public: ...@@ -60,6 +61,7 @@ public:
m_cfmap = cfmap; m_cfmap = cfmap;
construct_constant_reshape(); construct_constant_reshape();
construct_constant_broadcast(); construct_constant_broadcast();
construct_constant_dyn_broadcast();
construct_constant_pad(); construct_constant_pad();
construct_constant_unary(); construct_constant_unary();
construct_constant_binary(); construct_constant_binary();
...@@ -68,8 +70,8 @@ public: ...@@ -68,8 +70,8 @@ public:
construct_constant_convert(); construct_constant_convert();
construct_constant_shape_of(); construct_constant_shape_of();
construct_constant_reverse(); construct_constant_reverse();
construct_constant_product(); construct_constant_arithmetic_reduction();
construct_constant_sum(); construct_constant_logical_reduction();
construct_constant_concat(); construct_constant_concat();
construct_constant_gather(); construct_constant_gather();
construct_constant_slice(); construct_constant_slice();
...@@ -93,6 +95,7 @@ public: ...@@ -93,6 +95,7 @@ public:
{ {
case CFTransformations::RESHAPE: construct_constant_reshape(); break; case CFTransformations::RESHAPE: construct_constant_reshape(); break;
case CFTransformations::BROADCAST: construct_constant_broadcast(); break; case CFTransformations::BROADCAST: construct_constant_broadcast(); break;
case CFTransformations::DYN_BROADCAST: construct_constant_dyn_broadcast(); break;
case CFTransformations::PAD: construct_constant_pad(); break; case CFTransformations::PAD: construct_constant_pad(); break;
case CFTransformations::UNARY: construct_constant_unary(); break; case CFTransformations::UNARY: construct_constant_unary(); break;
case CFTransformations::BINARY: construct_constant_binary(); break; case CFTransformations::BINARY: construct_constant_binary(); break;
...@@ -101,8 +104,12 @@ public: ...@@ -101,8 +104,12 @@ public:
case CFTransformations::CONVERT: construct_constant_convert(); break; case CFTransformations::CONVERT: construct_constant_convert(); break;
case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break; case CFTransformations::SHAPE_OF: construct_constant_shape_of(); break;
case CFTransformations::REVERSE: construct_constant_reverse(); break; case CFTransformations::REVERSE: construct_constant_reverse(); break;
case CFTransformations::PRODUCT: construct_constant_product(); break; case CFTransformations::ARITHMETIC_REDUCTION:
case CFTransformations::SUM: construct_constant_sum(); break; construct_constant_arithmetic_reduction();
break;
case CFTransformations::LOGICAL_REDUCTION:
construct_constant_logical_reduction();
break;
case CFTransformations::CONCAT: construct_constant_concat(); break; case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::GATHER: construct_constant_gather(); break; case CFTransformations::GATHER: construct_constant_gather(); break;
case CFTransformations::SLICE: construct_constant_slice(); break; case CFTransformations::SLICE: construct_constant_slice(); break;
...@@ -118,6 +125,7 @@ public: ...@@ -118,6 +125,7 @@ public:
private: private:
void construct_constant_reshape(); void construct_constant_reshape();
void construct_constant_broadcast(); void construct_constant_broadcast();
void construct_constant_dyn_broadcast();
void construct_constant_pad(); void construct_constant_pad();
void construct_constant_unary(); void construct_constant_unary();
void construct_constant_binary(); void construct_constant_binary();
...@@ -126,8 +134,8 @@ private: ...@@ -126,8 +134,8 @@ private:
void construct_constant_convert(); void construct_constant_convert();
void construct_constant_shape_of(); void construct_constant_shape_of();
void construct_constant_reverse(); void construct_constant_reverse();
void construct_constant_product(); void construct_constant_arithmetic_reduction();
void construct_constant_sum(); void construct_constant_logical_reduction();
void construct_constant_concat(); void construct_constant_concat();
void construct_constant_gather(); void construct_constant_gather();
void construct_constant_slice(); void construct_constant_slice();
......
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
const AxisSet& reduction_axes) const AxisSet& reduction_axes)
{ {
T minval = std::numeric_limits<T>::has_infinity T minval = std::numeric_limits<T>::has_infinity
? -std::numeric_limits<T>::infinity() ? T(-std::numeric_limits<T>::infinity())
: std::numeric_limits<T>::min(); : std::numeric_limits<T>::min();
CoordinateTransform output_transform(out_shape); CoordinateTransform output_transform(out_shape);
......
...@@ -97,8 +97,35 @@ TEST(constant_folding, constant_broadcast) ...@@ -97,8 +97,35 @@ TEST(constant_folding, constant_broadcast)
ASSERT_TRUE(new_const); ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>(); auto values_out = new_const->get_vector<int>();
vector<int> values_permute{0, 0, 0, 0, 1, 1, 1, 1}; vector<int> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_permute, values_out); ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, constant_dyn_broadcast)
{
vector<int32_t> values_in{0, 1};
auto constant_in = make_shared<op::Constant>(element::i32, Shape{2}, values_in);
vector<int64_t> shape_in{2, 4};
auto constant_shape = make_shared<op::Constant>(element::i64, Shape{2}, shape_in);
vector<int64_t> axes_in{1};
auto constant_axes = make_shared<op::Constant>(element::i64, Shape{1}, axes_in);
auto dyn_broadcast = make_shared<op::DynBroadcast>(constant_in, constant_shape, constant_axes);
auto f = make_shared<Function>(dyn_broadcast, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynBroadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, constant_pad_exterior) TEST(constant_folding, constant_pad_exterior)
...@@ -434,6 +461,110 @@ TEST(constant_folding, const_sum) ...@@ -434,6 +461,110 @@ TEST(constant_folding, const_sum)
ASSERT_EQ(values_expected, values_out); ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, const_max)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Max>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Max>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{3, 6, 9};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_min)
{
Shape input_shape{3, 3};
vector<int32_t> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9};
auto constant = op::Constant::create(element::i32, input_shape, values_in);
auto convert = make_shared<op::Min>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Min>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int32_t>();
vector<int32_t> values_expected{1, 4, 7};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_all)
{
Shape input_shape{3, 3};
vector<char> values_in{0, 1, 1, 0, 1, 0, 1, 1, 1};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::All>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::All>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_any)
{
Shape input_shape{3, 3};
vector<char> values_in{1, 0, 0, 1, 0, 1, 0, 0, 0};
auto constant = op::Constant::create(element::boolean, input_shape, values_in);
auto convert = make_shared<op::Any>(constant, AxisSet{1});
auto f = make_shared<Function>(convert, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Any>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 1, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_concat) TEST(constant_folding, const_concat)
{ {
auto constant0 = auto constant0 =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment