Unverified Commit 18998c41 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Merge branch 'master' into aprocter/de-eigenize-partial

parents 676e601d 2f0a33c3
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes) bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers)
{ {
bool rewritten = false; bool rewritten = false;
for (auto node : nodes) for (auto node : nodes)
{ {
for (auto matcher : m_matchers) for (auto matcher : matchers)
{ {
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node << " , " NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node << " , "
<< node->get_name(); << node->get_name();
...@@ -29,53 +31,7 @@ bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Nod ...@@ -29,53 +31,7 @@ bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Nod
return rewritten; return rewritten;
} }
void ngraph::pass::GraphRewrite::replace_node(std::shared_ptr<Node> target, bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes)
std::shared_ptr<Node> replacement)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
{
return;
}
//fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target_output.get_inputs()),
end(target_output.get_inputs())}; //replace_output modifies target_output->m_inputs
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
//fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::pass::GraphRewrite::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{ {
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , " return run_matchers_on_nodes_list(nodes, m_matchers);
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users())
{
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
//TODO: [nikolayk] recursively walk target and update users()
//nodes w/ empty users sets should be DSE'ed.
} }
...@@ -49,10 +49,10 @@ public: ...@@ -49,10 +49,10 @@ public:
} }
void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); } void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); }
static void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
static void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
virtual bool run_on_call_graph(std::list<std::shared_ptr<ngraph::Node>>&) override; virtual bool run_on_call_graph(std::list<std::shared_ptr<ngraph::Node>>&) override;
static bool
run_matchers_on_nodes_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers);
private: private:
//enable cascading rewrites //enable cascading rewrites
......
...@@ -36,43 +36,19 @@ namespace ngraph ...@@ -36,43 +36,19 @@ namespace ngraph
begin(arguments), end(arguments)); //vector is needed for generating permutations begin(arguments), end(arguments)); //vector is needed for generating permutations
} }
std::shared_ptr<Node> Matcher::match_root() std::shared_ptr<Node> Matcher::match_root() { return m_match_root; }
{ bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
assert(is_match()); const std::shared_ptr<Node>& graph_node,
return m_match_root; PatternMap& pattern_map)
}
void Matcher::reset_pattern_nodes(
std::shared_ptr<Node> node) //TODO: [nikolayk] this doesn't have to be recursive
//even better we should walk the entire pattern subgraph once
//and keep track of all pattern nodes
{
auto label = std::dynamic_pointer_cast<::ngraph::pattern::op::Label>(node);
NGRAPH_DEBUG << "reset_pattern_nodes : node = " << node->get_name() << " , " << node;
if (label)
{
NGRAPH_DEBUG << "reset_pattern_nodes : label = " << node->get_name() << " , "
<< node;
label->reset();
}
for (auto arg : get_arguments(node))
{
reset_pattern_nodes(arg);
}
}
void Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node)
{ {
bool is_match = true; bool is_match = true;
if (label->is_bound()) if (pattern_map.count(label))
{ {
if (label->get_bound_node() != graph_node) if (pattern_map[label] != graph_node)
{ {
NGRAPH_DEBUG << "get_bound_node " << label->get_bound_node()->get_name() NGRAPH_DEBUG << "get_bound_node " << pattern_map[label]->get_name() << " , "
<< " , " << label->get_bound_node() << " NOT match " << pattern_map[label] << " NOT match " << graph_node->get_name()
<< graph_node->get_name() << " , " << graph_node; << " , " << graph_node;
is_match = false; is_match = false;
} }
} }
...@@ -82,103 +58,96 @@ namespace ngraph ...@@ -82,103 +58,96 @@ namespace ngraph
is_match = !predicate || predicate(graph_node); is_match = !predicate || predicate(graph_node);
} }
if (is_match) if (is_match) //in case label was already bound this rebinds it to the same node (harmless; and the logic seems cleaner)
{ {
NGRAPH_DEBUG << "Binding get_bound_node " << graph_node->get_name() << " , " NGRAPH_DEBUG << "(Re)binding get_bound_node " << graph_node->get_name() << " , "
<< graph_node << " , " << graph_node->get_name(); << graph_node << " , " << graph_node->get_name();
label->bind(graph_node); pattern_map[label] = graph_node;
}
else
{
reset();
m_match_root.reset();
NGRAPH_DEBUG << "MATCHER IS MATCH : " << this->is_match();
} }
return is_match;
} }
void Matcher::match_any(const std::shared_ptr<op::Any>& any, bool Matcher::match_any(const std::shared_ptr<op::Any>& any,
const std::shared_ptr<Node>& graph_node) const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{ {
auto predicate = any->get_predicate(); auto predicate = any->get_predicate();
if (!predicate || any->get_predicate()(graph_node)) if (!predicate || any->get_predicate()(graph_node))
{ {
on_match_class(any, graph_node, true); return match_arguments(any, graph_node, pattern_map);
} }
else else
{ {
auto args = get_arguments(any); auto args = get_arguments(any);
assert(args.size() == 1); assert(args.size() == 1);
on_match_class(args.at(0), graph_node, true); return match_node(args.at(0), graph_node, pattern_map);
} }
} }
void Matcher::match_class(const std::shared_ptr<Node>& pattern_node, bool Matcher::match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node) const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{ {
assert(pattern_node && graph_node); assert(pattern_node && graph_node);
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node)) if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{ {
match_pattern(label_node, graph_node); return match_pattern(label_node, graph_node, pattern_map);
return;
} }
if (auto any_node = std::dynamic_pointer_cast<op::Any>( if (auto any_node = std::dynamic_pointer_cast<op::Any>(
pattern_node)) //matches PatternSkipOp semantics pattern_node)) //matches PatternSkipOp semantics
{ {
match_any(any_node, graph_node); return match_any(any_node, graph_node, pattern_map);
return; }
auto p_pattern_node = pattern_node.get();
auto p_graph_node = graph_node.get();
if (std::type_index(typeid(*p_pattern_node)) == std::type_index(typeid(*p_graph_node)))
{
return match_arguments(pattern_node, graph_node, pattern_map);
} }
on_match_class(pattern_node, return false;
graph_node,
std::type_index(typeid(*&*pattern_node)) ==
std::type_index(typeid(*&*graph_node)));
} }
void Matcher::match_arguments(const Nodes& pattern_args, const Nodes& args) bool Matcher::match_permutation(const Nodes& pattern_args,
const Nodes& args,
PatternMap& pattern_map)
{ {
m_depth++; m_depth++;
for (size_t i = 0; i < args.size(); i++) for (size_t i = 0; i < args.size(); i++)
{ {
match_class(pattern_args.at(i), args.at(i)); if (!match_node(pattern_args.at(i), args.at(i), pattern_map))
if (!is_match())
{ {
m_depth--; m_depth--;
return; return false;
} }
} }
m_depth--; m_depth--;
return true;
} }
void Matcher::on_match_class(const std::shared_ptr<ngraph::Node>& pattern_node, bool Matcher::match_arguments(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node, const std::shared_ptr<ngraph::Node>& graph_node,
bool is_match) PatternMap& pattern_map)
{ {
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] " NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " " << "pattern = " << pattern_node << " , " << pattern_node->get_name() << " "
<< (is_match ? " " : "NOT ") << "matched " << graph_node << " , " << "matched " << graph_node << " , " << graph_node->get_name();
<< graph_node->get_name();
if (!is_match)
{
reset_pattern_nodes(pattern_node);
m_match_root.reset();
return;
}
auto args = get_arguments(graph_node); auto args = get_arguments(graph_node);
auto pattern_args = get_arguments(pattern_node); auto pattern_args = get_arguments(pattern_node);
if (args.size() != pattern_args.size()) if (args.size() != pattern_args.size())
{ {
reset_pattern_nodes(pattern_node); return false;
m_match_root.reset();
return;
} }
if (graph_node->is_commutative()) if (graph_node->is_commutative())
{ {
auto old_match_root = m_match_root;
std::sort( std::sort(
begin(pattern_args), begin(pattern_args),
end(pattern_args)); //TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster end(pattern_args)); //TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
...@@ -186,20 +155,24 @@ namespace ngraph ...@@ -186,20 +155,24 @@ namespace ngraph
{ {
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node " NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
<< graph_node->get_name() << " , " << graph_node; << graph_node->get_name() << " , " << graph_node;
reset_pattern_nodes(pattern_node); PatternMap copy{pattern_map};
m_match_root = if (match_permutation(pattern_args, args, copy))
old_match_root; //previous permutation wasn't a match; reset m_match_root
match_arguments(pattern_args, args);
if (this->is_match())
{ {
return; pattern_map.insert(begin(copy), end(copy));
return true;
} }
} while (std::next_permutation(begin(pattern_args), end(pattern_args))); } while (std::next_permutation(begin(pattern_args), end(pattern_args)));
} }
else else
{ {
match_arguments(pattern_args, args); PatternMap copy{pattern_map};
if (match_permutation(pattern_args, args, copy))
{
pattern_map.insert(begin(copy), end(copy));
return true;
}
} }
return false;
} }
void Matcher::process_match(::ngraph::pattern::gr_callback_fn callback) void Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
...@@ -211,7 +184,7 @@ namespace ngraph ...@@ -211,7 +184,7 @@ namespace ngraph
} }
assert(cb); assert(cb);
assert(is_match()); assert(this->m_match_root);
cb(*this); cb(*this);
} }
...@@ -230,29 +203,32 @@ namespace ngraph ...@@ -230,29 +203,32 @@ namespace ngraph
return result; return result;
} }
bool Matcher::match(const std::shared_ptr<Node>& pattern_node, bool Matcher::match(const std::shared_ptr<Node>& graph_node)
const std::shared_ptr<Node>& graph_node)
{ {
if (!pattern_node || !graph_node) //clear our state
m_match_root.reset();
m_pattern_map.clear();
if (!m_pattern_node || !graph_node)
{ {
NGRAPH_DEBUG << "pattern_node or graph_node are not set; matching FAILED"; throw "m_pattern_node or graph_node are not set!";
m_match_root.reset();
} }
if (get_users(pattern_node).size()) if (get_users(m_pattern_node).size())
{ {
throw "Pattern Node must not be used elsewhere!"; throw "Pattern Node must not be used elsewhere!";
} }
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node << " , " NGRAPH_DEBUG << "Starting match pattern = " << m_pattern_node << " , "
<< pattern_node->get_name() << " , graph_node = " << graph_node << " , " << m_pattern_node->get_name() << " , graph_node = " << graph_node << " , "
<< graph_node->get_name(); << graph_node->get_name();
reset_pattern_nodes(pattern_node); bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map);
m_match_root = graph_node; if (is_match)
match_class(pattern_node, graph_node); {
//NGRAPH_DEBUG << pad(2 * m_depth) << "is_match() " << is_match(); m_match_root = graph_node;
return is_match(); }
return is_match;
} }
} }
} }
...@@ -41,58 +41,58 @@ namespace ngraph ...@@ -41,58 +41,58 @@ namespace ngraph
class Matcher class Matcher
{ {
public: public:
using PatternMap = std::map<std::shared_ptr<op::Label>, std::shared_ptr<Node>>;
/// \brief Constructs a Matcher object /// \brief Constructs a Matcher object
/// ///
/// \param pattern_node is a pattern sub graph that will be matched against input graphs /// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param callback is a callback function that will be called on a successful match /// \param callback is a callback function that will be called on a successful match
Matcher(const std::shared_ptr<Node> pattern_node = nullptr, Matcher(const std::shared_ptr<Node> pattern_node = nullptr,
gr_callback_fn callback = nullptr) gr_callback_fn callback = nullptr)
: m_match_root(nullptr) : m_pattern_node(pattern_node)
, m_pattern_node(pattern_node)
, m_callback(callback) , m_callback(callback)
, m_depth(0) , m_depth(0)
{ {
} }
virtual ~Matcher() {} virtual ~Matcher() {}
// Called when the pattern node matches a graph node.
virtual void on_match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
bool is_match);
/// \brief Matches a pattern to \p graph_node /// \brief Matches a pattern to \p graph_node
/// ///
/// \param graph_node is an input graph to be matched against /// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node) bool match(const std::shared_ptr<Node>& graph_node);
{
return match(m_pattern_node, graph_node);
}
bool match(const std::shared_ptr<Node>& pattern_node, //keep public for testing for now
const std::shared_ptr<Node>& graph_node);
void process_match(gr_callback_fn callback = nullptr); void process_match(gr_callback_fn callback = nullptr);
void reset() {} void reset() {}
bool is_match() { return m_match_root != nullptr; }
std::shared_ptr<Node> pattern_node() { return m_pattern_node; } std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
std::shared_ptr<Node> match_root(); std::shared_ptr<Node> match_root();
void reset_pattern_nodes(std::shared_ptr<Node> node); PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; }
friend op::Label; //TODO: refine to match_class friend op::Label; //TODO: refine to match_class
protected: protected:
void virtual match_class(const std::shared_ptr<Node>& pattern_node, bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node); const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
virtual bool match_arguments(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
private:
static std::string pad(size_t num) { return std::string(num, ' '); }
void match_arguments(const Nodes& pattern_args, const Nodes& args);
void match_pattern(const std::shared_ptr<op::Label>& pattern_node,
const std::shared_ptr<Node>& graph_node);
void match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node);
std::shared_ptr<Node> m_match_root; std::shared_ptr<Node> m_match_root;
std::shared_ptr<Node> m_pattern_node; std::shared_ptr<Node> m_pattern_node;
PatternMap m_pattern_map;
private:
static std::string pad(size_t num) { return std::string(num, ' '); }
bool match_permutation(const Nodes& pattern_args,
const Nodes& args,
PatternMap& pattern_map);
bool match_pattern(const std::shared_ptr<op::Label>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
gr_callback_fn m_callback; gr_callback_fn m_callback;
size_t m_depth; size_t m_depth;
}; };
......
...@@ -38,17 +38,11 @@ namespace ngraph ...@@ -38,17 +38,11 @@ namespace ngraph
label->set_value_type_checked(node->get_value_type()); label->set_value_type_checked(node->get_value_type());
return label; return label;
} }
bool is_bound() { return m_bound != nullptr; }
std::shared_ptr<Node> get_bound_node() { return m_bound; }
void reset() { m_bound.reset(); }
void bind(std::shared_ptr<Node> n) { m_bound = n; }
Label(Predicate pred = nullptr) Label(Predicate pred = nullptr)
: Pattern("Label", Nodes{}, pred) : Pattern("Label", Nodes{}, pred)
{ {
} }
private:
std::shared_ptr<Node> m_bound;
}; };
} }
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
#include <cassert>
#include <deque> #include <deque>
#include <forward_list> #include <forward_list>
#include <iomanip> #include <iomanip>
...@@ -210,3 +211,50 @@ void ngraph::free_nodes(shared_ptr<Function> p) ...@@ -210,3 +211,50 @@ void ngraph::free_nodes(shared_ptr<Function> p)
n->clear_arguments(); n->clear_arguments();
} }
} }
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
{
return;
}
//fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target_output.get_inputs()),
end(target_output.get_inputs())}; //replace_output modifies target_output->m_inputs
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
//fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users())
{
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
}
...@@ -245,4 +245,9 @@ namespace ngraph ...@@ -245,4 +245,9 @@ namespace ngraph
std::function<void(std::shared_ptr<Function>)> f); std::function<void(std::shared_ptr<Function>)> f);
void free_nodes(std::shared_ptr<Function>); void free_nodes(std::shared_ptr<Function>);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
} // end namespace ngraph } // end namespace ngraph
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp" #include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
...@@ -35,22 +36,37 @@ using namespace std; ...@@ -35,22 +36,37 @@ using namespace std;
class TestMatcher : public pattern::Matcher class TestMatcher : public pattern::Matcher
{ {
using pattern::Matcher::Matcher; using pattern::Matcher::Matcher;
void virtual match_class(const std::shared_ptr<Node>& pattern_node, bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node) override const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map) override
{ {
static const auto parameter_type = std::type_index(typeid(::ngraph::op::Parameter)); if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
const auto pattern_type = std::type_index(typeid(*&*pattern_node));
if (pattern_type == parameter_type)
{ {
on_match_class(pattern_node, return pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
graph_node,
pattern_node.get() ==
dynamic_cast<::ngraph::op::Parameter*>(graph_node.get()));
return;
} }
this->pattern::Matcher::match_class(pattern_node, graph_node); return this->pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
}
public:
bool match(const std::shared_ptr<Node>& pattern_node, const std::shared_ptr<Node>& graph_node)
{
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node << " , "
<< pattern_node->get_name() << " , graph_node = " << graph_node << " , "
<< graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
} }
}; };
...@@ -75,18 +91,19 @@ public: ...@@ -75,18 +91,19 @@ public:
NGRAPH_DEBUG << "IN CALLBACK"; NGRAPH_DEBUG << "IN CALLBACK";
assert(m.match_root()->get_arguments().size() == 2); assert(m.match_root()->get_arguments().size() == 2);
size_t const_node_index = auto pattern_map = m.get_pattern_map();
m.match_root()->get_arguments().at(0) == pattern->get_bound_node();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>( auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>(
m.match_root()->get_arguments().at(const_node_index)); m.match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index); auto second_node = m.match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node; NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern->get_bound_node()->description() << " , " NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern->get_bound_node(); << pattern_map[pattern];
assert(const_node); assert(const_node);
auto pattern_value_type = dynamic_pointer_cast<const TensorViewType>( auto pattern_value_type =
pattern->get_bound_node()->get_value_type()); dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type());
auto const_node_value_type = auto const_node_value_type =
dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type()); dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type());
assert(pattern_value_type && const_node); assert(pattern_value_type && const_node);
...@@ -110,7 +127,7 @@ public: ...@@ -110,7 +127,7 @@ public:
} }
NGRAPH_DEBUG << "BEFORE REPLACE"; NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::pass::GraphRewrite::replace_node(m.match_root(), pattern->get_bound_node()); ngraph::replace_node(m.match_root(), pattern_map[pattern]);
}; };
auto m = make_shared<TestMatcher>(pattern * iconst1, callback); auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
...@@ -129,18 +146,19 @@ public: ...@@ -129,18 +146,19 @@ public:
NGRAPH_DEBUG << "IN CALLBACK"; NGRAPH_DEBUG << "IN CALLBACK";
assert(m.match_root()->get_arguments().size() == 2); assert(m.match_root()->get_arguments().size() == 2);
size_t const_node_index = auto pattern_map = m.get_pattern_map();
m.match_root()->get_arguments().at(0) == pattern->get_bound_node();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>( auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>(
m.match_root()->get_arguments().at(const_node_index)); m.match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index); auto second_node = m.match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node; NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern->get_bound_node()->description() << " , " NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern->get_bound_node(); << pattern_map[pattern];
assert(const_node); assert(const_node);
auto pattern_value_type = dynamic_pointer_cast<const TensorViewType>( auto pattern_value_type =
pattern->get_bound_node()->get_value_type()); dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type());
auto const_node_value_type = auto const_node_value_type =
dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type()); dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type());
assert(pattern_value_type && const_node); assert(pattern_value_type && const_node);
...@@ -164,7 +182,7 @@ public: ...@@ -164,7 +182,7 @@ public:
} }
NGRAPH_DEBUG << "BEFORE REPLACE"; NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::pass::GraphRewrite::replace_node(m.match_root(), pattern->get_bound_node()); ngraph::replace_node(m.match_root(), pattern_map[pattern]);
}; };
auto m = make_shared<TestMatcher>(pattern + iconst0, callback); auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
...@@ -292,7 +310,7 @@ TEST(pattern, matcher) ...@@ -292,7 +310,7 @@ TEST(pattern, matcher)
auto pattern = pattern::op::Label::make_from_node(a); auto pattern = pattern::op::Label::make_from_node(a);
ASSERT_TRUE(n.match(pattern, a)); ASSERT_TRUE(n.match(pattern, a));
ASSERT_EQ(pattern->get_bound_node(), a); ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto pattern_false = auto pattern_false =
pattern::op::Label::make_from_node(a, [](std::shared_ptr<Node> no) { return false; }); pattern::op::Label::make_from_node(a, [](std::shared_ptr<Node> no) { return false; });
...@@ -306,14 +324,14 @@ TEST(pattern, matcher) ...@@ -306,14 +324,14 @@ TEST(pattern, matcher)
ASSERT_TRUE(n.match(any + b, abs + b)); ASSERT_TRUE(n.match(any + b, abs + b));
ASSERT_TRUE(n.match(pattern + b, abs + b)); ASSERT_TRUE(n.match(pattern + b, abs + b));
ASSERT_EQ(pattern->get_bound_node(), abs); ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_TRUE(n.match(b + pattern, abs + b)); ASSERT_TRUE(n.match(b + pattern, abs + b));
ASSERT_EQ(pattern->get_bound_node(), abs); ASSERT_EQ(n.get_pattern_map()[pattern], abs);
auto c = make_shared<op::Parameter>(element::Int32::element_type(), shape); auto c = make_shared<op::Parameter>(element::Int32::element_type(), shape);
ASSERT_TRUE(n.match(c * (b + pattern), c * (abs + b))); ASSERT_TRUE(n.match(c * (b + pattern), c * (abs + b)));
ASSERT_EQ(pattern->get_bound_node(), abs); ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_TRUE(n.match(c * (any + b), c * (abs + b))); //nested any ASSERT_TRUE(n.match(c * (any + b), c * (abs + b))); //nested any
ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); //permutations w/ any ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); //permutations w/ any
...@@ -323,7 +341,7 @@ TEST(pattern, matcher) ...@@ -323,7 +341,7 @@ TEST(pattern, matcher)
auto iconst1_0 = construct_constant_node(1); auto iconst1_0 = construct_constant_node(1);
auto iconst1_1 = construct_constant_node(1); auto iconst1_1 = construct_constant_node(1);
ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst
ASSERT_EQ(pattern->get_bound_node(), a); ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto fconst1_0 = auto fconst1_0 =
make_shared<op::Constant>(element::Float32::element_type(), Shape{1}, std::to_string(1)); make_shared<op::Constant>(element::Float32::element_type(), Shape{1}, std::to_string(1));
auto patternf = pattern::op::Label::make_from_node(fconst1_0); auto patternf = pattern::op::Label::make_from_node(fconst1_0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment