Commit 68f6110c authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Smaller set of constant folding kernels (#3653)

* Instantiate only valid cases for binary op constant folding kernels

* removed unused parameter
parent deb4baef
......@@ -49,11 +49,10 @@
using namespace std;
using namespace ngraph;
template <class Tin, class Tout>
shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary,
NodeExecutorTy func)
static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary,
NodeExecutorTy func)
{
auto out_shape = binary->get_shape();
......@@ -61,7 +60,7 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
// auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape())
{
vector<Tout> out_vec(shape_size(out_shape));
vector<char> out_vec(shape_size(out_shape));
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
......@@ -73,51 +72,74 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
}
else
{
if (auto add_node = as_type_ptr<op::Add>(binary))
if (auto and_node = as_type_ptr<op::And>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
add_node->get_autob());
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
and_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto and_node = as_type_ptr<op::And>(binary))
else if (auto or_node = as_type_ptr<op::Or>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::logical_and<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
and_node->get_autob());
or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto divide_node = as_type_ptr<op::Divide>(binary))
else if (auto xor_node = as_type_ptr<op::Xor>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
shared_ptr<op::Divide> divop = as_type_ptr<op::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
divide_node->get_autob(),
pythondiv);
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto equal_node = as_type_ptr<op::Equal>(binary))
else
{
NGRAPH_CHECK(
false,
"fold_constant_binary_logical must be consistent with is_supported_binary_op");
}
}
}
template <class Tin>
shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary,
NodeExecutorTy func)
{
auto out_shape = binary->get_shape();
// NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape())
{
vector<char> out_vec(shape_size(out_shape));
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
if (auto equal_node = as_type_ptr<op::Equal>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
......@@ -129,7 +151,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
}
else if (auto greater_node = as_type_ptr<op::Greater>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
......@@ -141,7 +162,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
}
else if (auto greater_eq_node = as_type_ptr<op::GreaterEq>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
......@@ -153,7 +173,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
}
else if (auto less_node = as_type_ptr<op::Less>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
......@@ -165,7 +184,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
}
else if (auto less_eq_node = as_type_ptr<op::LessEq>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
......@@ -175,11 +193,83 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
less_eq_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto not_equal_node = as_type_ptr<op::NotEqual>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
not_equal_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
NGRAPH_CHECK(false,
"fold_constant_binary must be consistent with is_supported_binary_op");
}
}
}
template <class Tin, class Tout = Tin>
shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary,
NodeExecutorTy func)
{
auto out_shape = binary->get_shape();
// NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape())
{
vector<Tout> out_vec(shape_size(out_shape));
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
if (auto add_node = as_type_ptr<op::Add>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
add_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto divide_node = as_type_ptr<op::Divide>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::Divide> divop = as_type_ptr<op::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
divide_node->get_autob(),
pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto maximum_node = as_type_ptr<op::Maximum>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
......@@ -192,7 +282,7 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
......@@ -205,7 +295,7 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
......@@ -214,36 +304,11 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
multiply_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto not_equal_node = as_type_ptr<op::NotEqual>(binary))
{
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape));
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
not_equal_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto or_node = as_type_ptr<op::Or>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::logical_or<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto subtract_node = as_type_ptr<op::Subtract>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::subtract<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
......@@ -252,19 +317,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
subtract_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto xor_node = as_type_ptr<op::Xor>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
NGRAPH_CHECK(false,
......@@ -274,35 +326,26 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
}
template <class Tin>
shared_ptr<op::Constant> fold_constant_binary_helper(const element::Type& et_out,
shared_ptr<op::Constant> a,
shared_ptr<op::Constant> fold_constant_binary_helper(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary,
NodeExecutorTy func)
{
switch (et_out)
if (binary->is_binary_elementwise_comparison())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_binary_callback");
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_binary_callback");
case element::Type_t::boolean: return fold_constant_binary<Tin, char>(a, b, binary, func);
case element::Type_t::bf16: return fold_constant_binary<Tin, bfloat16>(a, b, binary, func);
case element::Type_t::f16: return fold_constant_binary<Tin, float16>(a, b, binary, func);
case element::Type_t::f32: return fold_constant_binary<Tin, float>(a, b, binary, func);
case element::Type_t::f64: return fold_constant_binary<Tin, double>(a, b, binary, func);
case element::Type_t::i8: return fold_constant_binary<Tin, int8_t>(a, b, binary, func);
case element::Type_t::i16: return fold_constant_binary<Tin, int16_t>(a, b, binary, func);
case element::Type_t::i32: return fold_constant_binary<Tin, int32_t>(a, b, binary, func);
case element::Type_t::i64: return fold_constant_binary<Tin, int64_t>(a, b, binary, func);
case element::Type_t::u8: return fold_constant_binary<Tin, uint8_t>(a, b, binary, func);
case element::Type_t::u16: return fold_constant_binary<Tin, uint16_t>(a, b, binary, func);
case element::Type_t::u32: return fold_constant_binary<Tin, uint32_t>(a, b, binary, func);
case element::Type_t::u64: return fold_constant_binary<Tin, uint64_t>(a, b, binary, func);
return fold_constant_binary_comparison<Tin>(a, b, binary, func);
}
else if (binary->is_binary_elementwise_arithmetic())
{
return fold_constant_binary_arithmetic<Tin>(a, b, binary, func);
}
else
{
NGRAPH_CHECK(
false, "fold_constant_binary_helper only available for comparison and arithmetic ops");
}
NGRAPH_UNREACHABLE("Unreachable switch case");
}
bool is_supported_binary_op(std::shared_ptr<Node> n)
{
return (is_type<op::Add>(n) || is_type<op::And>(n) || is_type<op::Divide>(n) ||
......@@ -353,68 +396,78 @@ void pass::ConstantFolding::construct_constant_binary()
}
std::shared_ptr<Node> replacement;
auto in_type = a_match->get_output_element_type(0);
auto out_type = binary_match->get_output_element_type(0);
switch (in_type)
if (binary_match->is_binary_elementwise_logical())
{
replacement = fold_constant_binary_logical(a_match, b_match, binary_match, func);
}
else
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_binary_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_binary_callback");
break;
case element::Type_t::boolean:
replacement =
fold_constant_binary_helper<char>(out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::bf16:
replacement = fold_constant_binary_helper<bfloat16>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::f16:
replacement = fold_constant_binary_helper<float16>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::f32:
replacement =
fold_constant_binary_helper<float>(out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::f64:
replacement =
fold_constant_binary_helper<double>(out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::i8:
replacement =
fold_constant_binary_helper<int8_t>(out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::i16:
replacement = fold_constant_binary_helper<int16_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::i32:
replacement = fold_constant_binary_helper<int32_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::i64:
replacement = fold_constant_binary_helper<int64_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::u8:
replacement = fold_constant_binary_helper<uint8_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::u16:
replacement = fold_constant_binary_helper<uint16_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::u32:
replacement = fold_constant_binary_helper<uint32_t>(
out_type, a_match, b_match, binary_match, func);
break;
case element::Type_t::u64:
replacement = fold_constant_binary_helper<uint64_t>(
out_type, a_match, b_match, binary_match, func);
break;
auto in_type = a_match->get_output_element_type(0);
auto out_type = binary_match->get_output_element_type(0);
switch (in_type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_binary_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_binary_callback");
break;
case element::Type_t::boolean:
replacement =
fold_constant_binary_helper<char>(a_match, b_match, binary_match, func);
break;
case element::Type_t::bf16:
replacement =
fold_constant_binary_helper<bfloat16>(a_match, b_match, binary_match, func);
break;
case element::Type_t::f16:
replacement =
fold_constant_binary_helper<float16>(a_match, b_match, binary_match, func);
break;
case element::Type_t::f32:
replacement =
fold_constant_binary_helper<float>(a_match, b_match, binary_match, func);
break;
case element::Type_t::f64:
replacement =
fold_constant_binary_helper<double>(a_match, b_match, binary_match, func);
break;
case element::Type_t::i8:
replacement =
fold_constant_binary_helper<int8_t>(a_match, b_match, binary_match, func);
break;
case element::Type_t::i16:
replacement =
fold_constant_binary_helper<int16_t>(a_match, b_match, binary_match, func);
break;
case element::Type_t::i32:
replacement =
fold_constant_binary_helper<int32_t>(a_match, b_match, binary_match, func);
break;
case element::Type_t::i64:
replacement =
fold_constant_binary_helper<int64_t>(a_match, b_match, binary_match, func);
break;
case element::Type_t::u8:
replacement =
fold_constant_binary_helper<uint8_t>(a_match, b_match, binary_match, func);
break;
case element::Type_t::u16:
replacement =
fold_constant_binary_helper<uint16_t>(a_match, b_match, binary_match, func);
break;
case element::Type_t::u32:
replacement =
fold_constant_binary_helper<uint32_t>(a_match, b_match, binary_match, func);
break;
case element::Type_t::u64:
replacement =
fold_constant_binary_helper<uint64_t>(a_match, b_match, binary_match, func);
break;
}
}
replace_node(m.get_match_root(), replacement);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment