Commit 1c2b0dc9 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Infrastructure for Common Subexpression Elimination (#927)

* cse init

* init tests

* clean up; more tests

* remove visualizations
parent c349056e
......@@ -106,6 +106,7 @@ set (SRC
op/util/unary_elementwise.cpp
pass/assign_placement.cpp
pass/algebraic_simplification.cpp
pass/cse.cpp
pass/dump_sorted.cpp
pass/get_output_element_elimination.cpp
pass/graph_rewrite.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include <set>
#include <typeinfo>
#include <unordered_map>
#include "cse.hpp"
#include "ngraph/axis_vector.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace ngraph;
#define TI(x) std::type_index(typeid(x))
static bool cse_unarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name();
return a->get_argument(0) == b->get_argument(0);
}
static bool cse_binarywise(std::shared_ptr<Node> a, std::shared_ptr<Node> b)
{
NGRAPH_DEBUG << "In cse_binary for " << a->get_name() << " and " << b->get_name();
return (a->get_argument(0) == b->get_argument(0) && a->get_argument(1) == b->get_argument(1)) ||
(a->get_argument(1) == b->get_argument(0) && a->get_argument(0) == b->get_argument(1));
}
static std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
initialize_ops_to_cse_handlers()
{
return std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>({
{TI(op::Abs), cse_unarywise}, {TI(op::Add), cse_binarywise},
});
}
static std::unordered_map<std::type_index,
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
ops_to_cse_handlers = initialize_ops_to_cse_handlers();
class NodeKey
{
public:
NodeKey(std::shared_ptr<Node> n)
: m_node(n)
{
}
std::shared_ptr<Node> get_node() const { return m_node; }
bool operator==(const NodeKey& other) const
{
Node& p_this = *m_node.get();
Node& p_other = *other.get_node().get();
if (TI(p_this) != TI(p_other))
{
return false;
}
auto eh = ops_to_cse_handlers.find(TI(p_this));
if (eh == ops_to_cse_handlers.end())
{
return false;
}
return eh->second(m_node, other.get_node());
}
private:
std::shared_ptr<Node> m_node;
};
namespace std
{
template <>
struct hash<NodeKey>
{
std::size_t operator()(const NodeKey& k) const
{
Node& p_this = *k.get_node().get();
auto ti = TI(p_this);
std::hash<std::type_index> type_hash_compute{};
auto type_hash = type_hash_compute(ti);
std::vector<size_t> arg_ids;
arg_ids.push_back(type_hash);
auto cargs = k.get_node()->get_arguments();
//TODO: Do we need another map, so we could
//specify how to compute hash for each op?
if (p_this.is_commutative())
{
std::sort(begin(cargs), end(cargs));
}
for (auto arg : cargs)
{
arg_ids.push_back(arg->get_instance_id());
}
auto hashc = ngraph::hash_combine(arg_ids);
return hashc;
}
};
}
bool ngraph::pass::CommonSubexpressionElimination::run_on_function(
std::shared_ptr<ngraph::Function> f)
{
bool replaced = false;
std::unordered_map<NodeKey, std::shared_ptr<Node>> expressions{};
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter() ||
n->is_constant() /*we could CSE constants as well*/)
{
continue;
}
NodeKey n_key{n};
if (expressions.count(n_key))
{
ngraph::replace_node(n, expressions.at(n_key));
replaced = true;
}
else
{
expressions.insert(std::make_pair(n_key, n));
}
}
return replaced;
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class CommonSubexpressionElimination;
}
}
class ngraph::pass::CommonSubexpressionElimination : public FunctionPass
{
public:
CommonSubexpressionElimination()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};
......@@ -34,6 +34,7 @@ set (SRC
copy.cpp
core_fusion.cpp
cpio.cpp
cse.cpp
element_type.cpp
file_util.cpp
inliner.cpp
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/pass/cse.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(CSE, abs_abs)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto abs1 = std::make_shared<op::Abs>(A);
auto abs2 = std::make_shared<op::Abs>(A);
auto f = std::make_shared<Function>(NodeVector{abs1, abs2}, op::ParameterVector{A});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
TEST(CSE, abs_abs_negative)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto abs1 = std::make_shared<op::Abs>(A);
auto abs2 = std::make_shared<op::Abs>(B);
auto f = std::make_shared<Function>(NodeVector{abs1, abs2}, op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), abs1);
ASSERT_EQ(f->get_results().at(1)->get_argument(0), abs2);
}
TEST(CSE, add_add)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto add1 = std::make_shared<op::Add>(A, B);
auto add2 = std::make_shared<op::Add>(A, B);
auto f = std::make_shared<Function>(NodeVector{add1, add2}, op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
TEST(CSE, add_add_commutative)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto add1 = std::make_shared<op::Add>(A, B);
auto add2 = std::make_shared<op::Add>(B, A);
auto f = std::make_shared<Function>(NodeVector{add1, add2}, op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
TEST(CSE, add_add_negative)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto C = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto D = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto add1 = std::make_shared<op::Add>(A, B);
auto add2 = std::make_shared<op::Add>(C, D);
auto f = std::make_shared<Function>(NodeVector{add1, add2}, op::ParameterVector{A, B, C, D});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), add1);
ASSERT_EQ(f->get_results().at(1)->get_argument(0), add2);
}
TEST(CSE, abs_add)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto abs_a1 = std::make_shared<op::Abs>(A);
auto abs_b1 = std::make_shared<op::Abs>(B);
auto abs_a2 = std::make_shared<op::Abs>(A);
auto abs_b2 = std::make_shared<op::Abs>(B);
auto add1 = std::make_shared<op::Add>(abs_a1, abs_b1);
auto add2 = std::make_shared<op::Add>(abs_a2, abs_b2);
auto f = std::make_shared<Function>(NodeVector{add1, add2}, op::ParameterVector{A, B});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
TEST(CSE, abs_add_abs_add)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto abs_a1 = std::make_shared<op::Abs>(A);
auto abs_b1 = std::make_shared<op::Abs>(B);
auto abs_a2 = std::make_shared<op::Abs>(A);
auto abs_b2 = std::make_shared<op::Abs>(B);
auto add1 = std::make_shared<op::Add>(abs_a1, abs_b1);
auto add2 = std::make_shared<op::Add>(abs_a2, abs_b2);
auto abs_add1 = std::make_shared<op::Abs>(add1);
auto abs_add2 = std::make_shared<op::Abs>(add2);
auto C = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto add3 = std::make_shared<op::Add>(abs_add1, C);
auto add4 = std::make_shared<op::Add>(abs_add2, C);
auto f = std::make_shared<Function>(NodeVector{add3, add4}, op::ParameterVector{A, B, C});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->get_argument(0), f->get_results().at(1)->get_argument(0));
}
TEST(CSE, abs_add_abs_add_negative)
{
Shape zero_shape{0};
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto abs_a1 = std::make_shared<op::Abs>(A);
auto abs_b1 = std::make_shared<op::Abs>(B);
auto abs_a2 = std::make_shared<op::Abs>(A);
auto abs_b2 = std::make_shared<op::Abs>(B);
auto add1 = std::make_shared<op::Add>(abs_a1, abs_b1);
auto add2 = std::make_shared<op::Add>(abs_a2, abs_b2);
auto abs_add1 = std::make_shared<op::Abs>(add1);
auto abs_add2 = std::make_shared<op::Abs>(add2);
auto C = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto D = std::make_shared<op::Parameter>(element::i32, zero_shape);
auto add3 = std::make_shared<op::Add>(abs_add1, C);
auto add4 = std::make_shared<op::Add>(abs_add2, D);
auto f = std::make_shared<Function>(NodeVector{add3, add4}, op::ParameterVector{A, B, C, D});
pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.run_passes(f);
auto oadd3 = f->get_results().at(0)->get_argument(0);
auto oadd4 = f->get_results().at(1)->get_argument(0);
ASSERT_EQ(oadd3, add3);
ASSERT_EQ(oadd4, add4);
ASSERT_EQ(oadd3->get_argument(1), C);
ASSERT_EQ(oadd4->get_argument(1), D);
ASSERT_EQ(oadd3->get_argument(0), oadd4->get_argument(0));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment