Commit bcaf32c4 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

ConstantFolding for Equal, Greater, GreaterEq, Less, LessEq, NotEqual (#3322)

* CF for And and Or

* CF support for comparison ops

* Fix predicate for binary elementwise; add unit tests for non-arithmetic binops

* Update CPU CF builders
parent c693cb7e
......@@ -20,17 +20,25 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
......@@ -45,15 +53,23 @@
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
......@@ -365,17 +381,17 @@ void pass::ConstantFolding::construct_constant_broadcast()
broadcast_matcher, constant_broadcast_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
template <class T>
template <class Tin, class Tout>
shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary,
NodeExecutorTy func)
{
auto out_shape = binary->get_shape();
vector<T> out_vec(shape_size(out_shape));
if (func != nullptr)
{
vector<Tout> out_vec(shape_size(out_shape));
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
......@@ -383,43 +399,160 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
outputs.push_back(out_vec.data());
func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
if (std::dynamic_pointer_cast<op::Add>(binary))
{
runtime::reference::add<T>(
a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Subtract>(binary))
{
runtime::reference::subtract<T>(
a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Multiply>(binary))
else if (std::dynamic_pointer_cast<op::And>(binary))
{
runtime::reference::multiply<T>(
a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::logical_and<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Divide>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
shared_ptr<op::Divide> divop = std::dynamic_pointer_cast<op::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<T>(a->get_data_ptr<T>(),
b->get_data_ptr<T>(),
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape),
pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Equal>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Greater>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::GreaterEq>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Less>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape),
pythondiv);
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Minimum>(binary))
else if (std::dynamic_pointer_cast<op::LessEq>(binary))
{
runtime::reference::minimum<T>(
a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Maximum>(binary))
{
runtime::reference::maximum<T>(
a->get_data_ptr<T>(), b->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Minimum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Multiply>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::NotEqual>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Or>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::logical_or<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (std::dynamic_pointer_cast<op::Subtract>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::subtract<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
shape_size(out_shape));
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
......@@ -427,16 +560,48 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
"fold_constant_binary must be consistent with is_supported_binary_op");
}
}
return make_shared<op::Constant>(a->get_element_type(), out_shape, out_vec);
}
template <class Tin>
shared_ptr<op::Constant> fold_constant_binary_helper(const element::Type& et_out,
shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary,
NodeExecutorTy func)
{
switch (et_out.get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_binary_callback");
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_binary_callback");
case element::Type_t::boolean: return fold_constant_binary<Tin, char>(a, b, binary, func);
case element::Type_t::bf16: return fold_constant_binary<Tin, bfloat16>(a, b, binary, func);
case element::Type_t::f16: return fold_constant_binary<Tin, float16>(a, b, binary, func);
case element::Type_t::f32: return fold_constant_binary<Tin, float>(a, b, binary, func);
case element::Type_t::f64: return fold_constant_binary<Tin, double>(a, b, binary, func);
case element::Type_t::i8: return fold_constant_binary<Tin, int8_t>(a, b, binary, func);
case element::Type_t::i16: return fold_constant_binary<Tin, int16_t>(a, b, binary, func);
case element::Type_t::i32: return fold_constant_binary<Tin, int32_t>(a, b, binary, func);
case element::Type_t::i64: return fold_constant_binary<Tin, int64_t>(a, b, binary, func);
case element::Type_t::u8: return fold_constant_binary<Tin, uint8_t>(a, b, binary, func);
case element::Type_t::u16: return fold_constant_binary<Tin, uint16_t>(a, b, binary, func);
case element::Type_t::u32: return fold_constant_binary<Tin, uint32_t>(a, b, binary, func);
case element::Type_t::u64: return fold_constant_binary<Tin, uint64_t>(a, b, binary, func);
}
NGRAPH_UNREACHABLE("Unreachable switch case");
}
bool is_supported_binary_op(std::shared_ptr<Node> n)
{
return (std::dynamic_pointer_cast<op::Add>(n) || std::dynamic_pointer_cast<op::Subtract>(n) ||
std::dynamic_pointer_cast<op::Multiply>(n) ||
std::dynamic_pointer_cast<op::Divide>(n) || std::dynamic_pointer_cast<op::Maximum>(n) ||
std::dynamic_pointer_cast<op::Minimum>(n));
return (
std::dynamic_pointer_cast<op::Add>(n) || std::dynamic_pointer_cast<op::And>(n) ||
std::dynamic_pointer_cast<op::Divide>(n) || std::dynamic_pointer_cast<op::Equal>(n) ||
std::dynamic_pointer_cast<op::Greater>(n) || std::dynamic_pointer_cast<op::GreaterEq>(n) ||
std::dynamic_pointer_cast<op::Less>(n) || std::dynamic_pointer_cast<op::LessEq>(n) ||
std::dynamic_pointer_cast<op::Maximum>(n) || std::dynamic_pointer_cast<op::Minimum>(n) ||
std::dynamic_pointer_cast<op::Multiply>(n) || std::dynamic_pointer_cast<op::NotEqual>(n) ||
std::dynamic_pointer_cast<op::Or>(n) || std::dynamic_pointer_cast<op::Subtract>(n));
}
void pass::ConstantFolding::construct_constant_binary()
......@@ -445,8 +610,12 @@ void pass::ConstantFolding::construct_constant_binary()
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto b = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
auto is_be = [](std::shared_ptr<Node> n) {
return (pattern::has_class<op::util::BinaryElementwiseArithmetic>()(n) ||
pattern::has_class<op::util::BinaryElementwiseComparison>()(n) ||
pattern::has_class<op::util::BinaryElementwiseLogical>()(n));
};
auto be = std::make_shared<pattern::op::Any>(a, is_be, NodeVector{a, b});
auto constant_binary_callback = [&, a, b](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_binary_callback against node = "
......@@ -474,36 +643,76 @@ void pass::ConstantFolding::construct_constant_binary()
func = handler->second(binary_match.get());
}
auto type = a_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
fold_constant_binary<int>(a_match, b_match, binary_match, func));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
fold_constant_binary<int8_t>(a_match, b_match, binary_match, func));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
fold_constant_binary<float>(a_match, b_match, binary_match, func));
return true;
}
else if (type == element::f64)
std::shared_ptr<Node> replacement;
auto in_type = a_match->get_output_element_type(0);
auto out_type = binary_match->get_output_element_type(0);
switch (in_type.get_type_enum())
{
replace_node(m.get_match_root(),
fold_constant_binary<double>(a_match, b_match, binary_match, func));
return true;
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_binary_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_binary_callback");
break;
case element::Type_t::boolean:
replacement =
fold_constant_binary_helper<char>(out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_binary_helper<bfloat16>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_binary_helper<float16>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::f32:
replacement =
fold_constant_binary_helper<float>(out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::f64:
replacement =
fold_constant_binary_helper<double>(out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::i8:
replacement =
fold_constant_binary_helper<int8_t>(out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_binary_helper<int16_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_binary_helper<int32_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_binary_helper<int64_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_binary_helper<uint8_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_binary_helper<uint16_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_binary_helper<uint32_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_binary_helper<uint64_t>(
out_type, a_match, b_match, binary_match, func);
break;
}
return false;
replace_node(m.get_match_root(), replacement);
return true;
};
auto reshape_matcher = make_shared<pattern::Matcher>(bea, "ConstantFolding.ConstantBinary");
auto reshape_matcher = make_shared<pattern::Matcher>(be, "ConstantFolding.ConstantBinary");
this->add_matcher(
reshape_matcher, constant_binary_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
......
......@@ -470,6 +470,68 @@ namespace ngraph
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sqrt);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Equal)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::equal);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::NotEqual)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::not_equal);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Greater)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::greater);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::GreaterEq)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::greater_eq);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Less)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::less);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::LessEq)
{
BUILD_BINARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::less_eq);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::And)
{
auto element_count = shape_size(node->get_shape());
auto functor = [&, element_count](const std::vector<void*>& inputs,
std::vector<void*>& outputs) {
runtime::cpu::kernel::logical_and(
inputs[0], inputs[1], outputs[0], element_count, 0);
};
return functor;
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Or)
{
auto element_count = shape_size(node->get_shape());
auto functor = [&, element_count](const std::vector<void*>& inputs,
std::vector<void*>& outputs) {
runtime::cpu::kernel::logical_or(
inputs[0], inputs[1], outputs[0], element_count, 0);
};
return functor;
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Sign)
{
......@@ -542,6 +604,14 @@ namespace ngraph
REGISTER_CF_BUILDER(Negative);
REGISTER_CF_BUILDER(Relu);
REGISTER_CF_BUILDER(Sqrt);
REGISTER_CF_BUILDER(Equal);
REGISTER_CF_BUILDER(NotEqual);
REGISTER_CF_BUILDER(Greater);
REGISTER_CF_BUILDER(GreaterEq);
REGISTER_CF_BUILDER(Less);
REGISTER_CF_BUILDER(LessEq);
REGISTER_CF_BUILDER(And);
REGISTER_CF_BUILDER(Or);
REGISTER_CF_BUILDER(Sign);
}
}
......
......@@ -67,9 +67,12 @@ namespace ngraph
}
}
// In English: return type is void and T must be a floating point type.
// In English: return type is void and T must be a standard floating point type, or
// bfloat16, or float16.
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value>::type
typename std::enable_if<std::is_floating_point<T>::value ||
std::is_same<T, bfloat16>::value ||
std::is_same<T, float16>::value>::type
divide(const T* arg0, const T* arg1, T* out, size_t count, bool pythondiv)
{
(void)pythondiv;
......
......@@ -459,6 +459,214 @@ TEST(constant_folding, const_concat)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_equal)
{
auto constant0 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
auto constant1 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
auto eq = make_shared<op::Equal>(constant0, constant1);
auto f = make_shared<Function>(eq, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Equal>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 1, 0, 0, 1, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_not_equal)
{
auto constant0 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
auto constant1 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 2, 3, 5, 6});
auto eq = make_shared<op::NotEqual>(constant0, constant1);
auto f = make_shared<Function>(eq, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::NotEqual>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 0, 1, 1, 0, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_greater)
{
auto constant0 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
auto constant1 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
auto eq = make_shared<op::Greater>(constant0, constant1);
auto f = make_shared<Function>(eq, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Greater>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 0, 1, 0, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_greater_eq)
{
auto constant0 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
auto constant1 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
auto eq = make_shared<op::GreaterEq>(constant0, constant1);
auto f = make_shared<Function>(eq, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::GreaterEq>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 1, 1, 0, 1, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_less)
{
auto constant0 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
auto constant1 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
auto eq = make_shared<op::Less>(constant0, constant1);
auto f = make_shared<Function>(eq, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Less>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 0, 0, 1, 0, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_less_eq)
{
auto constant0 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{1, 2, 3, 4, 5, 6});
auto constant1 =
op::Constant::create(element::i32, Shape{2, 3}, vector<int32_t>{2, 2, 2, 5, 5, 5});
auto eq = make_shared<op::LessEq>(constant0, constant1);
auto f = make_shared<Function>(eq, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::LessEq>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 1, 0, 1, 1, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_and)
{
auto constant0 =
op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
auto constant1 =
op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
auto eq = make_shared<op::And>(constant0, constant1);
auto f = make_shared<Function>(eq, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::And>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 0, 1, 0, 0, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_or)
{
auto constant0 =
op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 0, 1, 0, 1, 1});
auto constant1 =
op::Constant::create(element::boolean, Shape{2, 3}, vector<int32_t>{0, 1, 1, 1, 0, 1});
auto eq = make_shared<op::Or>(constant0, constant1);
auto f = make_shared<Function>(eq, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Or>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{0, 1, 1, 1, 1, 1};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment