Commit bc448701 authored by Gleb Kazantaev's avatar Gleb Kazantaev Committed by Scott Cyphers

Added constant folding for binary ops (#3895)

parent 1d53977a
......@@ -74,7 +74,7 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
}
else
{
if (auto and_node = as_type_ptr<op::And>(binary))
if (auto and_v0_node = as_type_ptr<op::v0::And>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
......@@ -82,21 +82,21 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
out_vec.data(),
a->get_shape(),
b->get_shape(),
and_node->get_autob());
and_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto logical_xor_node = as_type_ptr<op::v1::LogicalXor>(binary))
else if (auto logical_and_node = as_type_ptr<op::v1::LogicalAnd>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
logical_xor_node->get_autob());
logical_and_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto or_node = as_type_ptr<op::Or>(binary))
else if (auto or_node = as_type_ptr<op::v0::Or>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
......@@ -107,6 +107,17 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto logical_or_node = as_type_ptr<op::v1::LogicalOr>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
logical_or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto xor_node = as_type_ptr<op::v0::Xor>(binary))
{
vector<char> out_vec(shape_size(out_shape));
......@@ -118,6 +129,17 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto logical_xor_node = as_type_ptr<op::v1::LogicalXor>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
logical_xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
NGRAPH_CHECK(
......@@ -151,7 +173,18 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
}
else
{
if (auto equal_node = as_type_ptr<op::Equal>(binary))
if (auto equal_v0_node = as_type_ptr<op::v0::Equal>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
equal_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto equal_v1_node = as_type_ptr<op::v1::Equal>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
......@@ -159,10 +192,10 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
equal_node->get_autob());
equal_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto greater_node = as_type_ptr<op::Greater>(binary))
else if (auto greater_v0_node = as_type_ptr<op::v0::Greater>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
......@@ -170,10 +203,32 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
greater_node->get_autob());
greater_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto greater_eq_node = as_type_ptr<op::GreaterEq>(binary))
else if (auto greater_v1_node = as_type_ptr<op::v1::Greater>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
greater_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto greater_eq_v0_node = as_type_ptr<op::v0::GreaterEq>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
greater_eq_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto greater_eq_v1_node = as_type_ptr<op::v1::GreaterEq>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
......@@ -181,10 +236,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
greater_eq_node->get_autob());
greater_eq_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto less_v0_node = as_type_ptr<op::v0::Less>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
less_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto less_node = as_type_ptr<op::Less>(binary))
else if (auto less_v1_node = as_type_ptr<op::v1::Less>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
......@@ -192,10 +258,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
less_node->get_autob());
less_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto less_eq_v0_node = as_type_ptr<op::v0::LessEq>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
less_eq_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto less_eq_node = as_type_ptr<op::LessEq>(binary))
else if (auto less_eq_v1_node = as_type_ptr<op::v1::LessEqual>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
......@@ -203,10 +280,21 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
less_eq_node->get_autob());
less_eq_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto not_equal_v0_node = as_type_ptr<op::v0::NotEqual>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
not_equal_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto not_equal_node = as_type_ptr<op::NotEqual>(binary))
else if (auto not_equal_v1_node = as_type_ptr<op::v1::NotEqual>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
......@@ -214,7 +302,7 @@ shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
not_equal_node->get_autob());
not_equal_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
......@@ -249,7 +337,7 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
}
else
{
if (auto add_node = as_type_ptr<op::Add>(binary))
if (auto add_v0_node = as_type_ptr<op::v0::Add>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
......@@ -259,26 +347,55 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
add_node->get_autob());
add_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto divide_node = as_type_ptr<op::Divide>(binary))
else if (auto add_v1_node = as_type_ptr<op::v1::Add>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::Divide> divop = as_type_ptr<op::Divide>(binary);
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
add_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto divide_v0_node = as_type_ptr<op::v0::Divide>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v0::Divide> divop = as_type_ptr<op::v0::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
divide_node->get_autob(),
divide_v0_node->get_autob(),
pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto maximum_node = as_type_ptr<op::Maximum>(binary))
else if (auto divide_v1_node = as_type_ptr<op::v1::Divide>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v1::Divide> divop = as_type_ptr<op::v1::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
divide_v1_node->get_autob(),
pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto maximum_v0_node = as_type_ptr<op::v0::Maximum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
......@@ -288,10 +405,36 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
maximum_node->get_autob());
maximum_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto minimum_node = as_type_ptr<op::Minimum>(binary))
else if (auto maximum_v1_node = as_type_ptr<op::v1::Maximum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
maximum_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto minimum_v0_node = as_type_ptr<op::v0::Minimum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
minimum_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto minimum_v1_node = as_type_ptr<op::v1::Minimum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
......@@ -301,10 +444,23 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
minimum_node->get_autob());
minimum_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto multiply_v0_node = as_type_ptr<op::v0::Multiply>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
multiply_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto multiply_node = as_type_ptr<op::Multiply>(binary))
else if (auto multiply_v1_node = as_type_ptr<op::v1::Multiply>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
......@@ -314,21 +470,35 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
out_vec.data(),
a->get_shape(),
b->get_shape(),
multiply_node->get_autob());
multiply_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto power_v0_node = as_type_ptr<op::v0::Power>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::v0::Power> powop = as_type_ptr<op::v0::Power>(binary);
runtime::reference::power<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
power_v0_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto power_node = as_type_ptr<op::Power>(binary))
else if (auto power_v1_node = as_type_ptr<op::v1::Power>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::Power> powop = as_type_ptr<op::Power>(binary);
shared_ptr<op::v1::Power> powop = as_type_ptr<op::v1::Power>(binary);
runtime::reference::power<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
power_node->get_autob());
power_v1_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto subtract_node = as_type_ptr<op::Subtract>(binary))
......@@ -375,12 +545,19 @@ shared_ptr<op::Constant> fold_constant_binary_helper(shared_ptr<op::Constant> a,
bool is_supported_binary_op(std::shared_ptr<Node> n)
{
return (is_type<op::Add>(n) || is_type<op::And>(n) || is_type<op::Divide>(n) ||
is_type<op::Equal>(n) || is_type<op::Greater>(n) || is_type<op::GreaterEq>(n) ||
is_type<op::Less>(n) || is_type<op::LessEq>(n) || is_type<op::Maximum>(n) ||
is_type<op::Minimum>(n) || is_type<op::Multiply>(n) || is_type<op::NotEqual>(n) ||
is_type<op::Or>(n) || is_type<op::Power>(n) || is_type<op::Subtract>(n) ||
is_type<op::Xor>(n));
return (
is_type<op::v0::Add>(n) || is_type<op::v1::Add>(n) || is_type<op::v0::Multiply>(n) ||
is_type<op::v1::Multiply>(n) || is_type<op::v0::Divide>(n) || is_type<op::v1::Divide>(n) ||
is_type<op::v0::Power>(n) || is_type<op::v1::Power>(n) || is_type<op::v0::Equal>(n) ||
is_type<op::v1::Equal>(n) || is_type<op::v0::NotEqual>(n) || is_type<op::v1::NotEqual>(n) ||
is_type<op::v0::Greater>(n) || is_type<op::v1::Greater>(n) ||
is_type<op::v0::GreaterEq>(n) || is_type<op::v1::GreaterEq>(n) ||
is_type<op::v0::Less>(n) || is_type<op::v1::Less>(n) || is_type<op::v0::LessEq>(n) ||
is_type<op::v1::LessEqual>(n) || is_type<op::v0::Maximum>(n) ||
is_type<op::v1::Maximum>(n) || is_type<op::v0::Minimum>(n) || is_type<op::v1::Minimum>(n) ||
is_type<op::v0::And>(n) || is_type<op::v1::LogicalAnd>(n) || is_type<op::v0::Or>(n) ||
is_type<op::v1::LogicalOr>(n) || is_type<op::v0::Xor>(n) ||
is_type<op::v1::LogicalXor>(n) || is_type<op::Subtract>(n));
}
void pass::ConstantFolding::construct_constant_binary()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment