Commit 2f0a33c3 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Improvements to Pattern Matcher (#274)

* move replace_node to util

* static run_matchers_on_nodes_list

* switching to map

* formatting

* addressing Scott's feedback and fixing warnings

* more pattern matcher refactoring

refactoring cont'd
parent d4153c91
......@@ -5,12 +5,14 @@
#include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp"
bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes)
bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers)
{
bool rewritten = false;
for (auto node : nodes)
{
for (auto matcher : m_matchers)
for (auto matcher : matchers)
{
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node << " , "
<< node->get_name();
......@@ -29,53 +31,7 @@ bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Nod
return rewritten;
}
void ngraph::pass::GraphRewrite::replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
{
return;
}
//fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target_output.get_inputs()),
end(target_output.get_inputs())}; //replace_output modifies target_output->m_inputs
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
//fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::pass::GraphRewrite::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes)
{
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users())
{
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
//TODO: [nikolayk] recursively walk target and update users()
//nodes w/ empty users sets should be DSE'ed.
return run_matchers_on_nodes_list(nodes, m_matchers);
}
......@@ -49,10 +49,10 @@ public:
}
void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); }
static void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
static void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
virtual bool run_on_call_graph(std::list<std::shared_ptr<ngraph::Node>>&) override;
static bool
run_matchers_on_nodes_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers);
private:
//enable cascading rewrites
......
This diff is collapsed.
......@@ -41,58 +41,58 @@ namespace ngraph
class Matcher
{
public:
using PatternMap = std::map<std::shared_ptr<op::Label>, std::shared_ptr<Node>>;
/// \brief Constructs a Matcher object
///
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param callback is a callback function that will be called on a successful match
Matcher(const std::shared_ptr<Node> pattern_node = nullptr,
gr_callback_fn callback = nullptr)
: m_match_root(nullptr)
, m_pattern_node(pattern_node)
: m_pattern_node(pattern_node)
, m_callback(callback)
, m_depth(0)
{
}
virtual ~Matcher() {}
// Called when the pattern node matches a graph node.
virtual void on_match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
bool is_match);
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node)
{
return match(m_pattern_node, graph_node);
}
bool match(const std::shared_ptr<Node>& pattern_node, //keep public for testing for now
const std::shared_ptr<Node>& graph_node);
bool match(const std::shared_ptr<Node>& graph_node);
void process_match(gr_callback_fn callback = nullptr);
void reset() {}
bool is_match() { return m_match_root != nullptr; }
std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
std::shared_ptr<Node> match_root();
void reset_pattern_nodes(std::shared_ptr<Node> node);
PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; }
friend op::Label; //TODO: refine to match_class
protected:
void virtual match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node);
bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
virtual bool match_arguments(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
private:
static std::string pad(size_t num) { return std::string(num, ' '); }
void match_arguments(const Nodes& pattern_args, const Nodes& args);
void match_pattern(const std::shared_ptr<op::Label>& pattern_node,
const std::shared_ptr<Node>& graph_node);
void match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node);
std::shared_ptr<Node> m_match_root;
std::shared_ptr<Node> m_pattern_node;
PatternMap m_pattern_map;
private:
static std::string pad(size_t num) { return std::string(num, ' '); }
bool match_permutation(const Nodes& pattern_args,
const Nodes& args,
PatternMap& pattern_map);
bool match_pattern(const std::shared_ptr<op::Label>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
gr_callback_fn m_callback;
size_t m_depth;
};
......
......@@ -38,17 +38,11 @@ namespace ngraph
label->set_value_type_checked(node->get_value_type());
return label;
}
bool is_bound() { return m_bound != nullptr; }
std::shared_ptr<Node> get_bound_node() { return m_bound; }
void reset() { m_bound.reset(); }
void bind(std::shared_ptr<Node> n) { m_bound = n; }
Label(Predicate pred = nullptr)
: Pattern("Label", Nodes{}, pred)
{
}
private:
std::shared_ptr<Node> m_bound;
};
}
}
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <deque>
#include <forward_list>
#include <iomanip>
......@@ -210,3 +211,50 @@ void ngraph::free_nodes(shared_ptr<Function> p)
n->clear_arguments();
}
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
{
return;
}
//fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target_output.get_inputs()),
end(target_output.get_inputs())}; //replace_output modifies target_output->m_inputs
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
//fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users())
{
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
}
......@@ -245,4 +245,9 @@ namespace ngraph
std::function<void(std::shared_ptr<Function>)> f);
void free_nodes(std::shared_ptr<Function>);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
} // end namespace ngraph
......@@ -27,6 +27,7 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
using namespace std;
......@@ -35,22 +36,37 @@ using namespace std;
class TestMatcher : public pattern::Matcher
{
using pattern::Matcher::Matcher;
void virtual match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node) override
bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map) override
{
static const auto parameter_type = std::type_index(typeid(::ngraph::op::Parameter));
const auto pattern_type = std::type_index(typeid(*&*pattern_node));
if (pattern_type == parameter_type)
if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
{
on_match_class(pattern_node,
graph_node,
pattern_node.get() ==
dynamic_cast<::ngraph::op::Parameter*>(graph_node.get()));
return;
return pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
}
this->pattern::Matcher::match_class(pattern_node, graph_node);
return this->pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
}
public:
bool match(const std::shared_ptr<Node>& pattern_node, const std::shared_ptr<Node>& graph_node)
{
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node << " , "
<< pattern_node->get_name() << " , graph_node = " << graph_node << " , "
<< graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
};
......@@ -75,18 +91,19 @@ public:
NGRAPH_DEBUG << "IN CALLBACK";
assert(m.match_root()->get_arguments().size() == 2);
size_t const_node_index =
m.match_root()->get_arguments().at(0) == pattern->get_bound_node();
auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>(
m.match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern->get_bound_node()->description() << " , "
<< pattern->get_bound_node();
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern_map[pattern];
assert(const_node);
auto pattern_value_type = dynamic_pointer_cast<const TensorViewType>(
pattern->get_bound_node()->get_value_type());
auto pattern_value_type =
dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type());
auto const_node_value_type =
dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type());
assert(pattern_value_type && const_node);
......@@ -110,7 +127,7 @@ public:
}
NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::pass::GraphRewrite::replace_node(m.match_root(), pattern->get_bound_node());
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
};
auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
......@@ -129,18 +146,19 @@ public:
NGRAPH_DEBUG << "IN CALLBACK";
assert(m.match_root()->get_arguments().size() == 2);
size_t const_node_index =
m.match_root()->get_arguments().at(0) == pattern->get_bound_node();
auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>(
m.match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern->get_bound_node()->description() << " , "
<< pattern->get_bound_node();
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern_map[pattern];
assert(const_node);
auto pattern_value_type = dynamic_pointer_cast<const TensorViewType>(
pattern->get_bound_node()->get_value_type());
auto pattern_value_type =
dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type());
auto const_node_value_type =
dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type());
assert(pattern_value_type && const_node);
......@@ -164,7 +182,7 @@ public:
}
NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::pass::GraphRewrite::replace_node(m.match_root(), pattern->get_bound_node());
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
};
auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
......@@ -292,7 +310,7 @@ TEST(pattern, matcher)
auto pattern = pattern::op::Label::make_from_node(a);
ASSERT_TRUE(n.match(pattern, a));
ASSERT_EQ(pattern->get_bound_node(), a);
ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto pattern_false =
pattern::op::Label::make_from_node(a, [](std::shared_ptr<Node> no) { return false; });
......@@ -306,14 +324,14 @@ TEST(pattern, matcher)
ASSERT_TRUE(n.match(any + b, abs + b));
ASSERT_TRUE(n.match(pattern + b, abs + b));
ASSERT_EQ(pattern->get_bound_node(), abs);
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_TRUE(n.match(b + pattern, abs + b));
ASSERT_EQ(pattern->get_bound_node(), abs);
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
auto c = make_shared<op::Parameter>(element::Int32::element_type(), shape);
ASSERT_TRUE(n.match(c * (b + pattern), c * (abs + b)));
ASSERT_EQ(pattern->get_bound_node(), abs);
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_TRUE(n.match(c * (any + b), c * (abs + b))); //nested any
ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); //permutations w/ any
......@@ -323,7 +341,7 @@ TEST(pattern, matcher)
auto iconst1_0 = construct_constant_node(1);
auto iconst1_1 = construct_constant_node(1);
ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst
ASSERT_EQ(pattern->get_bound_node(), a);
ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto fconst1_0 =
make_shared<op::Constant>(element::Float32::element_type(), Shape{1}, std::to_string(1));
auto patternf = pattern::op::Label::make_from_node(fconst1_0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment