Commit bd51497b authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

infra for algebraic simplification and simplifications for Add and Mu… (#878)

* infra for algebraic simplification and simplifications for Add and Multiply (including broadcast consts)

* add tests, fix bugs

* negative tests, 0*0, 0*1, 0+0

* possible fix for 0*1

* remove stale test

* fix merge comp errors

* fix comp errors
parent 1eb9f9bf
......@@ -105,6 +105,7 @@ set (SRC
op/util/unary_elementwise_arithmetic.cpp
op/util/unary_elementwise.cpp
pass/assign_placement.cpp
pass/algebraic_simplification.cpp
pass/dump_sorted.cpp
pass/get_output_element_elimination.cpp
pass/graph_rewrite.cpp
......
......@@ -417,3 +417,15 @@ std::shared_ptr<Node> ngraph::make_constant_from_string(std::string val,
auto cvals = std::vector<std::string>(shape_size(shape), val);
return std::make_shared<op::Constant>(element_type, shape, cvals);
}
bool ngraph::is_zero(std::shared_ptr<Node> reduce_constant)
{
auto result_bool = is_equal_to_const_value("0", reduce_constant);
return result_bool;
}
bool ngraph::is_one(std::shared_ptr<Node> reduce_constant)
{
auto result_bool = is_equal_to_const_value("1", reduce_constant);
return result_bool;
}
......@@ -127,4 +127,8 @@ namespace ngraph
std::shared_ptr<Node> make_constant_from_string(std::string val,
const element::Type& element_type,
const Shape& shape);
bool is_zero(std::shared_ptr<Node> reduce_constant);
bool is_one(std::shared_ptr<Node> reduce_constant);
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include <set>
#include "algebraic_simplification.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace ngraph;
#define TI(x) std::type_index(typeid(x))
template <typename T>
static std::shared_ptr<pattern::Matcher>
create_binary_matcher(std::shared_ptr<pattern::op::Label> label,
std::shared_ptr<pattern::op::Label> const_label)
{
auto bcst_pred = [](std::shared_ptr<Node> n) {
return std::dynamic_pointer_cast<op::Broadcast>(n) != nullptr;
};
auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
auto matcher = std::make_shared<pattern::Matcher>(std::make_shared<T>(label, bcst), nullptr);
return matcher;
}
static bool simplify_multiply(std::shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_multiply for " << n->get_name();
auto iconst = ngraph::make_zero(element::i32, Shape{});
auto label = std::make_shared<pattern::op::Label>(iconst);
auto const_label_zero =
std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
auto const_label_one =
std::make_shared<pattern::op::Label>(iconst, ngraph::is_one, NodeVector{iconst});
auto matcher_const_zero = create_binary_matcher<op::Multiply>(label, const_label_zero);
auto matcher_const_one = create_binary_matcher<op::Multiply>(label, const_label_one);
if (matcher_const_zero->match(n))
{
auto cnst = matcher_const_zero->get_pattern_map()[const_label_zero];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << cnst->get_name();
ngraph::replace_node(n, cnst);
return true;
}
if (matcher_const_one->match(n))
{
auto x = matcher_const_one->get_pattern_map()[label];
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name();
ngraph::replace_node(n, x);
return true;
}
return false;
}
static bool simplify_add(std::shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_add for " << n->get_name();
auto iconst = ngraph::make_zero(element::i32, Shape{});
auto label = std::make_shared<pattern::op::Label>(iconst);
auto const_label = std::make_shared<pattern::op::Label>(iconst, nullptr, NodeVector{iconst});
auto matcher = create_binary_matcher<op::Add>(label, const_label);
if (matcher->match(n))
{
auto pattern_map = matcher->get_pattern_map();
auto x = pattern_map[label];
auto cnst = pattern_map[const_label];
NGRAPH_DEBUG << "Node " << n->get_name() << " matched \" arg + 0 \" \n"
<< " arg : " << x->get_name() << " , const : " << cnst->get_name();
if (ngraph::is_zero(cnst))
{
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << x->get_name();
ngraph::replace_node(n, x);
return true;
}
else
{
NGRAPH_DEBUG << cnst->get_name() << " not equal to 0 ";
}
}
return false;
}
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
initialize_const_values_to_ops()
{
return std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>({
{TI(op::Add), simplify_add}, {TI(op::Multiply), simplify_multiply},
});
}
static std::unordered_map<std::type_index, std::function<bool(std::shared_ptr<Node>)>>
ops_to_const_values = initialize_const_values_to_ops();
bool ngraph::pass::AlgebraicSimplification::run_on_function(std::shared_ptr<ngraph::Function> f)
{
bool replaced = false;
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter())
{
continue;
}
const Node& node = *n;
auto eh = ops_to_const_values.find(TI(node));
if (eh == ops_to_const_values.end())
{
continue;
}
replaced = eh->second(n) || replaced;
}
return replaced;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class AlgebraicSimplification;
}
}
class ngraph::pass::AlgebraicSimplification : public FunctionPass
{
public:
AlgebraicSimplification()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};
......@@ -35,12 +35,6 @@
using namespace ngraph;
using namespace std;
bool is_zero(shared_ptr<Node> reduce_constant)
{
auto result_bool = is_equal_to_const_value("0", reduce_constant);
return result_bool;
}
static shared_ptr<Node> construct_constant_node(int n)
{
return op::Constant::create(element::f32, Shape{}, {n});
......@@ -64,7 +58,7 @@ void pass::CoreFusion::construct_relu_pattern()
auto pattern_map = m.get_pattern_map();
auto mzero = m.get_pattern_map()[zero];
if (!is_zero(mzero))
if (!ngraph::is_zero(mzero))
{
NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n";
return false;
......
......@@ -98,6 +98,7 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp"
......@@ -308,6 +309,7 @@ void runtime::cpu::CPU_ExternalFunction::compile()
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUNopElimination>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
......
......@@ -26,6 +26,7 @@ include_directories(
set (SRC
backend_api.cpp
algebraic_simplification.cpp
backend_debug_api.cpp
builder.cpp
builder_autobroadcast.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pass/algebraic_simplification.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/serializer.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(algebraic_simplification, add_types_shapes)
{
Shape shapes[] = {Shape{}, Shape{2, 2}, Shape{3, 3, 3}};
element::Type types[] = {element::i32, element::f32, element::f64};
for (auto type : types)
{
for (auto shape : shapes)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
auto add_a_0 = a + iconst0;
auto add_a_0_0 = add_a_0 + iconst0;
auto add_b_0 = b + iconst0;
auto add_b_0_0 = add_b_0 + iconst0;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
op::ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
}
}
}
}
TEST(algebraic_simplification, add_broadcast)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
auto const_broadcast = make_shared<op::Broadcast>(iconst0, shape, AxisSet{0, 1});
auto add_a_0 = a + const_broadcast;
auto add_a_0_0 = add_a_0 + const_broadcast;
auto add_b_0 = b + const_broadcast;
auto add_b_0_0 = add_b_0 + const_broadcast;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
op::ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
}
}
TEST(algebraic_simplification, zero_plus_zero_commutativity)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
auto add_a_0 = iconst0 + iconst0;
auto add_a_0_0 = iconst0 + iconst0;
auto add_b_0 = iconst0 + b;
auto add_b_0_0 = iconst0 + b;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
op::ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_TRUE(ngraph::is_zero(f->get_results().at(2)->get_argument(0)));
ASSERT_EQ(f->get_results().at(4)->get_argument(0), b);
}
TEST(algebraic_simplification, zero_multiply_zero_one)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
auto iconst1 = ngraph::make_constant_from_string("1", type, shape);
auto add_a_0 = iconst0 * iconst0;
auto add_b_0 = iconst1 * iconst0;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0, c, add_b_0},
op::ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_TRUE(ngraph::is_zero(f->get_results().at(2)->get_argument(0)));
ASSERT_TRUE(ngraph::is_zero(f->get_results().at(4)->get_argument(0)));
}
TEST(algebraic_simplification, add_negative_tests)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto abs_a = make_shared<op::Abs>(a);
auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
auto add_a_0 = a + iconst2;
auto add_a_0_0 = add_a_0 + iconst2;
auto add_b_0 = b + abs_a;
auto add_b_0_0 = add_b_0 + abs_a;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
op::ParameterVector{a, b, c});
pass_manager.run_passes(f);
auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
}
}
TEST(algebraic_simplification, multiply_negative_tests)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::AlgebraicSimplification>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto abs_a = make_shared<op::Abs>(a);
auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
auto add_a_0 = a * iconst2;
auto add_a_0_0 = add_a_0 * iconst2;
auto add_b_0 = b * abs_a;
auto add_b_0_0 = add_b_0 * abs_a;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
op::ParameterVector{a, b, c});
pass_manager.run_passes(f);
auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i), results.at(i)->get_argument(0));
}
}
......@@ -72,11 +72,6 @@ static std::shared_ptr<Node> construct_constant_node(int n)
return op::Constant::create(element::i32, Shape{}, {n});
}
bool is_zero(std::shared_ptr<Node> reduce_constant)
{
return is_equal_to_const_value("0", reduce_constant);
}
bool sum_predicate(std::shared_ptr<Node> gn)
{
NGRAPH_DEBUG << "pred_v2 : looking at " << gn->get_name();
......@@ -85,7 +80,7 @@ bool sum_predicate(std::shared_ptr<Node> gn)
auto reducee = gn->get_argument(0);
auto reduce_constant = gn->get_argument(1);
if (!is_zero(reduce_constant))
if (!ngraph::is_zero(reduce_constant))
{
return false;
}
......@@ -651,9 +646,9 @@ public:
auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);
auto is_iconst_zero = [](std::shared_ptr<Node> n) {
bool result = is_zero(n);
bool result = ngraph::is_zero(n);
NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
return is_zero(n);
return ngraph::is_zero(n);
};
bool are_all_iconst_zeros =
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment