Unverified Commit a1ee816e authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

AlgebraicSimplification simplification (#4326)

* Start of rework

* Faster multiply simplification

* Fix uniform constant check

* Support Add op

* Fix build error
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent b2e7114d
......@@ -40,7 +40,17 @@
using namespace std;
using namespace ngraph;
#define TI(x) type_index(typeid(x))
bool is_uniform_constant(const Input<Node>& input)
{
bool rc = false;
auto node = input.get_source_output().get_node();
if (node->get_type_info() == op::Constant::type_info)
{
auto constant = as_type<op::Constant>(node);
rc = constant->get_all_data_elements_bitwise_identical();
}
return rc;
}
extern template Shape ngraph::apply_permutation<Shape>(Shape input, AxisVector order);
template <typename T>
......@@ -54,11 +64,6 @@ static shared_ptr<pattern::Matcher>
return matcher;
}
static shared_ptr<pattern::op::Label> get_broadcast_label(shared_ptr<pattern::Matcher> matcher)
{
return static_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
}
//`simplify_concat` identifies slices-concat sequences
// that cancel each other. Namely it replaces subgraphs
// similar to the one below with `arg`
......@@ -240,6 +245,109 @@ static bool simplify_concat(shared_ptr<Node> n)
return true;
}
static bool is_uniform_constant(const op::Constant* constant, int value)
{
bool rc = false;
if (constant && constant->get_all_data_elements_bitwise_identical())
{
switch (constant->get_element_type())
{
case ngraph::element::Type_t::undefined:
{
throw runtime_error("is_value type not supported");
}
case ngraph::element::Type_t::dynamic: { throw runtime_error("is_value type not supported");
}
case ngraph::element::Type_t::boolean: break;
case ngraph::element::Type_t::bf16:
rc = *static_cast<const bfloat16*>(constant->get_data_ptr()) ==
bfloat16(static_cast<float>(value));
break;
case ngraph::element::Type_t::f16:
rc = *static_cast<const float16*>(constant->get_data_ptr()) ==
float16(static_cast<float>(value));
break;
case ngraph::element::Type_t::f32:
rc = *static_cast<const float*>(constant->get_data_ptr()) == static_cast<float>(value);
break;
case ngraph::element::Type_t::f64:
rc =
*static_cast<const double*>(constant->get_data_ptr()) == static_cast<double>(value);
break;
case ngraph::element::Type_t::i8:
rc =
*static_cast<const int8_t*>(constant->get_data_ptr()) == static_cast<int8_t>(value);
break;
case ngraph::element::Type_t::i16:
rc = *static_cast<const int16_t*>(constant->get_data_ptr()) ==
static_cast<int16_t>(value);
break;
case ngraph::element::Type_t::i32:
rc = *static_cast<const int32_t*>(constant->get_data_ptr()) ==
static_cast<int32_t>(value);
break;
case ngraph::element::Type_t::i64:
rc = *static_cast<const int64_t*>(constant->get_data_ptr()) ==
static_cast<int64_t>(value);
break;
case ngraph::element::Type_t::u1: throw runtime_error("is_value type not supported");
case ngraph::element::Type_t::u8:
rc = *static_cast<const uint8_t*>(constant->get_data_ptr()) ==
static_cast<uint8_t>(value);
break;
case ngraph::element::Type_t::u16:
rc = *static_cast<const uint16_t*>(constant->get_data_ptr()) ==
static_cast<uint16_t>(value);
break;
case ngraph::element::Type_t::u32:
rc = *static_cast<const uint32_t*>(constant->get_data_ptr()) ==
static_cast<uint32_t>(value);
break;
case ngraph::element::Type_t::u64:
rc = *static_cast<const uint64_t*>(constant->get_data_ptr()) ==
static_cast<uint64_t>(value);
break;
}
}
return rc;
}
static shared_ptr<op::Constant> get_constant(shared_ptr<Node> op)
{
set<Node::type_info_t> nomath = {op::Broadcast::type_info, op::Reshape::type_info};
while (nomath.find(op->get_type_info()) != nomath.end())
{
op = op->input(0).get_source_output().get_node_shared_ptr();
}
return as_type_ptr<op::Constant>(op);
}
static bool is_input_uniform_constant(shared_ptr<Node> op,
int constant_value,
shared_ptr<Node>& constant,
shared_ptr<Node>& value)
{
bool rc = false;
auto c = get_constant(op->input(0).get_source_output().get_node_shared_ptr());
if (is_uniform_constant(c.get(), constant_value))
{
constant = op->input(0).get_source_output().get_node_shared_ptr();
value = op->input(1).get_source_output().get_node_shared_ptr();
rc = true;
}
else
{
c = get_constant(op->input(1).get_source_output().get_node_shared_ptr());
if (is_uniform_constant(c.get(), constant_value))
{
constant = op->input(1).get_source_output().get_node_shared_ptr();
value = op->input(0).get_source_output().get_node_shared_ptr();
rc = true;
}
}
return rc;
}
//`simplify_multiply` optimizes the following 4 *base* cases
//(8 cases in total including variants due to commutativity)
//
......@@ -249,33 +357,28 @@ static bool simplify_concat(shared_ptr<Node> n)
// a * broadcast(1) -> a
static bool simplify_multiply(shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_multiply for " << n->get_name();
auto iconst = make_zero(element::i32, Shape{});
auto label = make_shared<pattern::op::Label>(iconst);
auto const_label_zero = make_shared<pattern::op::Label>(iconst, is_zero, NodeVector{iconst});
auto const_label_one = make_shared<pattern::op::Label>(iconst, is_one, NodeVector{iconst});
auto matcher_const_zero = create_binary_matcher<op::Multiply>(label, const_label_zero);
auto matcher_const_one = create_binary_matcher<op::Multiply>(label, const_label_one);
if (matcher_const_zero->match(n))
{
auto bcst_label = get_broadcast_label(matcher_const_zero);
auto bcst_or_cnst = matcher_const_zero->get_pattern_map()[bcst_label];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << bcst_or_cnst->get_name();
replace_node(n, bcst_or_cnst);
return true;
}
if (matcher_const_one->match(n))
bool rc = false;
auto multiply = as_type_ptr<op::Multiply>(n);
if (multiply)
{
auto x = matcher_const_one->get_pattern_map()[label];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name();
replace_node(n, x);
return true;
shared_ptr<Node> constant;
shared_ptr<Node> value;
if (is_input_uniform_constant(multiply, 0, constant, value))
{
replace_node(multiply, constant);
rc = true;
}
else
{
if (is_input_uniform_constant(multiply, 1, constant, value))
{
replace_node(multiply, value);
rc = true;
}
}
}
return false;
return rc;
}
//`simplify_add` optimizes the following 2 *base* cases
......@@ -285,32 +388,20 @@ static bool simplify_multiply(shared_ptr<Node> n)
// a + broadcast(0) -> a
static bool simplify_add(shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_add for " << n->get_name();
auto iconst = make_zero(element::i32, Shape{});
auto label = make_shared<pattern::op::Label>(iconst);
auto const_label = make_shared<pattern::op::Label>(iconst, nullptr, NodeVector{iconst});
auto matcher = create_binary_matcher<op::Add>(label, const_label);
if (matcher->match(n))
bool rc = false;
auto add = as_type_ptr<op::Add>(n);
if (add)
{
auto pattern_map = matcher->get_pattern_map();
auto x = pattern_map[label];
auto cnst = pattern_map[const_label];
NGRAPH_DEBUG << "Node " << n->get_name() << " matched \" arg + 0 \" \n"
<< " arg : " << x->get_name() << " , const : " << cnst->get_name();
if (is_zero(cnst))
shared_ptr<Node> constant;
shared_ptr<Node> value;
if (is_input_uniform_constant(add, 0, constant, value))
{
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name();
replace_node(n, x);
return true;
}
else
{
NGRAPH_DEBUG << cnst->get_name() << " not equal to 0 ";
replace_node(add, value);
rc = true;
}
}
return false;
return rc;
}
//`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)`
......@@ -455,20 +546,20 @@ static bool simplify_reduction(shared_ptr<Node> n)
return true;
}
static unordered_map<type_index, function<bool(shared_ptr<Node>)>> initialize_ops_to_simplifiers()
static unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>> initialize_ops_to_simplifiers()
{
return unordered_map<type_index, function<bool(shared_ptr<Node>)>>(
{{TI(op::Add), simplify_add},
{TI(op::Multiply), simplify_multiply},
{TI(op::Concat), simplify_concat},
{TI(op::Sum),
return unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>>(
{{op::Add::type_info, simplify_add},
{op::Multiply::type_info, simplify_multiply},
{op::Concat::type_info, simplify_concat},
{op::Sum::type_info,
function<bool(shared_ptr<Node>)>{simplify_reduction<op::Sum, get_sum_constant>}},
{TI(op::Product),
{op::Product::type_info,
function<bool(shared_ptr<Node>)>{simplify_reduction<op::Product, get_prod_constant>}},
{TI(op::Log), simplify_log}});
{op::Log::type_info, simplify_log}});
}
static unordered_map<type_index, function<bool(shared_ptr<Node>)>> ops_to_simplifiers =
static unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>> ops_to_simplifiers =
initialize_ops_to_simplifiers();
bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f)
......@@ -481,14 +572,11 @@ bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f)
continue;
}
const Node& node = *n;
auto eh = ops_to_simplifiers.find(TI(node));
if (eh == ops_to_simplifiers.end())
auto eh = ops_to_simplifiers.find(n->get_type_info());
if (eh != ops_to_simplifiers.end())
{
continue;
replaced |= eh->second(n);
}
replaced = eh->second(n) || replaced;
}
return replaced;
}
......@@ -119,7 +119,7 @@ TEST(algebraic_simplification, add_broadcast)
}
}
TEST(algebraic_simplification, multiply_broadcast)
TEST(algebraic_simplification, multiply_broadcast_0)
{
Shape shape{2, 2};
pass::Manager pass_manager;
......@@ -139,7 +139,7 @@ TEST(algebraic_simplification, multiply_broadcast)
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Multiply>(f), 0);
auto expected = ngraph::NodeVector{a, b, const_broadcast, c, const_broadcast};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
......@@ -148,6 +148,34 @@ TEST(algebraic_simplification, multiply_broadcast)
}
}
TEST(algebraic_simplification, multiply_broadcast_1)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto const_broadcast = ngraph::builder::make_constant<int32_t>(element::i32, shape, 1);
auto mul_a_0 = a * const_broadcast;
auto mul_a_0_0 = mul_a_0 * const_broadcast;
auto mul_b_0 = b * const_broadcast;
auto mul_b_0_0 = mul_b_0 * const_broadcast;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Multiply>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected[i], results[i]->get_argument(0));
}
}
TEST(algebraic_simplification, zero_plus_zero_commutativity)
{
Shape shape{};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment