Commit 44f7479d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Move callback out of matcher (#2305)

* ngraph core changed

* cpu bits + pattern

* change perms

* revert std::cout to NGRAPH_DEBUG

* fix comp errors

* fix comment

* more comment fixes

* Remove callback argument

* clean up.

* fixed tests.

* more fixes for backends.

* fixe namespace.

* similar fix for recurrent matcher, and improves.

* fix build.
parent 715eeb37
......@@ -52,7 +52,7 @@ static shared_ptr<pattern::Matcher>
{
auto bcst = make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::Broadcast>());
auto bcst_label = make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher = make_shared<pattern::Matcher>(make_shared<T>(label, bcst_label), nullptr);
auto matcher = make_shared<pattern::Matcher>(make_shared<T>(label, bcst_label));
return matcher;
}
......@@ -86,7 +86,7 @@ static bool simplify_concat(shared_ptr<Node> n)
auto skip_reshape = make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::Reshape>());
auto matcher = make_shared<pattern::Matcher>(skip_reshape, nullptr);
auto matcher = make_shared<pattern::Matcher>(skip_reshape);
Coordinate prev_lower_bounds;
Shape prev_slice_shape;
......
......@@ -80,7 +80,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
auto slice_weights_label =
std::make_shared<pattern::op::Label>(slice_weights, nullptr, NodeVector{slice_weights});
auto conv = std::make_shared<op::Convolution>(slice_data, slice_weights_label);
auto matcher = std::make_shared<pattern::Matcher>(conv, nullptr);
auto matcher = std::make_shared<pattern::Matcher>(conv);
NGRAPH_DEBUG << "In simplify_concat (group convolution) for " << n->get_name();
......
......@@ -113,8 +113,8 @@ void pass::ConcatElimination::construct_concat_elimination()
return false;
};
auto m = std::make_shared<pattern::Matcher>(concat_label, callback);
this->add_matcher(m);
auto m = std::make_shared<pattern::Matcher>(concat_label, "ConcatElimination");
this->add_matcher(m, callback);
}
bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function)
......
......@@ -180,9 +180,8 @@ void pass::ConstantFolding::construct_constant_pad()
return false;
};
auto pad_matcher =
make_shared<pattern::Matcher>(pad, constant_pad_callback, "ConstantFolding.ConstantPad");
this->add_matcher(pad_matcher);
auto pad_matcher = make_shared<pattern::Matcher>(pad, "ConstantFolding.ConstantPad");
this->add_matcher(pad_matcher, constant_pad_callback);
}
void pass::ConstantFolding::construct_constant_reshape()
......@@ -245,9 +244,9 @@ void pass::ConstantFolding::construct_constant_reshape()
return false;
};
auto reshape_matcher = make_shared<pattern::Matcher>(
reshape, constant_reshape_callback, "ConstantFolding.ConstantReshape");
this->add_matcher(reshape_matcher);
auto reshape_matcher =
make_shared<pattern::Matcher>(reshape, "ConstantFolding.ConstantReshape");
this->add_matcher(reshape_matcher, constant_reshape_callback);
}
template <class T>
......@@ -340,9 +339,9 @@ void pass::ConstantFolding::construct_constant_broadcast()
return false;
};
auto broadcast_matcher = make_shared<pattern::Matcher>(
broadcast, constant_broadcast_callback, "ConstantFolding.ConstantBroadcast");
this->add_matcher(broadcast_matcher);
auto broadcast_matcher =
make_shared<pattern::Matcher>(broadcast, "ConstantFolding.ConstantBroadcast");
this->add_matcher(broadcast_matcher, constant_broadcast_callback);
}
template <class T>
......@@ -478,9 +477,8 @@ void pass::ConstantFolding::construct_constant_binary()
return false;
};
auto reshape_matcher = make_shared<pattern::Matcher>(
bea, constant_binary_callback, "ConstantFolding.ConstantBinary");
this->add_matcher(reshape_matcher);
auto reshape_matcher = make_shared<pattern::Matcher>(bea, "ConstantFolding.ConstantBinary");
this->add_matcher(reshape_matcher, constant_binary_callback);
}
bool is_supported_unary_op(std::shared_ptr<Node> n)
......@@ -609,9 +607,8 @@ void pass::ConstantFolding::construct_constant_unary()
return false;
};
auto reshape_matcher = make_shared<pattern::Matcher>(
uea, constant_unary_callback, "ConstantFolding.ConstantUnary");
this->add_matcher(reshape_matcher);
auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary");
this->add_matcher(reshape_matcher, constant_unary_callback);
}
template <class QUANT, class REAL>
......@@ -682,9 +679,9 @@ void pass::ConstantFolding::construct_constant_dequantize()
return false;
};
auto dequantize_matcher = make_shared<pattern::Matcher>(
dequant, constant_dequantize_callback, "ConstantFolding.ConstantDequantize");
this->add_matcher(dequantize_matcher);
auto dequantize_matcher =
make_shared<pattern::Matcher>(dequant, "ConstantFolding.ConstantDequantize");
this->add_matcher(dequantize_matcher, constant_dequantize_callback);
}
template <class REAL, class QUANT>
......@@ -757,7 +754,7 @@ void pass::ConstantFolding::construct_constant_quantize()
return false;
};
auto quantize_matcher = make_shared<pattern::Matcher>(
quant, constant_quantize_callback, "ConstantFolding.ConstantQuantize");
this->add_matcher(quantize_matcher);
auto quantize_matcher =
make_shared<pattern::Matcher>(quant, "ConstantFolding.ConstantQuantize");
this->add_matcher(quantize_matcher, constant_quantize_callback);
}
......@@ -64,7 +64,7 @@ void pass::CoreFusion::construct_relu()
std::make_shared<pattern::op::Skip>(zero, pattern::has_class<op::Broadcast>());
auto max = make_shared<op::Maximum>(skip_broadcast, val);
pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) {
auto callback = [val, zero](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_relu against "
<< m.get_match_root()->get_name();
......@@ -82,8 +82,8 @@ void pass::CoreFusion::construct_relu()
return true;
};
auto m = make_shared<pattern::Matcher>(max, callback, "CoreFusion.Relu");
this->add_matcher(m);
auto m = make_shared<pattern::Matcher>(max, "CoreFusion.Relu");
this->add_matcher(m, callback);
}
void pass::CoreFusion::construct_sigmoid()
......@@ -101,7 +101,7 @@ void pass::CoreFusion::construct_sigmoid()
auto divide_1_over_exp = std::make_shared<op::Divide>(skip_broadcast, add_exp);
// Define a call back that needs to called once the DFG matches the pattern
pattern::graph_rewrite_callback callback = [input, constant](pattern::Matcher& m) {
auto callback = [input, constant](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -130,8 +130,8 @@ void pass::CoreFusion::construct_sigmoid()
return true;
};
auto m = std::make_shared<pattern::Matcher>(divide_1_over_exp, callback, "CoreFusion.Sigmoid");
this->add_matcher(m);
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, "CoreFusion.Sigmoid");
this->add_matcher(m, callback);
}
void pass::CoreFusion::construct_sigmoid_bprop()
......@@ -159,7 +159,7 @@ void pass::CoreFusion::construct_sigmoid_bprop()
auto negative_2 = std::make_shared<op::Negative>(multiply_2);
// Define a call back that needs to called once the DFG matches the pattern
pattern::graph_rewrite_callback callback = [input, delta](pattern::Matcher& m) {
auto callback = [input, delta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_bprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -182,8 +182,8 @@ void pass::CoreFusion::construct_sigmoid_bprop()
return true;
};
auto m = std::make_shared<pattern::Matcher>(negative_2, callback, "CoreFusion.SigmoidBprop");
this->add_matcher(m);
auto m = std::make_shared<ngraph::pattern::Matcher>(negative_2, "CoreFusion.SigmoidBprop");
this->add_matcher(m, callback);
}
void pass::CoreFusion::construct_folded_batch_norm()
......@@ -211,8 +211,7 @@ void pass::CoreFusion::construct_folded_batch_norm()
auto shape_r = Shape{1, 2, 2, 2};
auto bn = std::make_shared<op::BatchNormInference>(eps, gamma, beta, pconv, mean, var);
pattern::graph_rewrite_callback callback = [input, filters, mean, var, gamma, beta](
pattern::Matcher& m) {
auto callback = [input, filters, mean, var, gamma, beta](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for folded batch norm against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -263,8 +262,8 @@ void pass::CoreFusion::construct_folded_batch_norm()
};
auto m = std::make_shared<pattern::Matcher>(bn, callback, "CoreFusion.FoldedBatchNorm");
this->add_matcher(m);
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, "CoreFusion.FoldedBatchNorm");
this->add_matcher(m, callback);
}
void pass::CoreFusion::construct_conv_affine_folding()
......@@ -292,8 +291,7 @@ void pass::CoreFusion::construct_conv_affine_folding()
auto multiply = std::make_shared<op::Multiply>(conv_label, A_label);
auto add = std::make_shared<op::Add>(multiply, B_label);
pattern::graph_rewrite_callback callback =
[input, filters, conv_label, A_label, B_label](pattern::Matcher& m) {
auto callback = [input, filters, conv_label, A_label, B_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for conv affine folding against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -374,8 +372,8 @@ void pass::CoreFusion::construct_conv_affine_folding()
};
auto m = std::make_shared<pattern::Matcher>(add, callback, "CoreFusion.ConvAffineFolding");
this->add_matcher(m);
auto m = make_shared<pattern::Matcher>(add, "CoreFusion.ConvAffineFolding");
this->add_matcher(m, callback);
}
static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv,
......@@ -445,7 +443,7 @@ void pass::CoreFusion::construct_reshape_broadcast()
auto reshape1 = make_shared<op::Reshape>(input, AxisVector{0}, Shape{10, 1});
auto broadcast = make_shared<op::Broadcast>(reshape1, Shape{10, 1, 20}, AxisSet{2});
pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
auto callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_reshape_broadcast against "
<< m.get_match_root()->get_name();
......@@ -504,8 +502,8 @@ void pass::CoreFusion::construct_reshape_broadcast()
return true;
};
auto m = make_shared<pattern::Matcher>(broadcast, callback, "CoreFusion.ReshapeBroadcast");
this->add_matcher(m);
auto m = make_shared<pattern::Matcher>(broadcast, "CoreFusion.ReshapeBroadcast");
this->add_matcher(m, callback);
}
// conv(56w3s1) conv(28w3s2)
......@@ -550,7 +548,7 @@ void pass::CoreFusion::construct_optimized_strided_conv()
auto weights_eltwise = std::make_shared<pattern::op::Label>(element::f32, win_size_1);
auto eltwise_conv = std::make_shared<op::Convolution>(eltwise_label, weights_eltwise);
pattern::graph_rewrite_callback callback = [win_size_1,
auto callback = [win_size_1,
eltwise_label,
conv_stride1_label,
conv_stride3_label,
......@@ -692,9 +690,8 @@ void pass::CoreFusion::construct_optimized_strided_conv()
return true;
};
auto m =
make_shared<pattern::Matcher>(eltwise_conv, callback, "CoreFusion.OptimizedStridedConv");
this->add_matcher(m);
auto m = make_shared<pattern::Matcher>(eltwise_conv, "CoreFusion.OptimizedStridedConv");
this->add_matcher(m, callback);
}
void pass::CoreFusion::construct_reshape_softmax_reshape()
......@@ -706,7 +703,7 @@ void pass::CoreFusion::construct_reshape_softmax_reshape()
auto softmax = make_shared<op::Softmax>(reshape1, AxisSet{1});
auto reshape2 = make_shared<op::Reshape>(softmax, io, input_shape);
pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
auto callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_reshape_softmax_reshape against "
<< m.get_match_root()->get_name();
......@@ -740,8 +737,8 @@ void pass::CoreFusion::construct_reshape_softmax_reshape()
return true;
};
auto m = make_shared<pattern::Matcher>(reshape2, callback, "CoreFusion.ReshapeSoftmaxReshape");
this->add_matcher(m);
auto m = make_shared<pattern::Matcher>(reshape2, "CoreFusion.ReshapeSoftmaxReshape");
this->add_matcher(m, callback);
}
void ngraph::pass::CoreFusion::construct_conv_bias()
......@@ -762,7 +759,7 @@ void ngraph::pass::CoreFusion::construct_conv_bias()
Strides{1, 1});
auto p_conv_bias = pbroadcast + pconv1;
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_conv_bias against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -816,9 +813,8 @@ void ngraph::pass::CoreFusion::construct_conv_bias()
return true;
};
auto m =
std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, callback, "CoreFusion.ConvBias");
this->add_matcher(m);
auto m = std::make_shared<ngraph::pattern::Matcher>(p_conv_bias, "CoreFusion.ConvBias");
this->add_matcher(m, callback);
}
void ngraph::pass::CoreFusion::construct_conv_bias_add()
......@@ -839,7 +835,7 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add()
auto add_input = std::make_shared<pattern::op::Label>(element::f32, pconv->get_shape());
auto padd = std::make_shared<ngraph::op::Add>(add_input, pconv);
pattern::graph_rewrite_callback callback = [data_batch, filters](pattern::Matcher& m) {
auto callback = [data_batch, filters](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_conv_sum against "
<< m.get_match_root()->get_name();
......@@ -867,6 +863,6 @@ void ngraph::pass::CoreFusion::construct_conv_bias_add()
return true;
};
auto m = std::make_shared<pattern::Matcher>(padd, callback, "CoreFusion.ConvBiasAdd");
this->add_matcher(m);
auto m = std::make_shared<pattern::Matcher>(padd, "CoreFusion.ConvBiasAdd");
this->add_matcher(m, callback);
}
......@@ -22,7 +22,6 @@
#include "graph_rewrite.hpp"
#include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
......@@ -64,35 +63,28 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
bool rewritten = false;
const size_t NUM_TRIES = 10;
size_t tries = NUM_TRIES;
vector<shared_ptr<pattern::Matcher>> original_matchers{m_matchers};
bool is_dynamic_function = f->is_dynamic();
vector<MatchClosure> original_matchers{m_matchers};
do
{
rewritten = false;
vector<shared_ptr<pattern::Matcher>> matchers{m_matchers};
// m_matchers may contain newly constructed matchers for matchers
// that need multiple passes. See comments above.
vector<MatchClosure> run_matchers{m_matchers};
m_matchers.clear();
for (auto node : f->get_ordered_ops())
{
for (auto matcher : matchers)
for (auto& mc : run_matchers)
{
if (is_dynamic_function &&
(matcher->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE)))
{
NGRAPH_DEBUG
<< "matcher requires static shape but the function is dynamic, "
<< "skipping this optimization till the shapes are fully materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << matcher->get_name() << "("
<< matcher->get_pattern()->get_name() << ") on " << node->get_name();
if (matcher->match(node))
{
NGRAPH_DEBUG << "Matcher " << matcher << matcher->get_name() << " matched "
NGRAPH_DEBUG << "Running matcher " << mc.matcher->get_name() << "("
<< mc.matcher->get_pattern()->get_name() << ") on "
<< node->get_name();
if (matcher->process_match())
if (mc.matcher->match(node))
{
NGRAPH_DEBUG << "Matcher " << mc.matcher << mc.matcher->get_name()
<< " matched " << node->get_name();
if (mc.callback(*mc.matcher.get()))
{
rewritten = true;
is_dynamic_function = f->is_dynamic();
break;
}
}
......@@ -105,7 +97,7 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
return (NUM_TRIES - tries) > 1; //this means a graph was transformed
}
static const vector<regex> initialize_fusion_regexes()
static vector<regex> initialize_fusion_regexes()
{
const char* cnsf = getenv("NGRAPH_DISABLED_FUSIONS");
vector<regex> regexes;
......@@ -122,7 +114,7 @@ static const vector<regex> initialize_fusion_regexes()
return regexes;
}
bool pass::GraphRewrite::is_enabled(shared_ptr<pattern::Matcher> m)
bool pass::GraphRewrite::is_enabled(const shared_ptr<pattern::Matcher>& m) const
{
//note, regexes are static to avoid re-initialization
static const auto regexes = initialize_fusion_regexes();
......@@ -139,41 +131,39 @@ bool pass::GraphRewrite::is_enabled(shared_ptr<pattern::Matcher> m)
return true;
}
void pass::GraphRewrite::add_matcher(shared_ptr<pattern::Matcher> m)
void pass::GraphRewrite::add_matcher(const shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback)
{
if (is_enabled(m))
{
m_matchers.push_back(m);
m_matchers.push_back({m, callback});
}
}
void pass::RecurrentGraphRewrite::add_matcher(
const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback)
{
m_matchers.push_back({m, callback});
}
bool pass::RecurrentGraphRewrite::run_on_function(shared_ptr<Function> f)
{
bool changed = false;
size_t i = 0;
bool is_dynamic_function = f->is_dynamic();
do
{
for (auto node : f->get_ops())
{
for (auto matcher : m_matchers)
for (auto& mc : m_matchers)
{
if (is_dynamic_function &&
(matcher->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE)))
{
NGRAPH_DEBUG
<< "matcher requires static shape but the function is dynamic, "
<< "skipping this optimization till the shapes are fully materialized";
continue;
}
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node->get_name();
if (matcher->match(node))
NGRAPH_DEBUG << "Running matcher " << mc.matcher << " on " << node->get_name();
if (mc.matcher->match(node))
{
NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node->get_name();
if (matcher->process_match())
NGRAPH_DEBUG << "Matcher " << mc.matcher << " matched " << node->get_name();
if (mc.callback(*mc.matcher.get()))
{
changed = true;
is_dynamic_function = f->is_dynamic();
goto next_fusion;
}
}
......
......@@ -17,8 +17,11 @@
#pragma once
#include <functional>
#include <memory>
#include <set>
#include "ngraph/pass/pass.hpp"
#include "ngraph/pattern/matcher.hpp"
namespace ngraph
{
......@@ -27,11 +30,10 @@ namespace ngraph
class GraphRewrite;
class RecurrentGraphRewrite;
}
namespace pattern
{
class Matcher;
class RecurrentMatcher;
}
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
using recurrent_graph_rewrite_callback =
std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
}
/// \brief GraphRewrite (in tandem with \sa Matcher) performs transformations on specified patterns
......@@ -52,13 +54,20 @@ public:
{
}
bool is_enabled(std::shared_ptr<pattern::Matcher> m);
void add_matcher(std::shared_ptr<pattern::Matcher> m);
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const ngraph::graph_rewrite_callback& callback);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
protected:
bool is_enabled(const std::shared_ptr<pattern::Matcher>& m) const;
private:
// enable cascading rewrites
std::vector<std::shared_ptr<pattern::Matcher>> m_matchers;
struct MatchClosure
{
std::shared_ptr<pattern::Matcher> matcher;
ngraph::graph_rewrite_callback callback;
};
std::vector<MatchClosure> m_matchers;
};
class ngraph::pass::RecurrentGraphRewrite : public FunctionPass
......@@ -70,10 +79,17 @@ public:
{
}
void add_matcher(std::shared_ptr<pattern::RecurrentMatcher> m) { m_matchers.push_back(m); }
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
const ngraph::recurrent_graph_rewrite_callback& callback);
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
private:
size_t m_num_iters;
std::vector<std::shared_ptr<pattern::RecurrentMatcher>> m_matchers;
struct MatchClosure
{
std::shared_ptr<pattern::RecurrentMatcher> matcher;
ngraph::recurrent_graph_rewrite_callback callback;
};
std::vector<MatchClosure> m_matchers;
};
......@@ -77,9 +77,9 @@ pass::PrefixReshapeElimination::PrefixReshapeElimination()
},
NodeVector{reshape_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
replace_node(m.get_matched_nodes().at(1), m.get_matched_nodes().at(2));
return true;
};
add_matcher(make_shared<pattern::Matcher>(target_op, callback));
add_matcher(make_shared<pattern::Matcher>(target_op, "PrefixReshapeElimination"), callback);
}
......@@ -72,8 +72,8 @@ void pass::ReshapeElimination::construct_identity_reshape_pattern()
return true;
};
auto m = make_shared<pattern::Matcher>(reshape1, callback);
this->add_matcher(m);
auto m = make_shared<pattern::Matcher>(reshape1);
this->add_matcher(m, callback);
}
void pass::ReshapeElimination::construct_reshapex2_pattern()
......@@ -131,8 +131,8 @@ void pass::ReshapeElimination::construct_reshapex2_pattern()
return false;
};
auto m = make_shared<pattern::Matcher>(reshape2, callback);
this->add_matcher(m);
auto m = make_shared<pattern::Matcher>(reshape2);
this->add_matcher(m, callback);
}
void pass::ReshapeElimination::construct_dot_transpose_pattern()
......@@ -145,7 +145,7 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern()
auto pdot = make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
auto preshape = make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.get_match_root()->get_name();
......@@ -188,8 +188,8 @@ void pass::ReshapeElimination::construct_dot_transpose_pattern()
return true;
};
auto m = make_shared<pattern::Matcher>(preshape, callback);
this->add_matcher(m);
auto m = make_shared<pattern::Matcher>(preshape);
this->add_matcher(m, callback);
}
void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
......@@ -289,7 +289,7 @@ void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
return modify_graph;
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
reshape_label, op, empty_correlated_matches, callback);
this->add_matcher(m);
auto m =
std::make_shared<pattern::RecurrentMatcher>(reshape_label, op, empty_correlated_matches);
this->add_matcher(m, callback);
}
......@@ -87,7 +87,7 @@ namespace ngraph
for (auto entry : m_pattern_map)
{
// leaf label
if (entry.first->get_input_size() == 0)
if (entry.first->get_inputs().empty())
{
label_exclusions.push_back(entry.second);
}
......@@ -194,7 +194,7 @@ namespace ngraph
// when their individual GOE are matched
// this also gives a bit more flexibility since we don't have to worry
// about *all* outputs of a pattern node but only the ones we want to match.
if (m_strict_mode && graph_node->get_output_size() == 1)
if (m_strict_mode && graph_node->get_outputs().size() == 1)
{
bool shape_match = pattern_node->get_output_partial_shape(0).compatible(
graph_node->get_output_partial_shape(0));
......@@ -328,26 +328,6 @@ namespace ngraph
return false;
}
bool Matcher::process_match(::ngraph::pattern::graph_rewrite_callback callback)
{
graph_rewrite_callback cb = m_callback;
if (callback)
{
cb = callback;
}
if (!cb)
{
throw ngraph_error("process_match invoked w/o a callback function");
}
if (!this->m_match_root)
{
throw ngraph_error("process_match invoked w/o a match");
}
return cb(*this);
}
bool Matcher::match(const std::shared_ptr<Node>& graph_node)
{
// clear our state
......@@ -397,11 +377,6 @@ namespace ngraph
return is_match;
}
bool Matcher::get_property(const ngraph::pass::PassPropertyMask& prop) const
{
return m_property.is_set(prop);
}
bool RecurrentMatcher::match(std::shared_ptr<Node> graph)
{
bool matched = false;
......@@ -457,11 +432,5 @@ namespace ngraph
return matched;
}
bool RecurrentMatcher::process_match() { return m_callback(*this); }
bool RecurrentMatcher::get_property(const ngraph::pass::PassPropertyMask& prop) const
{
return m_property.is_set(prop);
}
}
}
......@@ -25,7 +25,6 @@
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
......@@ -36,8 +35,6 @@ namespace ngraph
namespace pattern
{
using graph_rewrite_callback = std::function<bool(class Matcher& m)>;
using recurrent_graph_rewrite_callback = std::function<bool(class RecurrentMatcher& m)>;
using RPatternMap = std::map<std::shared_ptr<op::Label>, NodeVector>;
template <typename T>
......@@ -61,21 +58,35 @@ namespace ngraph
{
public:
using PatternMap = std::map<std::shared_ptr<op::Label>, std::shared_ptr<Node>>;
// Avoid implicit string construction from nullptr.
Matcher(const std::shared_ptr<Node>& pattern_node, std::nullptr_t name) = delete;
Matcher(const std::shared_ptr<Node>& pattern_node)
: m_pattern_node{pattern_node}
, m_depth{0}
, m_name{"Unnamed"}
, m_strict_mode{false}
{
}
Matcher(const std::shared_ptr<Node>& pattern_node, const std::string& name)
: m_pattern_node(pattern_node)
, m_depth{0}
, m_name{name}
, m_strict_mode{false}
{
}
/// \brief Constructs a Matcher object
///
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param callback is a callback function that will be called on a successful match
Matcher(const std::shared_ptr<Node> pattern_node = nullptr,
graph_rewrite_callback callback = nullptr,
const std::string& name = "Unnamed",
pass::PassPropertyMask property = pass::PassProperty::REGULAR_FUSIONS,
bool strict_mode = false)
/// \param name is a string which is used for logging and disabling a matcher
/// \param strict_mode forces a matcher to consider shapes and ET of nodes
Matcher(const std::shared_ptr<Node>& pattern_node,
const std::string& name,
bool strict_mode)
: m_pattern_node(pattern_node)
, m_callback(callback)
, m_depth(0)
, m_name(name)
, m_property(property)
, m_strict_mode(strict_mode)
{
}
......@@ -113,13 +124,10 @@ namespace ngraph
return matched;
}
bool get_property(const pass::PassPropertyMask& prop) const;
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
bool process_match(graph_rewrite_callback callback = nullptr);
NodeVector get_matched_nodes() { return m_matched_list; }
const NodeVector& get_matched_nodes() { return m_matched_list; }
void reset() {}
std::string get_name() { return m_name; }
const std::string& get_name() { return m_name; }
std::shared_ptr<Node> get_pattern() { return m_pattern_node; }
std::shared_ptr<Node> get_match_root();
PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; }
......@@ -173,10 +181,8 @@ namespace ngraph
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
graph_rewrite_callback m_callback;
size_t m_depth;
std::string m_name;
pass::PassPropertyMask m_property;
bool m_strict_mode;
};
......@@ -189,17 +195,12 @@ namespace ngraph
/// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same across all cells
// \param is a callback function that will be called on a successful match
RecurrentMatcher(std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns,
recurrent_graph_rewrite_callback callback,
pass::PassPropertyMask property = pass::PassProperty::REGULAR_FUSIONS)
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
: m_pattern(pattern)
, m_recurrent_pattern(rpattern)
, m_correlated_patterns(correlated_patterns)
, m_callback(callback)
, m_property(property)
{
}
......@@ -229,19 +230,12 @@ namespace ngraph
/// \brief Tries to match a pattern for an individual cell to a given \p graph
bool match(std::shared_ptr<Node> graph);
/// \brief Invoked by a pass to process a successful match
bool process_match();
bool get_property(const pass::PassPropertyMask& prop) const;
std::shared_ptr<Node> get_match_root() { return m_match_root; }
private:
std::shared_ptr<Node> m_pattern;
std::shared_ptr<op::Label> m_recurrent_pattern;
const std::set<std::shared_ptr<op::Label>> m_correlated_patterns;
RPatternMap m_matches;
recurrent_graph_rewrite_callback m_callback;
pass::PassPropertyMask m_property;
std::shared_ptr<Node> m_match_root;
};
}
......
......@@ -19,7 +19,6 @@
#include <functional>
#include "ngraph/node.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
......
This diff is collapsed.
......@@ -101,7 +101,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion
Strides{1, 1},
true);
pattern::graph_rewrite_callback callback = [data_conv](pattern::Matcher& m) {
auto callback = [data_conv](pattern::Matcher& m) {
NGRAPH_DEBUG << "conv_horizontal_fusion: In a callback for conv horizontal fusion for "
<< m.get_match_root()->get_name();
......@@ -188,7 +188,7 @@ void ngraph::runtime::cpu::pass::CPUHorizontalFusion::cpu_conv_horizontal_fusion
return true;
};
auto m = make_shared<pattern::Matcher>(
conv_bias, callback, "CPUHorizontalFusion.CpuConvHorizontalFusion");
this->add_matcher(m);
auto m =
make_shared<pattern::Matcher>(conv_bias, "CPUHorizontalFusion.CpuConvHorizontalFusion");
this->add_matcher(m, callback);
}
......@@ -422,7 +422,7 @@ std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
auto slice_weights_label =
std::make_shared<pattern::op::Label>(slice_weights, nullptr, NodeVector{slice_weights});
auto conv = std::make_shared<op::Convolution>(slice_data, slice_weights_label);
auto matcher = std::make_shared<pattern::Matcher>(conv, nullptr);
auto matcher = std::make_shared<pattern::Matcher>(conv);
NGRAPH_DEBUG << "In simplify_concat (group convolution) for " << n->get_name();
......
......@@ -53,7 +53,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu
auto conv = std::make_shared<ngraph::op::Convolution>(
data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1});
pattern::graph_rewrite_callback callback = [param](pattern::Matcher& m) {
auto callback = [param](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_weight against "
<< m.get_match_root()->get_name();
......@@ -116,9 +116,9 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu
return true;
};
auto m = make_shared<pattern::Matcher>(
conv, callback, "CPUPostLayoutOptimizations.ConstructWeight_fusion");
this->add_matcher(m);
auto m =
make_shared<pattern::Matcher>(conv, "CPUPostLayoutOptimizations.ConstructWeight_fusion");
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_slice_convertLayout_fusion()
......@@ -130,7 +130,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_slice_con
auto lt_desc = std::make_shared<runtime::cpu::LayoutDescriptor>(*tvt);
auto cvt_lt = std::make_shared<runtime::cpu::op::ConvertLayout>(slice, lt_desc);
pattern::graph_rewrite_callback callback = [param](pattern::Matcher& m) {
auto callback = [param](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_slice_converLayout against "
<< m.get_match_root()->get_name();
......@@ -169,8 +169,8 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_slice_con
};
auto m = make_shared<pattern::Matcher>(
cvt_lt, callback, "CPUPostLayoutOptimizations.ConstructSliceConvertLayoutFusion");
this->add_matcher(m);
cvt_lt, "CPUPostLayoutOptimizations.ConstructSliceConvertLayoutFusion");
this->add_matcher(m, callback);
}
// Reshape(transpose) + ConvertLayout
......@@ -193,7 +193,7 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::
std::make_shared<runtime::cpu::LayoutDescriptor>(*reshape->get_output_tensor_ptr());
auto cvt_lt = std::make_shared<runtime::cpu::op::ConvertLayout>(reshape, lt_desc);
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_reshape_converLayout against "
<< m.get_match_root()->get_name();
......@@ -264,6 +264,6 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::
};
auto m = make_shared<pattern::Matcher>(
cvt_lt, callback, "CPUPostLayoutOptimizations.ConstructReshapeConvertLayoutFusion");
this->add_matcher(m);
cvt_lt, "CPUPostLayoutOptimizations.ConstructReshapeConvertLayoutFusion");
this->add_matcher(m, callback);
}
......@@ -78,7 +78,7 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
auto divide_1_over_exp = std::make_shared<ngraph::op::Divide>(broadcast_constant, add_exp);
// Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
auto callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -102,9 +102,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_sigmoid()
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(
divide_1_over_exp, callback, "LSTMFusion.Sigmoid");
this->add_matcher(m);
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, "LSTMFusion.Sigmoid");
this->add_matcher(m, callback);
}
static void replace_collapse_node_user(std::shared_ptr<Node> collapsed_node,
......@@ -184,14 +183,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
std::make_shared<ngraph::op::Multiply>(ot, std::make_shared<ngraph::op::Tanh>(ct_label));
// Define a call back that needs to called once the DFG matches the pattern
pattern::graph_rewrite_callback callback = [ct_label,
w_i2h,
bias_i2h,
w_h2h,
bias_h2h,
xt,
ht_1,
ct_1](pattern::Matcher& m) {
auto callback = [ct_label, w_i2h, bias_i2h, w_h2h, bias_h2h, xt, ht_1, ct_1](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_lstm pattern against "
<< m.get_match_root()->get_name();
......@@ -326,8 +319,8 @@ void ngraph::runtime::cpu::pass::LSTMFusion::construct_lstm_fprop()
ngraph::replace_node(m.get_match_root(), ht_slice);
return true;
};
auto m = std::make_shared<pattern::Matcher>(ht, callback, "LSTMFusion.Fprop");
this->add_matcher(m);
auto m = std::make_shared<pattern::Matcher>(ht, "LSTMFusion.Fprop");
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
......@@ -378,13 +371,12 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
auto lstm_goe_slice =
std::make_shared<ngraph::op::Slice>(lstm_goe_label, Coordinate{10, 0}, Coordinate{20, 100});
pattern::recurrent_graph_rewrite_callback callback = [lstm_goe_label,
auto callback = [lstm_goe_label,
lstm_src_layer,
lstm_src_iter_label,
lstm_weights_layer_label,
lstm_weights_iter_label,
lstm_bias_label](
pattern::RecurrentMatcher& m) {
lstm_bias_label](pattern::RecurrentMatcher& m) {
NGRAPH_DEBUG << " In recurrent RNN fusion callback";
......@@ -568,9 +560,8 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
std::set<std::shared_ptr<pattern::op::Label>>{lstm_weights_layer_shared,
lstm_weights_iter_shared,
lstm_bias_layer_shared,
lstm_bias_iter_shared},
callback);
this->add_matcher(m);
lstm_bias_iter_shared});
this->add_matcher(m, callback);
}
static std::shared_ptr<Node> stack_rnn_inputs(NodeVector rnn_input_nodes)
......@@ -613,13 +604,12 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto rnn_goe0_label =
std::make_shared<pattern::op::Label>(rnn_goe0, nullptr, NodeVector{rnn_goe0});
pattern::recurrent_graph_rewrite_callback callback = [rnn_src_layer,
auto callback = [rnn_src_layer,
rnn_src_iter,
rnn_weights_layer,
rnn_weights_iter,
rnn_bias,
rnn_goe0_label](
pattern::RecurrentMatcher& m) {
rnn_goe0_label](pattern::RecurrentMatcher& m) {
auto number_of_rnn_cell_matched = m.get_number_of_recurrent_matches();
NGRAPH_DEBUG << " In Recurrent multi layer RNN fusion callback ";
NGRAPH_DEBUG << " Number of RNN's Matched: " << number_of_rnn_cell_matched;
......@@ -775,8 +765,8 @@ void ngraph::runtime::cpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
rnn_goe0_label, rnn_src_layer, empty_correlated_matches, callback);
this->add_matcher(m);
rnn_goe0_label, rnn_src_layer, empty_correlated_matches);
this->add_matcher(m, callback);
}
void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
......@@ -813,8 +803,7 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
NodeVector{rnn_ltor_goe0_reshape_tnc, skip_reverse_seq}, 0);
// Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [rnn_left_to_right,
rnn_right_to_left](pattern::Matcher& m) {
auto callback = [rnn_left_to_right, rnn_right_to_left](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto rnn_ltor_node =
......@@ -914,6 +903,6 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(concat, callback);
this->add_matcher(m);
auto m = std::make_shared<ngraph::pattern::Matcher>(concat, "BiDirectionalRnn");
this->add_matcher(m, callback);
}
......@@ -70,7 +70,7 @@ void ngraph::runtime::gpu::pass::LSTMFusion::construct_sigmoid()
auto divide_1_over_exp = std::make_shared<op::Divide>(broadcast_constant, add_exp);
//Define a call back that needs to called once the DFG matches the pattern
ngraph::pattern::graph_rewrite_callback callback = [input](pattern::Matcher& m) {
auto callback = [input](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
......@@ -94,8 +94,8 @@ void ngraph::runtime::gpu::pass::LSTMFusion::construct_sigmoid()
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, callback);
this->add_matcher(m);
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp);
this->add_matcher(m, callback);
}
static std::shared_ptr<Node> compute_lstm_params(const std::shared_ptr<Node>& w_x,
......@@ -177,7 +177,7 @@ void ngraph::runtime::gpu::pass::LSTMFusion::construct_lstm_fprop()
auto ht_label = std::make_shared<pattern::op::Label>(ht, nullptr, NodeVector{ht});
//Define a call back that needs to called once the DFG matches the pattern
pattern::graph_rewrite_callback callback = [ct_label,
auto callback = [ct_label,
input_xt,
weights_i2h,
hidden_ht,
......@@ -335,8 +335,8 @@ void ngraph::runtime::gpu::pass::LSTMFusion::construct_lstm_fprop()
ngraph::replace_node(m.get_match_root(), ht_output);
return true;
};
auto m = std::make_shared<pattern::Matcher>(ht, callback);
this->add_matcher(m);
auto m = std::make_shared<pattern::Matcher>(ht);
this->add_matcher(m, callback);
}
static std::shared_ptr<ngraph::Node>
......@@ -393,11 +393,7 @@ void ngraph::runtime::gpu::pass::RNNFusion::construct_rnn_lstm_fprop()
auto goe = std::make_shared<op::GetOutputElement>(lstm, 0); // hidden output
auto lstm_node_label = std::make_shared<pattern::op::Label>(goe, nullptr, NodeVector{goe});
pattern::recurrent_graph_rewrite_callback callback = [lstm_node_label,
xt,
ht_1,
params_label,
rpattern_ct_1](
auto callback = [lstm_node_label, xt, ht_1, params_label, rpattern_ct_1](
pattern::RecurrentMatcher& m) {
NGRAPH_DEBUG << " In RNN fusion callback";
......@@ -603,8 +599,8 @@ void ngraph::runtime::gpu::pass::RNNFusion::construct_rnn_lstm_fprop()
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
lstm_node_label, rpattern_ct_1, empty_correlated_matches, callback);
this->add_matcher(m);
lstm_node_label, rpattern_ct_1, empty_correlated_matches);
this->add_matcher(m, callback);
}
static std::shared_ptr<Node>
......@@ -683,8 +679,7 @@ void ngraph::runtime::gpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
auto rnn_ht_label =
std::make_shared<pattern::op::Label>(rnn_ht_out, nullptr, NodeVector{rnn_ht_out});
pattern::recurrent_graph_rewrite_callback callback =
[src_layer_label, src_iter_label, params_label, state_iter_label, rnn_ht_label](
auto callback = [src_layer_label, src_iter_label, params_label, state_iter_label, rnn_ht_label](
pattern::RecurrentMatcher& m) {
if (m.get_number_of_recurrent_matches() <= 1)
......@@ -826,8 +821,7 @@ void ngraph::runtime::gpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
continue;
}
if (auto rnn_goe_node =
std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
if (auto rnn_goe_node = std::dynamic_pointer_cast<op::GetOutputElement>(rnn_goes))
{
// we need to only replace the {ht} consumers of the last RNN layer,
// since for other layers the intermediate outputs {ht} will be computed
......@@ -852,6 +846,6 @@ void ngraph::runtime::gpu::pass::MultiLayerRNNFusion::construct_multi_layer_rnn_
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto m = std::make_shared<pattern::RecurrentMatcher>(
rnn_ht_label, src_layer_label, empty_correlated_matches, callback);
this->add_matcher(m);
rnn_ht_label, src_layer_label, empty_correlated_matches);
this->add_matcher(m, callback);
}
......@@ -32,7 +32,7 @@ ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision()
return dynamic_cast<ngraph::op::Concat*>(node.get()) != nullptr;
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
auto concat = std::static_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto args = concat->get_arguments();
......@@ -114,5 +114,5 @@ ngraph::runtime::plaidml::pass::ConcatElision::ConcatElision()
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(concat_op, callback));
add_matcher(std::make_shared<pattern::Matcher>(concat_op), callback);
}
......@@ -30,7 +30,7 @@ ngraph::runtime::plaidml::pass::ConcatSplit::ConcatSplit()
return op != nullptr && kMaxConcatInputs < op->get_input_size();
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
auto concat = std::static_pointer_cast<ngraph::op::Concat>(m.get_match_root());
auto args = concat->get_arguments();
......@@ -73,5 +73,5 @@ ngraph::runtime::plaidml::pass::ConcatSplit::ConcatSplit()
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(concat_op, callback));
add_matcher(std::make_shared<pattern::Matcher>(concat_op), callback);
}
......@@ -71,7 +71,7 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data
},
NodeVector{producer_op});
pattern::graph_rewrite_callback callback = [producer_op](pattern::Matcher& m) {
auto callback = [producer_op](pattern::Matcher& m) {
auto consumer = m.get_match_root();
auto producer = m.get_pattern_map()[producer_op];
NGRAPH_DEBUG << "Adding conversion for " << producer->description() << " -> "
......@@ -89,6 +89,6 @@ void ngraph::runtime::plaidml::pass::ExplicitLogicals::construct_logical_to_data
return true;
};
auto m = std::make_shared<pattern::Matcher>(data_consumer_op, callback);
add_matcher(m);
auto m = std::make_shared<pattern::Matcher>(data_consumer_op);
add_matcher(m, callback);
}
......@@ -39,7 +39,7 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
},
NodeVector{broadcast_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
// Since the broadcast is going to an elementwise operation, we
// can always replace it with an equivalent reshape that uses ones
// for the broadcast axes.
......@@ -70,5 +70,5 @@ ngraph::runtime::plaidml::pass::ImplicitBroadcast::ImplicitBroadcast()
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(target_op, callback));
add_matcher(std::make_shared<pattern::Matcher>(target_op), callback);
}
......@@ -34,7 +34,7 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
pattern::has_class<ngraph::op::ConvolutionBackpropFilters>()(node);
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
auto to_transpose = [](const std::shared_ptr<Node>& node) -> ngraph::op::Reshape* {
if (!node)
{
......@@ -140,5 +140,5 @@ ngraph::runtime::plaidml::pass::LowerConvolutions::LowerConvolutions()
return false;
};
add_matcher(std::make_shared<pattern::Matcher>(convolution_op, callback));
add_matcher(std::make_shared<pattern::Matcher>(convolution_op), callback);
}
......@@ -35,7 +35,7 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
},
NodeVector{upper_replicate_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
auto nodes = m.get_matched_nodes();
auto lower = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(0));
auto upper = std::static_pointer_cast<plaidml::op::Replicate>(nodes.at(1));
......@@ -54,5 +54,5 @@ ngraph::runtime::plaidml::pass::ReplicateCombination::ReplicateCombination()
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(lower_replicate_op, callback));
add_matcher(std::make_shared<pattern::Matcher>(lower_replicate_op), callback);
}
......@@ -43,7 +43,7 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
},
NodeVector{skip_op});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
bool replaced_any = false;
auto nodes = m.get_matched_nodes();
std::size_t dim_limit = nodes.at(1)->get_shape().size();
......@@ -81,5 +81,5 @@ ngraph::runtime::plaidml::pass::ReplicateElision::ReplicateElision()
return replaced_any;
};
add_matcher(std::make_shared<pattern::Matcher>(target_op, callback));
add_matcher(std::make_shared<pattern::Matcher>(target_op), callback);
}
......@@ -111,7 +111,7 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd()
filters_shape.at(1) == 3 && filters_shape.at(2) > 4 && filters_shape.at(3) > 4);
});
pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto callback = [](pattern::Matcher& m) {
auto conv = std::static_pointer_cast<plaidml::op::Convolution>(m.get_match_root());
NodeVector args = conv->get_arguments();
std::shared_ptr<ngraph::op::Constant> a;
......@@ -126,5 +126,5 @@ ngraph::runtime::plaidml::pass::Winograd::Winograd()
return true;
};
add_matcher(std::make_shared<pattern::Matcher>(convolution_op, callback));
add_matcher(std::make_shared<pattern::Matcher>(convolution_op), callback);
}
......@@ -131,7 +131,7 @@ TEST(control_dependencies, clone_function_cdop)
auto f = make_shared<Function>(cdop, ParameterVector{A});
auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
......@@ -152,7 +152,7 @@ TEST(control_dependencies, clone_function_cdop_abs)
auto f = make_shared<Function>(absn_cdop, ParameterVector{A, B});
auto clone = ngraph::clone_function(*f.get());
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
......@@ -175,7 +175,7 @@ TEST(control_dependencies, serialize_cdop)
string js = serialize(f, 4);
shared_ptr<Function> clone = deserialize(js);
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
......@@ -199,7 +199,7 @@ TEST(control_dependencies, serialize_cdop_abs)
string js = serialize(f, 4);
shared_ptr<Function> clone = deserialize(js);
auto matcher = std::make_shared<pattern::Matcher>(cdop, nullptr);
auto matcher = std::make_shared<pattern::Matcher>(cdop);
auto cdop_clone = clone->get_results().at(0)->get_argument(0)->get_argument(0);
ASSERT_TRUE(matcher->match(cdop_clone));
auto cloned_deps = cdop_clone->get_control_dependencies();
......
......@@ -90,7 +90,7 @@ public:
auto iconst1 = construct_constant_node(1);
auto pattern = std::make_shared<pattern::op::Label>(iconst1);
ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) {
auto callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
<< m.get_match_root()->get_name();
NGRAPH_CHECK(m.get_match_root()->get_arguments().size() == 2);
......@@ -126,8 +126,8 @@ public:
return true;
};
auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
this->add_matcher(m);
auto m = make_shared<TestMatcher>(pattern * iconst1);
this->add_matcher(m, callback);
}
void construct_add_zero()
......@@ -172,8 +172,9 @@ public:
return true;
};
auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
this->add_matcher(m);
auto add = pattern + iconst0;
auto m = make_shared<TestMatcher>(add);
this->add_matcher(m, callback);
}
TestGraphRewrite()
......@@ -446,7 +447,7 @@ TEST(pattern, matcher)
// strict mode
{
TestMatcher sm(nullptr, nullptr, "TestMatcher", pass::PassProperty::REGULAR_FUSIONS, true);
TestMatcher sm(nullptr, "TestMatcher", true);
// exact shape and type
auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
auto label_dynamic_shape =
......@@ -540,7 +541,7 @@ TEST(pattern, recurrent_pattern)
auto add3 = iconst0 + add2;
auto padd = iconst0 + rpattern;
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
RecurrentMatcher rm(padd, rpattern, empty_correlated_matches, nullptr);
RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
ASSERT_TRUE(rm.match(add3));
ASSERT_EQ(rm.get_number_of_bound_labels(), 1);
auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
......@@ -554,7 +555,7 @@ TEST(pattern, recurrent_pattern)
auto add2_2 = iconst1 + add1;
auto add3_2 = iconst0 + add2_2;
auto padd2 = iconst_label + rpattern;
RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches, nullptr);
RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
ASSERT_TRUE(rm2.match(add3_2));
ASSERT_EQ(rm2.get_number_of_bound_labels(), 2);
recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
......@@ -569,7 +570,7 @@ TEST(pattern, recurrent_pattern)
// Non-matching correlated labels
std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
correlated_matches.insert(iconst_label);
RecurrentMatcher rm3(padd2, rpattern, correlated_matches, nullptr);
RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
ASSERT_TRUE(rm3.match(add3_2));
ASSERT_EQ(rm3.get_number_of_bound_labels(), 2);
iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
......@@ -602,8 +603,7 @@ public:
auto rpattern = std::make_shared<pattern::op::Label>(element::i32, shape);
auto padd = iconst_label + rpattern;
ngraph::pattern::recurrent_graph_rewrite_callback callback = [iconst_label, rpattern](
pattern::RecurrentMatcher& rm) {
auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
<< rm.get_match_root()->get_name();
......@@ -634,9 +634,8 @@ public:
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto rm = make_shared<pattern::RecurrentMatcher>(
padd, rpattern, empty_correlated_matches, callback);
this->add_matcher(rm);
auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
this->add_matcher(rm, callback);
}
TestRecurrentGraphRewrite()
......@@ -697,7 +696,7 @@ TEST(pattern, label_on_skip)
auto bcst = std::make_shared<pattern::op::Skip>(const_label, bcst_pred);
auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher = std::make_shared<pattern::Matcher>(
std::make_shared<op::Multiply>(label, bcst_label), nullptr);
std::make_shared<op::Multiply>(label, bcst_label), "label_on_skip");
auto const_broadcast = make_shared<op::Broadcast>(iconst, shape, AxisSet{0, 1});
auto mul = a * const_broadcast;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment