Commit bc448701 authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Scott Cyphers

Added constant folding for binary ops (#3895)

parent 1d53977a
...@@ -74,7 +74,7 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons ...@@ -74,7 +74,7 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
} }
else else
{ {
if (auto and_node = as_type_ptr<op::And>(binary)) if (auto and_v0_node = as_type_ptr<op::v0::And>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_and<char>(a->get_data_ptr<char>(), runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
...@@ -82,21 +82,21 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons ...@@ -82,21 +82,21 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
and_node->get_autob()); and_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto logical_xor_node = as_type_ptr<op::v1::LogicalXor>(binary)) else if (auto logical_and_node = as_type_ptr<op::v1::LogicalAnd>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(), runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(), b->get_data_ptr<char>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
logical_xor_node->get_autob()); logical_and_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto or_node = as_type_ptr<op::Or>(binary)) else if (auto or_node = as_type_ptr<op::v0::Or>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_or<char>(a->get_data_ptr<char>(), runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
...@@ -107,6 +107,17 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons ...@@ -107,6 +107,17 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
or_node->get_autob()); or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto logical_or_node = as_type_ptr<op::v1::LogicalOr>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
logical_or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto xor_node = as_type_ptr<op::v0::Xor>(binary)) else if (auto xor_node = as_type_ptr<op::v0::Xor>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
...@@ -118,6 +129,17 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons ...@@ -118,6 +129,17 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
xor_node->get_autob()); xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto logical_xor_node = as_type_ptr<op::v1::LogicalXor>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
logical_xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else else
{ {
NGRAPH_CHECK( NGRAPH_CHECK(
...@@ -151,7 +173,18 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant ...@@ -151,7 +173,18 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
} }
else else
{ {
if (auto equal_node = as_type_ptr<op::Equal>(binary)) if (auto equal_v0_node = as_type_ptr<op::v0::Equal>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
equal_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto equal_v1_node = as_type_ptr<op::v1::Equal>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(), runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
...@@ -159,10 +192,10 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant ...@@ -159,10 +192,10 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
equal_node->get_autob()); equal_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto greater_node = as_type_ptr<op::Greater>(binary)) else if (auto greater_v0_node = as_type_ptr<op::v0::Greater>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
...@@ -170,10 +203,32 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant ...@@ -170,10 +203,32 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
greater_node->get_autob()); greater_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto greater_eq_node = as_type_ptr<op::GreaterEq>(binary)) else if (auto greater_v1_node = as_type_ptr<op::v1::Greater>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
greater_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto greater_eq_v0_node = as_type_ptr<op::v0::GreaterEq>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
greater_eq_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto greater_eq_v1_node = as_type_ptr<op::v1::GreaterEq>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
...@@ -181,10 +236,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant ...@@ -181,10 +236,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
greater_eq_node->get_autob()); greater_eq_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto less_v0_node = as_type_ptr<op::v0::Less>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
less_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto less_node = as_type_ptr<op::Less>(binary)) else if (auto less_v1_node = as_type_ptr<op::v1::Less>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
...@@ -192,10 +258,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant ...@@ -192,10 +258,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
less_node->get_autob()); less_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto less_eq_v0_node = as_type_ptr<op::v0::LessEq>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
less_eq_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto less_eq_node = as_type_ptr<op::LessEq>(binary)) else if (auto less_eq_v1_node = as_type_ptr<op::v1::LessEqual>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
...@@ -203,10 +280,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant ...@@ -203,10 +280,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
less_eq_node->get_autob()); less_eq_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto not_equal_v0_node = as_type_ptr<op::v0::NotEqual>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
not_equal_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto not_equal_node = as_type_ptr<op::NotEqual>(binary)) else if (auto not_equal_v1_node = as_type_ptr<op::v1::NotEqual>(binary))
{ {
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(), runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
...@@ -214,7 +302,7 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant ...@@ -214,7 +302,7 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
not_equal_node->get_autob()); not_equal_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else else
...@@ -249,7 +337,7 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant ...@@ -249,7 +337,7 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
} }
else else
{ {
if (auto add_node = as_type_ptr<op::Add>(binary)) if (auto add_v0_node = as_type_ptr<op::v0::Add>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -259,26 +347,55 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant ...@@ -259,26 +347,55 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
add_node->get_autob()); add_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto divide_node = as_type_ptr<op::Divide>(binary)) else if (auto add_v1_node = as_type_ptr<op::v1::Add>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape)); vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::Divide> divop = as_type_ptr<op::Divide>(binary); runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
add_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto divide_v0_node = as_type_ptr<op::v0::Divide>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v0::Divide> divop = as_type_ptr<op::v0::Divide>(binary);
bool pythondiv = divop->is_pythondiv(); bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(), runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
divide_node->get_autob(), divide_v0_node->get_autob(),
pythondiv); pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto maximum_node = as_type_ptr<op::Maximum>(binary)) else if (auto divide_v1_node = as_type_ptr<op::v1::Divide>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v1::Divide> divop = as_type_ptr<op::v1::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
divide_v1_node->get_autob(),
pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto maximum_v0_node = as_type_ptr<op::v0::Maximum>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -288,10 +405,36 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant ...@@ -288,10 +405,36 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
maximum_node->get_autob()); maximum_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto minimum_node = as_type_ptr<op::Minimum>(binary)) else if (auto maximum_v1_node = as_type_ptr<op::v1::Maximum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
maximum_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto minimum_v0_node = as_type_ptr<op::v0::Minimum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
minimum_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto minimum_v1_node = as_type_ptr<op::v1::Minimum>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -301,10 +444,23 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant ...@@ -301,10 +444,23 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
minimum_node->get_autob()); minimum_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto multiply_v0_node = as_type_ptr<op::v0::Multiply>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
multiply_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto multiply_node = as_type_ptr<op::Multiply>(binary)) else if (auto multiply_v1_node = as_type_ptr<op::v1::Multiply>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
...@@ -314,21 +470,35 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant ...@@ -314,21 +470,35 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
multiply_node->get_autob()); multiply_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto power_v0_node = as_type_ptr<op::v0::Power>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v0::Power> powop = as_type_ptr<op::v0::Power>(binary);
runtime::reference::power<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
power_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto power_node = as_type_ptr<op::Power>(binary)) else if (auto power_v1_node = as_type_ptr<op::v1::Power>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape)); vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::Power> powop = as_type_ptr<op::Power>(binary); shared_ptr<op::v1::Power> powop = as_type_ptr<op::v1::Power>(binary);
runtime::reference::power<Tin>(a->get_data_ptr<Tin>(), runtime::reference::power<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
power_node->get_autob()); power_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto subtract_node = as_type_ptr<op::Subtract>(binary)) else if (auto subtract_node = as_type_ptr<op::Subtract>(binary))
...@@ -375,12 +545,19 @@ shared_ptr<op::Constant> fold_constant_binary_helper(shared_ptr<op::Constant> a, ...@@ -375,12 +545,19 @@ shared_ptr<op::Constant> fold_constant_binary_helper(shared_ptr<op::Constant> a,
bool is_supported_binary_op(std::shared_ptr<Node> n) bool is_supported_binary_op(std::shared_ptr<Node> n)
{ {
return (is_type<op::Add>(n) || is_type<op::And>(n) || is_type<op::Divide>(n) || return (
is_type<op::Equal>(n) || is_type<op::Greater>(n) || is_type<op::GreaterEq>(n) || is_type<op::v0::Add>(n) || is_type<op::v1::Add>(n) || is_type<op::v0::Multiply>(n) ||
is_type<op::Less>(n) || is_type<op::LessEq>(n) || is_type<op::Maximum>(n) || is_type<op::v1::Multiply>(n) || is_type<op::v0::Divide>(n) || is_type<op::v1::Divide>(n) ||
is_type<op::Minimum>(n) || is_type<op::Multiply>(n) || is_type<op::NotEqual>(n) || is_type<op::v0::Power>(n) || is_type<op::v1::Power>(n) || is_type<op::v0::Equal>(n) ||
is_type<op::Or>(n) || is_type<op::Power>(n) || is_type<op::Subtract>(n) || is_type<op::v1::Equal>(n) || is_type<op::v0::NotEqual>(n) || is_type<op::v1::NotEqual>(n) ||
is_type<op::Xor>(n)); is_type<op::v0::Greater>(n) || is_type<op::v1::Greater>(n) ||
is_type<op::v0::GreaterEq>(n) || is_type<op::v1::GreaterEq>(n) ||
is_type<op::v0::Less>(n) || is_type<op::v1::Less>(n) || is_type<op::v0::LessEq>(n) ||
is_type<op::v1::LessEqual>(n) || is_type<op::v0::Maximum>(n) ||
is_type<op::v1::Maximum>(n) || is_type<op::v0::Minimum>(n) || is_type<op::v1::Minimum>(n) ||
is_type<op::v0::And>(n) || is_type<op::v1::LogicalAnd>(n) || is_type<op::v0::Or>(n) ||
is_type<op::v1::LogicalOr>(n) || is_type<op::v0::Xor>(n) ||
is_type<op::v1::LogicalXor>(n) || is_type<op::Subtract>(n));
} }
void pass::ConstantFolding::construct_constant_binary() void pass::ConstantFolding::construct_constant_binary()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment