Commit 68f6110c authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Smaller set of constant folding kernels (#3653)

* Instantiate only valid cases for binary op constant folding kernels

* removed unused parameter
parent deb4baef
...@@ -49,8 +49,7 @@ ...@@ -49,8 +49,7 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
template <class Tin, class Tout> static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b, shared_ptr<op::Constant> b,
shared_ptr<Node> binary, shared_ptr<Node> binary,
NodeExecutorTy func) NodeExecutorTy func)
...@@ -61,7 +60,7 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -61,7 +60,7 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
// auto-broadcast is in use, and the CPU functors don't yet support that. // auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape()) if (func != nullptr && a->get_shape() == b->get_shape())
{ {
vector<Tout> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
vector<void*> inputs; vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr())); inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr())); inputs.push_back(const_cast<void*>(b->get_data_ptr()));
...@@ -73,51 +72,74 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -73,51 +72,74 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
} }
else else
{ {
if (auto add_node = as_type_ptr<op::Add>(binary)) if (auto and_node = as_type_ptr<op::And>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), vector<char> out_vec(shape_size(out_shape));
"Input/output types do not match"); runtime::reference::logical_and<char>(a->get_data_ptr<char>(),
vector<Tin> out_vec(shape_size(out_shape)); b->get_data_ptr<char>(),
runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
add_node->get_autob()); and_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto and_node = as_type_ptr<op::And>(binary)) else if (auto or_node = as_type_ptr<op::Or>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), vector<char> out_vec(shape_size(out_shape));
"Input/output types do not match"); runtime::reference::logical_or<char>(a->get_data_ptr<char>(),
vector<Tin> out_vec(shape_size(out_shape)); b->get_data_ptr<char>(),
runtime::reference::logical_and<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
and_node->get_autob()); or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto divide_node = as_type_ptr<op::Divide>(binary)) else if (auto xor_node = as_type_ptr<op::Xor>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), vector<char> out_vec(shape_size(out_shape));
"Input/output types do not match"); runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
vector<Tin> out_vec(shape_size(out_shape)); b->get_data_ptr<char>(),
shared_ptr<op::Divide> divop = as_type_ptr<op::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
divide_node->get_autob(), xor_node->get_autob());
pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto equal_node = as_type_ptr<op::Equal>(binary)) else
{
NGRAPH_CHECK(
false,
"fold_constant_binary_logical must be consistent with is_supported_binary_op");
}
}
}
template <class Tin>
shared_ptr<op::Constant> fold_constant_binary_comparison(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary,
NodeExecutorTy func)
{
auto out_shape = binary->get_shape();
// NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape())
{
vector<char> out_vec(shape_size(out_shape));
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
if (auto equal_node = as_type_ptr<op::Equal>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(), runtime::reference::equal<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
...@@ -129,7 +151,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -129,7 +151,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
} }
else if (auto greater_node = as_type_ptr<op::Greater>(binary)) else if (auto greater_node = as_type_ptr<op::Greater>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
...@@ -141,7 +162,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -141,7 +162,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
} }
else if (auto greater_eq_node = as_type_ptr<op::GreaterEq>(binary)) else if (auto greater_eq_node = as_type_ptr<op::GreaterEq>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::greater_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
...@@ -153,7 +173,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -153,7 +173,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
} }
else if (auto less_node = as_type_ptr<op::Less>(binary)) else if (auto less_node = as_type_ptr<op::Less>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::less<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
...@@ -165,7 +184,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -165,7 +184,6 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
} }
else if (auto less_eq_node = as_type_ptr<op::LessEq>(binary)) else if (auto less_eq_node = as_type_ptr<op::LessEq>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean");
vector<char> out_vec(shape_size(out_shape)); vector<char> out_vec(shape_size(out_shape));
runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(), runtime::reference::less_eq<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
...@@ -175,94 +193,128 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -175,94 +193,128 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
less_eq_node->get_autob()); less_eq_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto maximum_node = as_type_ptr<op::Maximum>(binary)) else if (auto not_equal_node = as_type_ptr<op::NotEqual>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), vector<char> out_vec(shape_size(out_shape));
"Input/output types do not match"); runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(),
vector<Tin> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
maximum_node->get_autob()); not_equal_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto minimum_node = as_type_ptr<op::Minimum>(binary)) else
{
NGRAPH_CHECK(false,
"fold_constant_binary must be consistent with is_supported_binary_op");
}
}
}
template <class Tin, class Tout = Tin>
shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary,
NodeExecutorTy func)
{
auto out_shape = binary->get_shape();
// NOTE: We will skip the executor if the shapes do not match, because that means
// auto-broadcast is in use, and the CPU functors don't yet support that.
if (func != nullptr && a->get_shape() == b->get_shape())
{
vector<Tout> out_vec(shape_size(out_shape));
vector<void*> inputs;
inputs.push_back(const_cast<void*>(a->get_data_ptr()));
inputs.push_back(const_cast<void*>(b->get_data_ptr()));
vector<void*> outputs;
outputs.push_back(out_vec.data());
func(inputs, outputs);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else
{
if (auto add_node = as_type_ptr<op::Add>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape)); vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(), runtime::reference::add<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
minimum_node->get_autob()); add_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto multiply_node = as_type_ptr<op::Multiply>(binary)) else if (auto divide_node = as_type_ptr<op::Divide>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape)); vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(), shared_ptr<op::Divide> divop = as_type_ptr<op::Divide>(binary);
bool pythondiv = divop->is_pythondiv();
runtime::reference::divide<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
multiply_node->get_autob()); divide_node->get_autob(),
pythondiv);
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto not_equal_node = as_type_ptr<op::NotEqual>(binary)) else if (auto maximum_node = as_type_ptr<op::Maximum>(binary))
{ {
NGRAPH_CHECK(element::from<Tout>() == element::boolean, "Output type is not boolean"); NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
vector<char> out_vec(shape_size(out_shape)); "Input/output types do not match");
runtime::reference::not_equal<Tin>(a->get_data_ptr<Tin>(), vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::maximum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
not_equal_node->get_autob()); maximum_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto or_node = as_type_ptr<op::Or>(binary)) else if (auto minimum_node = as_type_ptr<op::Minimum>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape)); vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::logical_or<Tin>(a->get_data_ptr<Tin>(), runtime::reference::minimum<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
or_node->get_autob()); minimum_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto subtract_node = as_type_ptr<op::Subtract>(binary)) else if (auto multiply_node = as_type_ptr<op::Multiply>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape)); vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::subtract<Tin>(a->get_data_ptr<Tin>(), runtime::reference::multiply<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
subtract_node->get_autob()); multiply_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else if (auto xor_node = as_type_ptr<op::Xor>(binary)) else if (auto subtract_node = as_type_ptr<op::Subtract>(binary))
{ {
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(), NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match"); "Input/output types do not match");
vector<Tin> out_vec(shape_size(out_shape)); vector<Tout> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<Tin>(a->get_data_ptr<Tin>(), runtime::reference::subtract<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(), b->get_data_ptr<Tin>(),
out_vec.data(), out_vec.data(),
a->get_shape(), a->get_shape(),
b->get_shape(), b->get_shape(),
xor_node->get_autob()); subtract_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec); return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
} }
else else
...@@ -274,35 +326,26 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a, ...@@ -274,35 +326,26 @@ shared_ptr<op::Constant> fold_constant_binary(shared_ptr<op::Constant> a,
} }
template <class Tin> template <class Tin>
shared_ptr<op::Constant> fold_constant_binary_helper(const element::Type& et_out, shared_ptr<op::Constant> fold_constant_binary_helper(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b, shared_ptr<op::Constant> b,
shared_ptr<Node> binary, shared_ptr<Node> binary,
NodeExecutorTy func) NodeExecutorTy func)
{ {
switch (et_out) if (binary->is_binary_elementwise_comparison())
{ {
case element::Type_t::undefined: return fold_constant_binary_comparison<Tin>(a, b, binary, func);
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_binary_callback"); }
case element::Type_t::dynamic: else if (binary->is_binary_elementwise_arithmetic())
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_binary_callback"); {
case element::Type_t::boolean: return fold_constant_binary<Tin, char>(a, b, binary, func); return fold_constant_binary_arithmetic<Tin>(a, b, binary, func);
case element::Type_t::bf16: return fold_constant_binary<Tin, bfloat16>(a, b, binary, func); }
case element::Type_t::f16: return fold_constant_binary<Tin, float16>(a, b, binary, func); else
case element::Type_t::f32: return fold_constant_binary<Tin, float>(a, b, binary, func); {
case element::Type_t::f64: return fold_constant_binary<Tin, double>(a, b, binary, func); NGRAPH_CHECK(
case element::Type_t::i8: return fold_constant_binary<Tin, int8_t>(a, b, binary, func); false, "fold_constant_binary_helper only available for comparison and arithmetic ops");
case element::Type_t::i16: return fold_constant_binary<Tin, int16_t>(a, b, binary, func);
case element::Type_t::i32: return fold_constant_binary<Tin, int32_t>(a, b, binary, func);
case element::Type_t::i64: return fold_constant_binary<Tin, int64_t>(a, b, binary, func);
case element::Type_t::u8: return fold_constant_binary<Tin, uint8_t>(a, b, binary, func);
case element::Type_t::u16: return fold_constant_binary<Tin, uint16_t>(a, b, binary, func);
case element::Type_t::u32: return fold_constant_binary<Tin, uint32_t>(a, b, binary, func);
case element::Type_t::u64: return fold_constant_binary<Tin, uint64_t>(a, b, binary, func);
} }
NGRAPH_UNREACHABLE("Unreachable switch case");
} }
bool is_supported_binary_op(std::shared_ptr<Node> n) bool is_supported_binary_op(std::shared_ptr<Node> n)
{ {
return (is_type<op::Add>(n) || is_type<op::And>(n) || is_type<op::Divide>(n) || return (is_type<op::Add>(n) || is_type<op::And>(n) || is_type<op::Divide>(n) ||
...@@ -353,69 +396,79 @@ void pass::ConstantFolding::construct_constant_binary() ...@@ -353,69 +396,79 @@ void pass::ConstantFolding::construct_constant_binary()
} }
std::shared_ptr<Node> replacement; std::shared_ptr<Node> replacement;
if (binary_match->is_binary_elementwise_logical())
{
replacement = fold_constant_binary_logical(a_match, b_match, binary_match, func);
}
else
{
auto in_type = a_match->get_output_element_type(0); auto in_type = a_match->get_output_element_type(0);
auto out_type = binary_match->get_output_element_type(0); auto out_type = binary_match->get_output_element_type(0);
switch (in_type) switch (in_type)
{ {
case element::Type_t::undefined: case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in constant_binary_callback"); NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_binary_callback");
break; break;
case element::Type_t::dynamic: case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_binary_callback"); NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_binary_callback");
break; break;
case element::Type_t::boolean: case element::Type_t::boolean:
replacement = replacement =
fold_constant_binary_helper<char>(out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<char>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::bf16: case element::Type_t::bf16:
replacement = fold_constant_binary_helper<bfloat16>( replacement =
out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<bfloat16>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::f16: case element::Type_t::f16:
replacement = fold_constant_binary_helper<float16>( replacement =
out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<float16>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::f32: case element::Type_t::f32:
replacement = replacement =
fold_constant_binary_helper<float>(out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<float>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::f64: case element::Type_t::f64:
replacement = replacement =
fold_constant_binary_helper<double>(out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<double>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::i8: case element::Type_t::i8:
replacement = replacement =
fold_constant_binary_helper<int8_t>(out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<int8_t>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::i16: case element::Type_t::i16:
replacement = fold_constant_binary_helper<int16_t>( replacement =
out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<int16_t>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::i32: case element::Type_t::i32:
replacement = fold_constant_binary_helper<int32_t>( replacement =
out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<int32_t>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::i64: case element::Type_t::i64:
replacement = fold_constant_binary_helper<int64_t>( replacement =
out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<int64_t>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::u8: case element::Type_t::u8:
replacement = fold_constant_binary_helper<uint8_t>( replacement =
out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<uint8_t>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::u16: case element::Type_t::u16:
replacement = fold_constant_binary_helper<uint16_t>( replacement =
out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<uint16_t>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::u32: case element::Type_t::u32:
replacement = fold_constant_binary_helper<uint32_t>( replacement =
out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<uint32_t>(a_match, b_match, binary_match, func);
break; break;
case element::Type_t::u64: case element::Type_t::u64:
replacement = fold_constant_binary_helper<uint64_t>( replacement =
out_type, a_match, b_match, binary_match, func); fold_constant_binary_helper<uint64_t>(a_match, b_match, binary_match, func);
break; break;
} }
}
replace_node(m.get_match_root(), replacement); replace_node(m.get_match_root(), replacement);
return true; return true;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment