Commit 44f7479d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Move callback out of matcher (#2305)

* ngraph core changed

* cpu bits + pattern

* change perms

* revert std::cout to NGRAPH_DEBUG

* fix comp errors

* fix comment

* more comment fixes

* Remove callback argument

* clean up.

* fixed tests.

* more fixes for backends.

* fixe namespace.

* similar fix for recurrent matcher, and improves.

* fix build.
parent 715eeb37
...@@ -52,7 +52,7 @@ static shared_ptr<pattern::Matcher> ...@@ -52,7 +52,7 @@ static shared_ptr<pattern::Matcher>
{ {
auto bcst = make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::Broadcast>()); auto bcst = make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::Broadcast>());
auto bcst_label = make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst}); auto bcst_label = make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher = make_shared<pattern::Matcher>(make_shared<T>(label, bcst_label), nullptr); auto matcher = make_shared<pattern::Matcher>(make_shared<T>(label, bcst_label));
return matcher; return matcher;
} }
...@@ -86,7 +86,7 @@ static bool simplify_concat(shared_ptr<Node> n) ...@@ -86,7 +86,7 @@ static bool simplify_concat(shared_ptr<Node> n)
auto skip_reshape = make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::Reshape>()); auto skip_reshape = make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::Reshape>());
auto matcher = make_shared<pattern::Matcher>(skip_reshape, nullptr); auto matcher = make_shared<pattern::Matcher>(skip_reshape);
Coordinate prev_lower_bounds; Coordinate prev_lower_bounds;
Shape prev_slice_shape; Shape prev_slice_shape;
......
...@@ -80,7 +80,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n) ...@@ -80,7 +80,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
auto slice_weights_label = auto slice_weights_label =
std::make_shared<pattern::op::Label>(slice_weights, nullptr, NodeVector{slice_weights}); std::make_shared<pattern::op::Label>(slice_weights, nullptr, NodeVector{slice_weights});
auto conv = std::make_shared<op::Convolution>(slice_data, slice_weights_label); auto conv = std::make_shared<op::Convolution>(slice_data, slice_weights_label);
auto matcher = std::make_shared<pattern::Matcher>(conv, nullptr); auto matcher = std::make_shared<pattern::Matcher>(conv);
NGRAPH_DEBUG << "In simplify_concat (group convolution) for " << n->get_name(); NGRAPH_DEBUG << "In simplify_concat (group convolution) for " << n->get_name();
......
...@@ -113,8 +113,8 @@ void pass::ConcatElimination::construct_concat_elimination() ...@@ -113,8 +113,8 @@ void pass::ConcatElimination::construct_concat_elimination()
return false; return false;
}; };
auto m = std::make_shared<pattern::Matcher>(concat_label, callback); auto m = std::make_shared<pattern::Matcher>(concat_label, "ConcatElimination");
this->add_matcher(m); this->add_matcher(m, callback);
} }
bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function) bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function)
......
...@@ -180,9 +180,8 @@ void pass::ConstantFolding::construct_constant_pad() ...@@ -180,9 +180,8 @@ void pass::ConstantFolding::construct_constant_pad()
return false; return false;
}; };
auto pad_matcher = auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
make_shared<pattern::Matcher>(pad, constant_pad_callback, "ConstantFolding.ConstantPad"); this->add_matcher(pad_matcher, constant_pad_callback);
this->add_matcher(pad_matcher);
} }
void pass::ConstantFolding::construct_constant_reshape() void pass::ConstantFolding::construct_constant_reshape()
...@@ -245,9 +244,9 @@ void pass::ConstantFolding::construct_constant_reshape() ...@@ -245,9 +244,9 @@ void pass::ConstantFolding::construct_constant_reshape()
return false; return false;
}; };
auto reshape_matcher = make_shared<pattern::Matcher>( auto reshape_matcher =
reshape, constant_reshape_callback, "ConstantFolding.ConstantReshape"); make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(reshape_matcher); this->add_matcher(reshape_matcher, constant_reshape_callback);
} }
template <class T> template <class T>
...@@ -340,9 +339,9 @@ void pass::ConstantFolding::construct_constant_broadcast() ...@@ -340,9 +339,9 @@ void pass::ConstantFolding::construct_constant_broadcast()
return false; return false;
}; };
auto broadcast_matcher = make_shared<pattern::Matcher>( auto broadcast_matcher =
broadcast, constant_broadcast_callback, "ConstantFolding.ConstantBroadcast"); make_shared<pattern::Matcher>(broadcast, "ConstantFolding.ConstantBroadcast");
this->add_matcher(broadcast_matcher); this->add_matcher(broadcast_matcher, constant_broadcast_callback);
} }
template <class T> template <class T>
...@@ -478,9 +477,8 @@ void pass::ConstantFolding::construct_constant_binary() ...@@ -478,9 +477,8 @@ void pass::ConstantFolding::construct_constant_binary()
return false; return false;
}; };
auto reshape_matcher = make_shared<pattern::Matcher>( auto reshape_matcher = make_shared<pattern::Matcher>(bea, "ConstantFolding.ConstantBinary");
bea, constant_binary_callback, "ConstantFolding.ConstantBinary"); this->add_matcher(reshape_matcher, constant_binary_callback);
this->add_matcher(reshape_matcher);
} }
bool is_supported_unary_op(std::shared_ptr<Node> n) bool is_supported_unary_op(std::shared_ptr<Node> n)
...@@ -609,9 +607,8 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -609,9 +607,8 @@ void pass::ConstantFolding::construct_constant_unary()
return false; return false;
}; };
auto reshape_matcher = make_shared<pattern::Matcher>( auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary");
uea, constant_unary_callback, "ConstantFolding.ConstantUnary"); this->add_matcher(reshape_matcher, constant_unary_callback);
this->add_matcher(reshape_matcher);
} }
template <class QUANT, class REAL> template <class QUANT, class REAL>
...@@ -682,9 +679,9 @@ void pass::ConstantFolding::construct_constant_dequantize() ...@@ -682,9 +679,9 @@ void pass::ConstantFolding::construct_constant_dequantize()
return false; return false;
}; };
auto dequantize_matcher = make_shared<pattern::Matcher>( auto dequantize_matcher =
dequant, constant_dequantize_callback, "ConstantFolding.ConstantDequantize"); make_shared<pattern::Matcher>(dequant, "ConstantFolding.ConstantDequantize");
this->add_matcher(dequantize_matcher); this->add_matcher(dequantize_matcher, constant_dequantize_callback);
} }
template <class REAL, class QUANT> template <class REAL, class QUANT>
...@@ -757,7 +754,7 @@ void pass::ConstantFolding::construct_constant_quantize() ...@@ -757,7 +754,7 @@ void pass::ConstantFolding::construct_constant_quantize()
return false; return false;
}; };
auto quantize_matcher = make_shared<pattern::Matcher>( auto quantize_matcher =
quant, constant_quantize_callback, "ConstantFolding.ConstantQuantize"); make_shared<pattern::Matcher>(quant, "ConstantFolding.ConstantQuantize");
this->add_matcher(quantize_matcher); this->add_matcher(quantize_matcher, constant_quantize_callback);
} }
This diff is collapsed.
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include "graph_rewrite.hpp" #include "graph_rewrite.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -64,35 +63,28 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -64,35 +63,28 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
bool rewritten = false; bool rewritten = false;
const size_t NUM_TRIES = 10; const size_t NUM_TRIES = 10;
size_t tries = NUM_TRIES; size_t tries = NUM_TRIES;
vector<shared_ptr<pattern::Matcher>> original_matchers{m_matchers}; vector<MatchClosure> original_matchers{m_matchers};
bool is_dynamic_function = f->is_dynamic();
do do
{ {
rewritten = false; rewritten = false;
vector<shared_ptr<pattern::Matcher>> matchers{m_matchers}; // m_matchers may contain newly constructed matchers for matchers
// that need multiple passes. See comments above.
vector<MatchClosure> run_matchers{m_matchers};
m_matchers.clear(); m_matchers.clear();
for (auto node : f->get_ordered_ops()) for (auto node : f->get_ordered_ops())
{ {
for (auto matcher : matchers) for (auto& mc : run_matchers)
{ {
if (is_dynamic_function && NGRAPH_DEBUG << "Running matcher " << mc.matcher->get_name() << "("
(matcher->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE))) << mc.matcher->get_pattern()->get_name() << ") on "
<< node->get_name();
if (mc.matcher->match(node))
{ {
NGRAPH_DEBUG NGRAPH_DEBUG << "Matcher " << mc.matcher << mc.matcher->get_name()
<< "matcher requires static shape but the function is dynamic, " << " matched " << node->get_name();
<< "skipping this optimization till the shapes are fully materialized"; if (mc.callback(*mc.matcher.get()))
continue;
}
NGRAPH_DEBUG << "Running matcher " << matcher->get_name() << "("
<< matcher->get_pattern()->get_name() << ") on " << node->get_name();
if (matcher->match(node))
{
NGRAPH_DEBUG << "Matcher " << matcher << matcher->get_name() << " matched "
<< node->get_name();
if (matcher->process_match())
{ {
rewritten = true; rewritten = true;
is_dynamic_function = f->is_dynamic();
break; break;
} }
} }
...@@ -105,7 +97,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f) ...@@ -105,7 +97,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
return (NUM_TRIES - tries) > 1; //this means a graph was transformed return (NUM_TRIES - tries) > 1; //this means a graph was transformed
} }
static const vector<regex> initialize_fusion_regexes() static vector<regex> initialize_fusion_regexes()
{ {
const char* cnsf = getenv("NGRAPH_DISABLED_FUSIONS"); const char* cnsf = getenv("NGRAPH_DISABLED_FUSIONS");
vector<regex> regexes; vector<regex> regexes;
...@@ -122,7 +114,7 @@ static const vector<regex> initialize_fusion_regexes() ...@@ -122,7 +114,7 @@ static const vector<regex> initialize_fusion_regexes()
return regexes; return regexes;
} }
bool pass::GraphRewrite::is_enabled(shared_ptr<pattern::Matcher> m) bool pass::GraphRewrite::is_enabled(const shared_ptr<pattern::Matcher>& m) const
{ {
//note, regexes are static to avoid re-initialization //note, regexes are static to avoid re-initialization
static const auto regexes = initialize_fusion_regexes(); static const auto regexes = initialize_fusion_regexes();
...@@ -139,41 +131,39 @@ bool pass::GraphRewrite::is_enabled(shared_ptr<pattern::Matcher> m) ...@@ -139,41 +131,39 @@ bool pass::GraphRewrite::is_enabled(shared_ptr<pattern::Matcher> m)
return true; return true;
} }
void pass::GraphRewrite::add_matcher(shared_ptr<pattern::Matcher> m) void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback)
{ {
if (is_enabled(m)) if (is_enabled(m))
{ {
m_matchers.push_back(m); m_matchers.push_back({m, callback});
} }
} }
void pass::RecurrentGraphRewrite::add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback)
{
m_matchers.push_back({m, callback});
}
bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f) bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
{ {
bool changed = false; bool changed = false;
size_t i = 0; size_t i = 0;
bool is_dynamic_function = f->is_dynamic();
do do
{ {
for (auto node : f->get_ops()) for (auto node : f->get_ops())
{ {
for (auto matcher : m_matchers) for (auto& mc : m_matchers)
{ {
if (is_dynamic_function && NGRAPH_DEBUG << "Running matcher " << mc.matcher << " on " << node->get_name();
(matcher->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE))) if (mc.matcher->match(node))
{
NGRAPH_DEBUG
<< "matcher requires static shape but the function is dynamic, "
<< "skipping this optimization till the shapes are fully materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node->get_name();
if (matcher->match(node))
{ {
NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node->get_name(); NGRAPH_DEBUG << "Matcher " << mc.matcher << " matched " << node->get_name();
if (matcher->process_match()) if (mc.callback(*mc.matcher.get()))
{ {
changed = true; changed = true;
is_dynamic_function = f->is_dynamic();
goto next_fusion; goto next_fusion;
} }
} }
......
...@@ -17,8 +17,11 @@ ...@@ -17,8 +17,11 @@
#pragma once #pragma once
#include <functional> #include <functional>
#include <memory>
#include <set> #include <set>
#include "ngraph/pass/pass.hpp" #include "ngraph/pass/pass.hpp"
#include "ngraph/pattern/matcher.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -27,11 +30,10 @@ namespace ngraph ...@@ -27,11 +30,10 @@ namespace ngraph
class GraphRewrite; class GraphRewrite;
class RecurrentGraphRewrite; class RecurrentGraphRewrite;
} }
namespace pattern
{ using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
class Matcher; using recurrent_graph_rewrite_callback =
class RecurrentMatcher; std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
}
} }
/// \brief GraphRewrite (in tandem with \sa Matcher) performs transformations on specified patterns /// \brief GraphRewrite (in tandem with \sa Matcher) performs transformations on specified patterns
...@@ -52,13 +54,20 @@ public: ...@@ -52,13 +54,20 @@ public:
{ {
} }
bool is_enabled(std::shared_ptr<pattern::Matcher> m); void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
void add_matcher(std::shared_ptr<pattern::Matcher> m); const ngraph::graph_rewrite_callback& callback);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f); virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
protected:
bool is_enabled(const std::shared_ptr<pattern::Matcher>& m) const;
private: private:
// enable cascading rewrites struct MatchClosure
std::vector<std::shared_ptr<pattern::Matcher>> m_matchers; {
std::shared_ptr<pattern::Matcher> matcher;
ngraph::graph_rewrite_callback callback;
};
std::vector<MatchClosure> m_matchers;
}; };
class ngraph::pass::RecurrentGraphRewrite : public FunctionPass class ngraph::pass::RecurrentGraphRewrite : public FunctionPass
...@@ -70,10 +79,17 @@ public: ...@@ -70,10 +79,17 @@ public:
{ {
} }
void add_matcher(std::shared_ptr<pattern::RecurrentMatcher> m) { m_matchers.push_back(m); } void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f); virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private: private:
size_t m_num_iters; size_t m_num_iters;
std::vector<std::shared_ptr<pattern::RecurrentMatcher>> m_matchers;
struct MatchClosure
{
std::shared_ptr<pattern::RecurrentMatcher> matcher;
ngraph::recurrent_graph_rewrite_callback callback;
};
std::vector<MatchClosure> m_matchers;
}; };
...@@ -77,9 +77,9 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination() ...@@ -77,9 +77,9 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
}, },
NodeVector{reshape_op}); NodeVector{reshape_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2)); replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2));
return true; return true;
}; };
add_matcher(make_shared<pattern::Matcher>(target_op, callback)); add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"), callback);
} }
...@@ -72,8 +72,8 @@ void pass::ReshapeElimination::construct_identity_reshape_pattern() ...@@ -72,8 +72,8 @@ void pass::ReshapeElimination::construct_identity_reshape_pattern()
return true; return true;
}; };
auto m = make_shared<pattern::Matcher>(reshape1, callback); auto m = make_shared<pattern::Matcher>(reshape1);
this->add_matcher(m); this->add_matcher(m, callback);
} }
void pass::ReshapeElimination::construct_reshapex2_pattern() void pass::ReshapeElimination::construct_reshapex2_pattern()
...@@ -131,8 +131,8 @@ void pass::ReshapeElimination::construct_reshapex2_pattern() ...@@ -131,8 +131,8 @@ void pass::ReshapeElimination::construct_reshapex2_pattern()
return false; return false;
}; };
auto m = make_shared<pattern::Matcher>(reshape2, callback); auto m = make_shared<pattern::Matcher>(reshape2);
this->add_matcher(m); this->add_matcher(m, callback);
} }
void pass::ReshapeElimination::construct_dot_transpose_pattern() void pass::ReshapeElimination::construct_dot_transpose_pattern()
...@@ -145,7 +145,7 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -145,7 +145,7 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern()
auto pdot = make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred); auto pdot = make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2}); auto preshape = make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = " NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -188,8 +188,8 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern() ...@@ -188,8 +188,8 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern()
return true; return true;
}; };
auto m = make_shared<pattern::Matcher>(preshape, callback); auto m = make_shared<pattern::Matcher>(preshape);
this->add_matcher(m); this->add_matcher(m, callback);
} }
void pass::RecurrentReshapeElimination::construct_recurrent_reshape() void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
...@@ -289,7 +289,7 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape() ...@@ -289,7 +289,7 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
return modify_graph; return modify_graph;
}; };
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches; std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>( auto m =
reshape_label, op, empty_correlated_matches, callback); std::make_shared<pattern::RecurrentMatcher>(reshape_label, op, empty_correlated_matches);
this->add_matcher(m); this->add_matcher(m, callback);
} }
...@@ -87,7 +87,7 @@ namespace ngraph ...@@ -87,7 +87,7 @@ namespace ngraph
for (auto entry : m_pattern_map) for (auto entry : m_pattern_map)
{ {
// leaf label // leaf label
if (entry.first->get_input_size() == 0) if (entry.first->get_inputs().empty())
{ {
label_exclusions.push_back(entry.second); label_exclusions.push_back(entry.second);
} }
...@@ -194,7 +194,7 @@ namespace ngraph ...@@ -194,7 +194,7 @@ namespace ngraph
// when their individual GOE are matched // when their individual GOE are matched
// this also gives a bit more flexibility since we don't have to worry // this also gives a bit more flexibility since we don't have to worry
// about *all* outputs of a pattern node but only the ones we want to match. // about *all* outputs of a pattern node but only the ones we want to match.
if (m_strict_mode && graph_node->get_output_size() == 1) if (m_strict_mode && graph_node->get_outputs().size() == 1)
{ {
bool shape_match = pattern_node->get_output_partial_shape(0).compatible( bool shape_match = pattern_node->get_output_partial_shape(0).compatible(
graph_node->get_output_partial_shape(0)); graph_node->get_output_partial_shape(0));
...@@ -328,26 +328,6 @@ namespace ngraph ...@@ -328,26 +328,6 @@ namespace ngraph
return false; return false;
} }
bool Matcher::process_match(::ngraph::pattern::graph_rewrite_callback callback)
{
graph_rewrite_callback cb = m_callback;
if (callback)
{
cb = callback;
}
if (!cb)
{
throw ngraph_error("process_match invoked w/o a callback function");
}
if (!this->m_match_root)
{
throw ngraph_error("process_match invoked w/o a match");
}
return cb(*this);
}
bool Matcher::match(const std::shared_ptr<Node>& graph_node) bool Matcher::match(const std::shared_ptr<Node>& graph_node)
{ {
// clear our state // clear our state
...@@ -397,11 +377,6 @@ namespace ngraph ...@@ -397,11 +377,6 @@ namespace ngraph
return is_match; return is_match;
} }
bool Matcher::get_property(const ngraph::pass::PassPropertyMask& prop) const
{
return m_property.is_set(prop);
}
bool RecurrentMatcher::match(std::shared_ptr<Node> graph) bool RecurrentMatcher::match(std::shared_ptr<Node> graph)
{ {
bool matched = false; bool matched = false;
...@@ -457,11 +432,5 @@ namespace ngraph ...@@ -457,11 +432,5 @@ namespace ngraph
return matched; return matched;
} }
bool RecurrentMatcher::process_match() { return m_callback(*this); }
bool RecurrentMatcher::get_property(const ngraph::pass::PassPropertyMask& prop) const
{
return m_property.is_set(prop);
}
} }
} }
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#include "ngraph/pattern/op/any_of.hpp" #include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp" #include "ngraph/pattern/op/skip.hpp"
#include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,8 +35,6 @@ namespace ngraph ...@@ -36,8 +35,6 @@ namespace ngraph
namespace pattern namespace pattern
{ {
using graph_rewrite_callback = std::function<bool(class Matcher& m)>;
using recurrent_graph_rewrite_callback = std::function<bool(class RecurrentMatcher& m)>;
using RPatternMap = std::map<std::shared_ptr<op::Label>, NodeVector>; using RPatternMap = std::map<std::shared_ptr<op::Label>, NodeVector>;
template <typename T> template <typename T>
...@@ -61,21 +58,35 @@ namespace ngraph ...@@ -61,21 +58,35 @@ namespace ngraph
{ {
public: public:
using PatternMap = std::map<std::shared_ptr<op::Label>, std::shared_ptr<Node>>; using PatternMap = std::map<std::shared_ptr<op::Label>, std::shared_ptr<Node>>;
// Avoid implicit string construction from nullptr.
Matcher(const std::shared_ptr<Node>& pattern_node, std::nullptr_t name) = delete;
Matcher(const std::shared_ptr<Node>& pattern_node)
: m_pattern_node{pattern_node}
, m_depth{0}
, m_name{"Unnamed"}
, m_strict_mode{false}
{
}
Matcher(const std::shared_ptr<Node>& pattern_node, const std::string& name)
: m_pattern_node(pattern_node)
, m_depth{0}
, m_name{name}
, m_strict_mode{false}
{
}
/// \brief Constructs a Matcher object /// \brief Constructs a Matcher object
/// ///
/// \param pattern_node is a pattern sub graph that will be matched against input graphs /// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param callback is a callback function that will be called on a successful match /// \param name is a string which is used for logging and disabling a matcher
Matcher(const std::shared_ptr<Node> pattern_node = nullptr, /// \param strict_mode forces a matcher to consider shapes and ET of nodes
graph_rewrite_callback callback = nullptr, Matcher(const std::shared_ptr<Node>& pattern_node,
const std::string& name = "Unnamed", const std::string& name,
pass::PassPropertyMask property = pass::PassProperty::REGULAR_FUSIONS, bool strict_mode)
bool strict_mode = false)
: m_pattern_node(pattern_node) : m_pattern_node(pattern_node)
, m_callback(callback)
, m_depth(0) , m_depth(0)
, m_name(name) , m_name(name)
, m_property(property)
, m_strict_mode(strict_mode) , m_strict_mode(strict_mode)
{ {
} }
...@@ -113,13 +124,10 @@ namespace ngraph ...@@ -113,13 +124,10 @@ namespace ngraph
return matched; return matched;
} }
bool get_property(const pass::PassPropertyMask& prop) const;
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true); bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
const NodeVector& get_matched_nodes() { return m_matched_list; }
bool process_match(graph_rewrite_callback callback = nullptr);
NodeVector get_matched_nodes() { return m_matched_list; }
void reset() {} void reset() {}
std::string get_name() { return m_name; } const std::string& get_name() { return m_name; }
std::shared_ptr<Node> get_pattern() { return m_pattern_node; } std::shared_ptr<Node> get_pattern() { return m_pattern_node; }
std::shared_ptr<Node> get_match_root(); std::shared_ptr<Node> get_match_root();
PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; } PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; }
...@@ -173,10 +181,8 @@ namespace ngraph ...@@ -173,10 +181,8 @@ namespace ngraph
const std::shared_ptr<Node>& graph_node, const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map); PatternMap& pattern_map);
graph_rewrite_callback m_callback;
size_t m_depth; size_t m_depth;
std::string m_name; std::string m_name;
pass::PassPropertyMask m_property;
bool m_strict_mode; bool m_strict_mode;
}; };
...@@ -189,17 +195,12 @@ namespace ngraph ...@@ -189,17 +195,12 @@ namespace ngraph
/// \param pattern is a pattern sub graph describing an individual cell /// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should start at /// \param rpattern is a (recurring) label to denote which node the next match should start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same across all cells /// \param correlated_patterns is a set of labels whose bound nodes must remain the same across all cells
// \param is a callback function that will be called on a successful match
RecurrentMatcher(std::shared_ptr<Node> pattern, RecurrentMatcher(std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern, std::shared_ptr<op::Label> rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns, const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
recurrent_graph_rewrite_callback callback,
pass::PassPropertyMask property = pass::PassProperty::REGULAR_FUSIONS)
: m_pattern(pattern) : m_pattern(pattern)
, m_recurrent_pattern(rpattern) , m_recurrent_pattern(rpattern)
, m_correlated_patterns(correlated_patterns) , m_correlated_patterns(correlated_patterns)
, m_callback(callback)
, m_property(property)
{ {
} }
...@@ -229,19 +230,12 @@ namespace ngraph ...@@ -229,19 +230,12 @@ namespace ngraph
/// \brief Tries to match a pattern for an individual cell to a given \p graph /// \brief Tries to match a pattern for an individual cell to a given \p graph
bool match(std::shared_ptr<Node> graph); bool match(std::shared_ptr<Node> graph);
/// \brief Invoked by a pass to process a successful match
bool process_match();
bool get_property(const pass::PassPropertyMask& prop) const;
std::shared_ptr<Node> get_match_root() { return m_match_root; } std::shared_ptr<Node> get_match_root() { return m_match_root; }
private: private:
std::shared_ptr<Node> m_pattern; std::shared_ptr<Node> m_pattern;
std::shared_ptr<op::Label> m_recurrent_pattern; std::shared_ptr<op::Label> m_recurrent_pattern;
const std::set<std::shared_ptr<op::Label>> m_correlated_patterns; const std::set<std::shared_ptr<op::Label>> m_correlated_patterns;
RPatternMap m_matches; RPatternMap m_matches;
recurrent_graph_rewrite_callback m_callback;
pass::PassPropertyMask m_property;
std::shared_ptr<Node> m_match_root; std::shared_ptr<Node> m_match_root;
}; };
} }
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include <functional> #include <functional>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph namespace ngraph
{ {
......
This diff is collapsed.
...@@ -101,7 +101,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion ...@@ -101,7 +101,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion
Strides{1, 1}, Strides{1, 1},
true); true);
pattern::graph_rewrite_callback callback = [data_conv](pattern::Matcher& m) { auto callback = [data_conv](pattern::Matcher& m) {
NGRAPH_DEBUG << "conv_horizontal_fusion: In a callback for conv horizontal fusion for " NGRAPH_DEBUG << "conv_horizontal_fusion: In a callback for conv horizontal fusion for "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -188,7 +188,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion ...@@ -188,7 +188,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion
return true; return true;
}; };
auto m = make_shared<pattern::Matcher>( auto m =
conv_bias, callback, "CPUHorizontalFusion.CpuConvHorizontalFusion"); make_shared<pattern::Matcher>(conv_bias, "CPUHorizontalFusion.CpuConvHorizontalFusion");
this->add_matcher(m); this->add_matcher(m, callback);
} }
...@@ -422,7 +422,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n) ...@@ -422,7 +422,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
auto slice_weights_label = auto slice_weights_label =
std::make_shared<pattern::op::Label>(slice_weights, nullptr, NodeVector{slice_weights}); std::make_shared<pattern::op::Label>(slice_weights, nullptr, NodeVector{slice_weights});
auto conv = std::make_shared<op::Convolution>(slice_data, slice_weights_label); auto conv = std::make_shared<op::Convolution>(slice_data, slice_weights_label);
auto matcher = std::make_shared<pattern::Matcher>(conv, nullptr); auto matcher = std::make_shared<pattern::Matcher>(conv);
NGRAPH_DEBUG << "In simplify_concat (group convolution) for " << n->get_name(); NGRAPH_DEBUG << "In simplify_concat (group convolution) for " << n->get_name();
......
...@@ -53,7 +53,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu ...@@ -53,7 +53,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu
auto conv = std::make_shared<ngraph::op::Convolution>( auto conv = std::make_shared<ngraph::op::Convolution>(
data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1}); data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1});
pattern::graph_rewrite_callback callback = [param](pattern::Matcher& m) { auto callback = [param](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_weight against " NGRAPH_DEBUG << "In a callback for construct_weight against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -116,9 +116,9 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu ...@@ -116,9 +116,9 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu
return true; return true;
}; };
auto m = make_shared<pattern::Matcher>( auto m =
conv, callback, "CPUPostLayoutOptimizations.ConstructWeight_fusion"); make_shared<pattern::Matcher>(conv, "CPUPostLayoutOptimizations.ConstructWeight_fusion");
this->add_matcher(m); this->add_matcher(m, callback);
} }
void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_slice_convertLayout_fusion() void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_slice_convertLayout_fusion()
...@@ -130,7 +130,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_slice_con ...@@ -130,7 +130,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_slice_con
auto lt_desc = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt); auto lt_desc = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt);
auto cvt_lt = std::make_shared<runtime::cpu::op::ConvertLayout>(slice, lt_desc); auto cvt_lt = std::make_shared<runtime::cpu::op::ConvertLayout>(slice, lt_desc);
pattern::graph_rewrite_callback callback = [param](pattern::Matcher& m) { auto callback = [param](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_slice_converLayout against " NGRAPH_DEBUG << "In a callback for construct_slice_converLayout against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -169,8 +169,8 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_slice_con ...@@ -169,8 +169,8 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_slice_con
}; };
auto m = make_shared<pattern::Matcher>( auto m = make_shared<pattern::Matcher>(
cvt_lt, callback, "CPUPostLayoutOptimizations.ConstructSliceConvertLayoutFusion"); cvt_lt, "CPUPostLayoutOptimizations.ConstructSliceConvertLayoutFusion");
this->add_matcher(m); this->add_matcher(m, callback);
} }
// Reshape(transpose) + ConvertLayout // Reshape(transpose) + ConvertLayout
...@@ -193,7 +193,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations:: ...@@ -193,7 +193,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::
std::make_shared<runtime::cpu::LayoutDescriptor>(*reshape->get_output_tensor_ptr()); std::make_shared<runtime::cpu::LayoutDescriptor>(*reshape->get_output_tensor_ptr());
auto cvt_lt = std::make_shared<runtime::cpu::op::ConvertLayout>(reshape, lt_desc); auto cvt_lt = std::make_shared<runtime::cpu::op::ConvertLayout>(reshape, lt_desc);
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_reshape_converLayout against " NGRAPH_DEBUG << "In a callback for construct_reshape_converLayout against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -264,6 +264,6 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations:: ...@@ -264,6 +264,6 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::
}; };
auto m = make_shared<pattern::Matcher>( auto m = make_shared<pattern::Matcher>(
cvt_lt, callback, "CPUPostLayoutOptimizations.ConstructReshapeConvertLayoutFusion"); cvt_lt, "CPUPostLayoutOptimizations.ConstructReshapeConvertLayoutFusion");
this->add_matcher(m); this->add_matcher(m, callback);
} }
...@@ -78,7 +78,7 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid() ...@@ -78,7 +78,7 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
auto divide_1_over_exp = std::make_shared<ngraph::op::Divide>(broadcast_constant, add_exp); auto divide_1_over_exp = std::make_shared<ngraph::op::Divide>(broadcast_constant, add_exp);
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) { auto callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -102,9 +102,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid() ...@@ -102,9 +102,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>( auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, "LSTMFusion.Sigmoid");
divide_1_over_exp, callback, "LSTMFusion.Sigmoid"); this->add_matcher(m, callback);
this->add_matcher(m);
} }
static void replace_collapse_node_user(std::shared_ptr<Node> collapsed_node, static void replace_collapse_node_user(std::shared_ptr<Node> collapsed_node,
...@@ -184,14 +183,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop() ...@@ -184,14 +183,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
std::make_shared<ngraph::op::Multiply>(ot, std::make_shared<ngraph::op::Tanh>(ct_label)); std::make_shared<ngraph::op::Multiply>(ot, std::make_shared<ngraph::op::Tanh>(ct_label));
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
pattern::graph_rewrite_callback callback = [ct_label, auto callback = [ct_label, w_i2h, bias_i2h, w_h2h, bias_h2h, xt, ht_1, ct_1](
w_i2h, pattern::Matcher& m) {
bias_i2h,
w_h2h,
bias_h2h,
xt,
ht_1,
ct_1](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_lstm pattern against " NGRAPH_DEBUG << "In a callback for construct_fprop_lstm pattern against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
...@@ -326,8 +319,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop() ...@@ -326,8 +319,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
ngraph::replace_node(m.get_match_root(), ht_slice); ngraph::replace_node(m.get_match_root(), ht_slice);
return true; return true;
}; };
auto m = std::make_shared<pattern::Matcher>(ht, callback, "LSTMFusion.Fprop"); auto m = std::make_shared<pattern::Matcher>(ht, "LSTMFusion.Fprop");
this->add_matcher(m); this->add_matcher(m, callback);
} }
void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
...@@ -378,13 +371,12 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -378,13 +371,12 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
auto lstm_goe_slice = auto lstm_goe_slice =
std::make_shared<ngraph::op::Slice>(lstm_goe_label, Coordinate{10, 0}, Coordinate{20, 100}); std::make_shared<ngraph::op::Slice>(lstm_goe_label, Coordinate{10, 0}, Coordinate{20, 100});
pattern::recurrent_graph_rewrite_callback callback = [lstm_goe_label, auto callback = [lstm_goe_label,
lstm_src_layer, lstm_src_layer,
lstm_src_iter_label, lstm_src_iter_label,
lstm_weights_layer_label, lstm_weights_layer_label,
lstm_weights_iter_label, lstm_weights_iter_label,
lstm_bias_label]( lstm_bias_label](pattern::RecurrentMatcher& m) {
pattern::RecurrentMatcher& m) {
NGRAPH_DEBUG << " In recurrent RNN fusion callback"; NGRAPH_DEBUG << " In recurrent RNN fusion callback";
...@@ -568,9 +560,8 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop() ...@@ -568,9 +560,8 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
std::set<std::shared_ptr<pattern::op::Label>>{lstm_weights_layer_shared, std::set<std::shared_ptr<pattern::op::Label>>{lstm_weights_layer_shared,
lstm_weights_iter_shared, lstm_weights_iter_shared,
lstm_bias_layer_shared, lstm_bias_layer_shared,
lstm_bias_iter_shared}, lstm_bias_iter_shared});
callback); this->add_matcher(m, callback);
this->add_matcher(m);
} }
static std::shared_ptr<Node> stack_rnn_inputs(NodeVector rnn_input_nodes) static std::shared_ptr<Node> stack_rnn_inputs(NodeVector rnn_input_nodes)
...@@ -613,13 +604,12 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -613,13 +604,12 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto rnn_goe0_label = auto rnn_goe0_label =
std::make_shared<pattern::op::Label>(rnn_goe0, nullptr, NodeVector{rnn_goe0}); std::make_shared<pattern::op::Label>(rnn_goe0, nullptr, NodeVector{rnn_goe0});
pattern::recurrent_graph_rewrite_callback callback = [rnn_src_layer, auto callback = [rnn_src_layer,
rnn_src_iter, rnn_src_iter,
rnn_weights_layer, rnn_weights_layer,
rnn_weights_iter, rnn_weights_iter,
rnn_bias, rnn_bias,
rnn_goe0_label]( rnn_goe0_label](pattern::RecurrentMatcher& m) {
pattern::RecurrentMatcher& m) {
auto number_of_rnn_cell_matched = m.get_number_of_recurrent_matches(); auto number_of_rnn_cell_matched = m.get_number_of_recurrent_matches();
NGRAPH_DEBUG << " In Recurrent multi layer RNN fusion callback "; NGRAPH_DEBUG << " In Recurrent multi layer RNN fusion callback ";
NGRAPH_DEBUG << " Number of RNN's Matched: " << number_of_rnn_cell_matched; NGRAPH_DEBUG << " Number of RNN's Matched: " << number_of_rnn_cell_matched;
...@@ -775,8 +765,8 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_ ...@@ -775,8 +765,8 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches; std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>( auto m = std::make_shared<pattern::RecurrentMatcher>(
rnn_goe0_label, rnn_src_layer, empty_correlated_matches, callback); rnn_goe0_label, rnn_src_layer, empty_correlated_matches);
this->add_matcher(m); this->add_matcher(m, callback);
} }
void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn() void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
...@@ -813,8 +803,7 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn() ...@@ -813,8 +803,7 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
NodeVector{rnn_ltor_goe0_reshape_tnc, skip_reverse_seq}, 0); NodeVector{rnn_ltor_goe0_reshape_tnc, skip_reverse_seq}, 0);
// Define a call back that needs to called once the DFG matches the pattern // Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [rnn_left_to_right, auto callback = [rnn_left_to_right, rnn_right_to_left](pattern::Matcher& m) {
rnn_right_to_left](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto rnn_ltor_node = auto rnn_ltor_node =
...@@ -914,6 +903,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn() ...@@ -914,6 +903,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(concat, callback); auto m = std::make_shared<ngraph::pattern::Matcher>(concat, "BiDirectionalRnn");
this->add_matcher(m); this->add_matcher(m, callback);
} }
...@@ -32,7 +32,7 @@ ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision() ...@@ -32,7 +32,7 @@ ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision()
return dynamic_cast<ngraph::op::Concat*>(node.get()) != nullptr; return dynamic_cast<ngraph::op::Concat*>(node.get()) != nullptr;
}); });
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
auto concat = std::static_pointer_cast<ngraph::op::Concat>(m.get_match_root()); auto concat = std::static_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto args = concat->get_arguments(); auto args = concat->get_arguments();
...@@ -114,5 +114,5 @@ ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision() ...@@ -114,5 +114,5 @@ ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision()
return true; return true;
}; };
add_matcher(std::make_shared<pattern::Matcher>(concat_op, callback)); add_matcher(std::make_shared<pattern::Matcher>(concat_op), callback);
} }
...@@ -30,7 +30,7 @@ ngraph::runtime::plaidml::pass::ConcatSplit::ConcatSplit() ...@@ -30,7 +30,7 @@ ngraph::runtime::plaidml::pass::ConcatSplit::ConcatSplit()
return op != nullptr && kMaxConcatInputs < op->get_input_size(); return op != nullptr && kMaxConcatInputs < op->get_input_size();
}); });
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
auto concat = std::static_pointer_cast<ngraph::op::Concat>(m.get_match_root()); auto concat = std::static_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto args = concat->get_arguments(); auto args = concat->get_arguments();
...@@ -73,5 +73,5 @@ ngraph::runtime::plaidml::pass::ConcatSplit::ConcatSplit() ...@@ -73,5 +73,5 @@ ngraph::runtime::plaidml::pass::ConcatSplit::ConcatSplit()
return true; return true;
}; };
add_matcher(std::make_shared<pattern::Matcher>(concat_op, callback)); add_matcher(std::make_shared<pattern::Matcher>(concat_op), callback);
} }
...@@ -71,7 +71,7 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data ...@@ -71,7 +71,7 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data
}, },
NodeVector{producer_op}); NodeVector{producer_op});
pattern::graph_rewrite_callback callback = [producer_op](pattern::Matcher& m) { auto callback = [producer_op](pattern::Matcher& m) {
auto consumer = m.get_match_root(); auto consumer = m.get_match_root();
auto producer = m.get_pattern_map()[producer_op]; auto producer = m.get_pattern_map()[producer_op];
NGRAPH_DEBUG << "Adding conversion for " << producer->description() << " -> " NGRAPH_DEBUG << "Adding conversion for " << producer->description() << " -> "
...@@ -89,6 +89,6 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data ...@@ -89,6 +89,6 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data
return true; return true;
}; };
auto m = std::make_shared<pattern::Matcher>(data_consumer_op, callback); auto m = std::make_shared<pattern::Matcher>(data_consumer_op);
add_matcher(m); add_matcher(m, callback);
} }
...@@ -39,7 +39,7 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast() ...@@ -39,7 +39,7 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
}, },
NodeVector{broadcast_op}); NodeVector{broadcast_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
// Since the broadcast is going to an elementwise operation, we // Since the broadcast is going to an elementwise operation, we
// can always replace it with an equivalent reshape that uses ones // can always replace it with an equivalent reshape that uses ones
// for the broadcast axes. // for the broadcast axes.
...@@ -70,5 +70,5 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast() ...@@ -70,5 +70,5 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
return true; return true;
}; };
add_matcher(std::make_shared<pattern::Matcher>(target_op, callback)); add_matcher(std::make_shared<pattern::Matcher>(target_op), callback);
} }
...@@ -34,7 +34,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions() ...@@ -34,7 +34,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
pattern::has_class<ngraph::op::ConvolutionBackpropFilters>()(node); pattern::has_class<ngraph::op::ConvolutionBackpropFilters>()(node);
}); });
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
auto to_transpose = [](const std::shared_ptr<Node>& node) -> ngraph::op::Reshape* { auto to_transpose = [](const std::shared_ptr<Node>& node) -> ngraph::op::Reshape* {
if (!node) if (!node)
{ {
...@@ -140,5 +140,5 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions() ...@@ -140,5 +140,5 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
return false; return false;
}; };
add_matcher(std::make_shared<pattern::Matcher>(convolution_op, callback)); add_matcher(std::make_shared<pattern::Matcher>(convolution_op), callback);
} }
...@@ -35,7 +35,7 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination() ...@@ -35,7 +35,7 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
}, },
NodeVector{upper_replicate_op}); NodeVector{upper_replicate_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
auto nodes = m.get_matched_nodes(); auto nodes = m.get_matched_nodes();
auto lower = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(0)); auto lower = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(0));
auto upper = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(1)); auto upper = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(1));
...@@ -54,5 +54,5 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination() ...@@ -54,5 +54,5 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
return true; return true;
}; };
add_matcher(std::make_shared<pattern::Matcher>(lower_replicate_op, callback)); add_matcher(std::make_shared<pattern::Matcher>(lower_replicate_op), callback);
} }
...@@ -43,7 +43,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision() ...@@ -43,7 +43,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
}, },
NodeVector{skip_op}); NodeVector{skip_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
bool replaced_any = false; bool replaced_any = false;
auto nodes = m.get_matched_nodes(); auto nodes = m.get_matched_nodes();
std::size_t dim_limit = nodes.at(1)->get_shape().size(); std::size_t dim_limit = nodes.at(1)->get_shape().size();
...@@ -81,5 +81,5 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision() ...@@ -81,5 +81,5 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
return replaced_any; return replaced_any;
}; };
add_matcher(std::make_shared<pattern::Matcher>(target_op, callback)); add_matcher(std::make_shared<pattern::Matcher>(target_op), callback);
} }
...@@ -111,7 +111,7 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd() ...@@ -111,7 +111,7 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd()
filters_shape.at(1) == 3 && filters_shape.at(2) > 4 && filters_shape.at(3) > 4); filters_shape.at(1) == 3 && filters_shape.at(2) > 4 && filters_shape.at(3) > 4);
}); });
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) { auto callback = [](pattern::Matcher& m) {
auto conv = std::static_pointer_cast<plaidml::op::Convolution>(m.get_match_root()); auto conv = std::static_pointer_cast<plaidml::op::Convolution>(m.get_match_root());
NodeVector args = conv->get_arguments(); NodeVector args = conv->get_arguments();
std::shared_ptr<ngraph::op::Constant> a; std::shared_ptr<ngraph::op::Constant> a;
...@@ -126,5 +126,5 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd() ...@@ -126,5 +126,5 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd()
return true; return true;
}; };
add_matcher(std::make_shared<pattern::Matcher>(convolution_op, callback)); add_matcher(std::make_shared<pattern::Matcher>(convolution_op), callback);
} }
...@@ -131,7 +131,7 @@ TEST(control_dependencies, clone_function_cdop) ...@@ -131,7 +131,7 @@ TEST(control_dependencies, clone_function_cdop)
auto f = make_shared<Function>(cdop, ParameterVector{A}); auto f = make_shared<Function>(cdop, ParameterVector{A});
auto clone = ngraph::clone_function(*f.get()); auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr); auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0); auto cdop_clone = clone->get_results().at(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone)); ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies(); auto cloned_deps = cdop_clone->get_control_dependencies();
...@@ -152,7 +152,7 @@ TEST(control_dependencies, clone_function_cdop_abs) ...@@ -152,7 +152,7 @@ TEST(control_dependencies, clone_function_cdop_abs)
auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B}); auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B});
auto clone = ngraph::clone_function(*f.get()); auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr); auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0); auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone)); ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies(); auto cloned_deps = cdop_clone->get_control_dependencies();
...@@ -175,7 +175,7 @@ TEST(control_dependencies, serialize_cdop) ...@@ -175,7 +175,7 @@ TEST(control_dependencies, serialize_cdop)
string js = serialize(f, 4); string js = serialize(f, 4);
shared_ptr<Function> clone = deserialize(js); shared_ptr<Function> clone = deserialize(js);
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr); auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0); auto cdop_clone = clone->get_results().at(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone)); ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies(); auto cloned_deps = cdop_clone->get_control_dependencies();
...@@ -199,7 +199,7 @@ TEST(control_dependencies, serialize_cdop_abs) ...@@ -199,7 +199,7 @@ TEST(control_dependencies, serialize_cdop_abs)
string js = serialize(f, 4); string js = serialize(f, 4);
shared_ptr<Function> clone = deserialize(js); shared_ptr<Function> clone = deserialize(js);
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr); auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0); auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone)); ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies(); auto cloned_deps = cdop_clone->get_control_dependencies();
......
...@@ -90,7 +90,7 @@ public: ...@@ -90,7 +90,7 @@ public:
auto iconst1 = construct_constant_node(1); auto iconst1 = construct_constant_node(1);
auto pattern = std::make_shared<pattern::op::Label>(iconst1); auto pattern = std::make_shared<pattern::op::Label>(iconst1);
ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) { auto callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against " NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
<< m.get_match_root()->get_name(); << m.get_match_root()->get_name();
NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2); NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
...@@ -126,8 +126,8 @@ public: ...@@ -126,8 +126,8 @@ public:
return true; return true;
}; };
auto m = make_shared<TestMatcher>(pattern * iconst1, callback); auto m = make_shared<TestMatcher>(pattern * iconst1);
this->add_matcher(m); this->add_matcher(m, callback);
} }
void construct_add_zero() void construct_add_zero()
...@@ -172,8 +172,9 @@ public: ...@@ -172,8 +172,9 @@ public:
return true; return true;
}; };
auto m = make_shared<TestMatcher>(pattern + iconst0, callback); auto add = pattern + iconst0;
this->add_matcher(m); auto m = make_shared<TestMatcher>(add);
this->add_matcher(m, callback);
} }
TestGraphRewrite() TestGraphRewrite()
...@@ -446,7 +447,7 @@ TEST(pattern, matcher) ...@@ -446,7 +447,7 @@ TEST(pattern, matcher)
// strict mode // strict mode
{ {
TestMatcher sm(nullptr, nullptr, "TestMatcher", pass::PassProperty::REGULAR_FUSIONS, true); TestMatcher sm(nullptr, "TestMatcher", true);
// exact shape and type // exact shape and type
auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{}); auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
auto label_dynamic_shape = auto label_dynamic_shape =
...@@ -540,7 +541,7 @@ TEST(pattern, recurrent_pattern) ...@@ -540,7 +541,7 @@ TEST(pattern, recurrent_pattern)
auto add3 = iconst0 + add2; auto add3 = iconst0 + add2;
auto padd = iconst0 + rpattern; auto padd = iconst0 + rpattern;
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches; std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
RecurrentMatcher rm(padd, rpattern, empty_correlated_matches, nullptr); RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
ASSERT_TRUE(rm.match(add3)); ASSERT_TRUE(rm.match(add3));
ASSERT_EQ(rm.get_number_of_bound_labels(), 1); ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern); auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
...@@ -554,7 +555,7 @@ TEST(pattern, recurrent_pattern) ...@@ -554,7 +555,7 @@ TEST(pattern, recurrent_pattern)
auto add2_2 = iconst1 + add1; auto add2_2 = iconst1 + add1;
auto add3_2 = iconst0 + add2_2; auto add3_2 = iconst0 + add2_2;
auto padd2 = iconst_label + rpattern; auto padd2 = iconst_label + rpattern;
RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches, nullptr); RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
ASSERT_TRUE(rm2.match(add3_2)); ASSERT_TRUE(rm2.match(add3_2));
ASSERT_EQ(rm2.get_number_of_bound_labels(), 2); ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern); recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
...@@ -569,7 +570,7 @@ TEST(pattern, recurrent_pattern) ...@@ -569,7 +570,7 @@ TEST(pattern, recurrent_pattern)
// Non-matching correlated labels // Non-matching correlated labels
std::set<std::shared_ptr<pattern::op::Label>> correlated_matches; std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
correlated_matches.insert(iconst_label); correlated_matches.insert(iconst_label);
RecurrentMatcher rm3(padd2, rpattern, correlated_matches, nullptr); RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
ASSERT_TRUE(rm3.match(add3_2)); ASSERT_TRUE(rm3.match(add3_2));
ASSERT_EQ(rm3.get_number_of_bound_labels(), 2); ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label); iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
...@@ -602,8 +603,7 @@ public: ...@@ -602,8 +603,7 @@ public:
auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape); auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
auto padd = iconst_label + rpattern; auto padd = iconst_label + rpattern;
ngraph::pattern::recurrent_graph_rewrite_callback callback = [iconst_label, rpattern]( auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
pattern::RecurrentMatcher& rm) {
NGRAPH_DEBUG << "In a callback for construct_recurrent_add against " NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
<< rm.get_match_root()->get_name(); << rm.get_match_root()->get_name();
...@@ -634,9 +634,8 @@ public: ...@@ -634,9 +634,8 @@ public:
}; };
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches; std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto rm = make_shared<pattern::RecurrentMatcher>( auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
padd, rpattern, empty_correlated_matches, callback); this->add_matcher(rm, callback);
this->add_matcher(rm);
} }
TestRecurrentGraphRewrite() TestRecurrentGraphRewrite()
...@@ -697,7 +696,7 @@ TEST(pattern, label_on_skip) ...@@ -697,7 +696,7 @@ TEST(pattern, label_on_skip)
auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred); auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst}); auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher = std::make_shared<pattern::Matcher>( auto matcher = std::make_shared<pattern::Matcher>(
std::make_shared<op::Multiply>(label, bcst_label), nullptr); std::make_shared<op::Multiply>(label, bcst_label), "label_on_skip");
auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1}); auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
auto mul = a * const_broadcast; auto mul = a * const_broadcast;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment