Commit 606ad20b authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Fix Matcher's getters to adhere to ngraph coding guidelines (#916)

* rename getters to adhere to ngraph coding guidelines

* fix renaminb

* fix build errors
parent 4f2316e8
......@@ -52,7 +52,7 @@ static std::shared_ptr<pattern::Matcher>
static std::shared_ptr<pattern::op::Label>
get_broadcast_label(std::shared_ptr<pattern::Matcher> matcher)
{
return std::dynamic_pointer_cast<pattern::op::Label>(matcher->pattern_node()->get_argument(1));
return std::dynamic_pointer_cast<pattern::op::Label>(matcher->get_pattern()->get_argument(1));
}
//`simplify_multiply` optimizes the following 4 *base* cases
......
......@@ -53,7 +53,8 @@ void pass::CoreFusion::construct_relu()
auto max = make_shared<op::Maximum>(skip_broadcast, val);
pattern::graph_rewrite_callback callback = [val, zero](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_relu against " << m.match_root()->get_name();
NGRAPH_DEBUG << "In a callback for construct_relu against "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto mzero = m.get_pattern_map()[zero];
......@@ -62,10 +63,10 @@ void pass::CoreFusion::construct_relu()
NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n";
return false;
}
auto mpattern = m.match_root();
auto mpattern = m.get_match_root();
auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val]));
ngraph::replace_node(m.match_root(), cg);
ngraph::replace_node(m.get_match_root(), cg);
return true;
};
......
......@@ -61,11 +61,11 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = "
<< m.match_root()->get_name();
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto gop = pattern_map[op];
auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root());
if (r1->get_shape() != gop->get_shape())
{
......@@ -82,7 +82,7 @@ void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
return false;
}
ngraph::replace_node(m.match_root(), gop);
ngraph::replace_node(m.get_match_root(), gop);
return true;
};
......@@ -101,22 +101,22 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_reshapex2_pattern against node = "
<< m.match_root()->get_name();
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto gop = pattern_map[op];
if (gop->get_shape() != m.match_root()->get_shape())
if (gop->get_shape() != m.get_match_root()->get_shape())
{
NGRAPH_DEBUG << "Operand shape doesn't match the shape of the second reshape!";
NGRAPH_DEBUG << "gop " << gop->get_name()
<< "shape = " << vector_to_string(gop->get_shape());
NGRAPH_DEBUG << "match_root " << m.match_root()->get_name()
<< "shape = " << vector_to_string(m.match_root()->get_shape());
NGRAPH_DEBUG << "match_root " << m.get_match_root()->get_name()
<< "shape = " << vector_to_string(m.get_match_root()->get_shape());
return false;
}
auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = std::dynamic_pointer_cast<op::Reshape>(r2->get_argument(0));
Shape do_r2(r1->get_shape().size());
......@@ -132,7 +132,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2)
{
NGRAPH_DEBUG << "Two reshapes were removed!";
ngraph::replace_node(m.match_root(), gop);
ngraph::replace_node(m.get_match_root(), gop);
return true;
}
......@@ -141,7 +141,7 @@ void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
if (perm2 == do_r1)
{
NGRAPH_DEBUG << "Two transposes were removed!";
ngraph::replace_node(m.match_root(), gop);
ngraph::replace_node(m.get_match_root(), gop);
return true;
}
......@@ -163,9 +163,9 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
<< m.match_root()->get_name();
<< m.get_match_root()->get_name();
auto mtranspose = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
auto mtranspose = std::dynamic_pointer_cast<op::Reshape>(m.get_match_root());
//this also checks the rank
if (mtranspose->get_input_order() != AxisVector{1, 0})
{
......@@ -190,7 +190,7 @@ void ngraph::pass::ReshapeElimination::construct_dot_transpose_pattern()
auto reshape1 = std::make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
auto tdot = std::shared_ptr<Node>(new op::Dot(reshape1, reshape0));
ngraph::replace_node(m.match_root(), tdot);
ngraph::replace_node(m.get_match_root(), tdot);
return true;
};
......
......@@ -26,7 +26,7 @@ namespace ngraph
{
namespace pattern
{
std::shared_ptr<Node> Matcher::match_root() { return m_match_root; }
std::shared_ptr<Node> Matcher::get_match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
......
......@@ -95,8 +95,8 @@ namespace ngraph
bool process_match(graph_rewrite_callback callback = nullptr);
void reset() {}
std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
std::shared_ptr<Node> match_root();
std::shared_ptr<Node> get_pattern() { return m_pattern_node; }
std::shared_ptr<Node> get_match_root();
PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; }
/// \brief Low-level helper to match recurring patterns
///
......
This diff is collapsed.
......@@ -51,9 +51,10 @@ void ngraph::runtime::cpu::pass::CPUPostLayoutOptimizations::construct_weight_fu
data_conv, cvt_lt_conv, Strides{1, 1}, Strides{1, 1});
pattern::graph_rewrite_callback callback = [param](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_weight against " << m.match_root()->get_name();
NGRAPH_DEBUG << "In a callback for construct_weight against "
<< m.get_match_root()->get_name();
auto m_cvt_lt = m.match_root()->get_argument(1);
auto m_cvt_lt = m.get_match_root()->get_argument(1);
auto m_reshape_conv = m_cvt_lt->get_argument(0);
std::shared_ptr<Node> m_conv_bprop;
......
......@@ -152,15 +152,16 @@ public:
ngraph::pattern::graph_rewrite_callback callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
<< m.match_root()->get_name();
assert(m.match_root()->get_arguments().size() == 2);
<< m.get_match_root()->get_name();
assert(m.get_match_root()->get_arguments().size() == 2);
auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern];
size_t const_node_index =
m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index);
m.get_match_root()->get_arguments().at(const_node_index));
auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
......@@ -181,7 +182,7 @@ public:
return false;
}
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
return true;
};
......@@ -197,15 +198,16 @@ public:
auto callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_add_zero against "
<< m.match_root()->get_name();
assert(m.match_root()->get_arguments().size() == 2);
<< m.get_match_root()->get_name();
assert(m.get_match_root()->get_arguments().size() == 2);
auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern];
size_t const_node_index =
m.get_match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index);
m.get_match_root()->get_arguments().at(const_node_index));
auto second_node = m.get_match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
......@@ -226,7 +228,7 @@ public:
return false;
}
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
return true;
};
......@@ -240,14 +242,14 @@ public:
ngraph::pattern::graph_rewrite_callback callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_sum_pattern against "
<< m.match_root()->get_name();
auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root());
<< m.get_match_root()->get_name();
auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.get_match_root());
auto reducee = reduce->get_inputs().at(0).get_output().get_node();
NGRAPH_DEBUG << "reducee = " << reducee->get_name();
auto sum =
std::shared_ptr<ngraph::Node>(new op::Sum(reducee, reduce->get_reduction_axes()));
ngraph::replace_node(m.match_root(), sum);
ngraph::replace_node(m.get_match_root(), sum);
return true;
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment