Commit ab63fd33 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Subgraph Labels (#285)

* subgraph labels

* adding more documentation for Label

* minor fixes
parent 8c50b179
......@@ -46,9 +46,9 @@ namespace ngraph
{
if (pattern_map[label] != graph_node)
{
NGRAPH_DEBUG << "get_bound_node " << pattern_map[label]->get_name() << " , "
<< pattern_map[label] << " NOT match " << graph_node->get_name()
<< " , " << graph_node;
NGRAPH_DEBUG << "[MATCHER] get_bound_node " << pattern_map[label]->get_name()
<< " , " << pattern_map[label] << " NOT match "
<< graph_node->get_name() << " , " << graph_node;
is_match = false;
}
}
......@@ -60,11 +60,23 @@ namespace ngraph
if (is_match) //in case label was already bound this rebinds it to the same node (harmless; and the logic seems cleaner)
{
NGRAPH_DEBUG << "(Re)binding get_bound_node " << graph_node->get_name() << " , "
<< graph_node << " , " << graph_node->get_name();
pattern_map[label] = graph_node;
}
auto args = get_arguments(label);
if (args.size() > 0)
{
assert(args.size() ==
1); //it should be impossible to construct labels w/ more than one arg
NGRAPH_DEBUG << "[MATCHER] Label describes a sub graph in the pattern";
is_match = match_node(args.at(0), graph_node, pattern_map);
}
if (is_match)
{
NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node "
<< graph_node->get_name() << " , " << graph_node << " , "
<< graph_node->get_name();
pattern_map[label] = graph_node;
}
}
return is_match;
}
......@@ -91,6 +103,11 @@ namespace ngraph
PatternMap& pattern_map)
{
assert(pattern_node && graph_node);
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " "
<< "matched " << graph_node << " , " << graph_node->get_name();
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{
return match_pattern(label_node, graph_node, pattern_map);
......
......@@ -30,17 +30,43 @@ namespace ngraph
class Label : public Pattern
{
public:
/// \brief creates a Label node from \sa node.
///
/// this Label node can be bound to arbitrary nodes in an input graph
/// as long as provided \sa pred is satisfied and the node hasn't been previously bound to
/// a different node in the input graph
/// \code{.cpp}
/// auto pattern = pattern::op::Label::make_from_node(a); //a is op::Parameter
/// matcher.match(pattern, a));
/// \endcode
static std::shared_ptr<Label>
make_from_node(const std::shared_ptr<ngraph::Node>& node,
Predicate pred = nullptr)
{
auto label = std::make_shared<Label>(pred);
auto label = std::make_shared<Label>(Nodes{}, pred);
label->set_value_type_checked(node->get_value_type());
return label;
}
Label(Predicate pred = nullptr)
: Pattern("Label", Nodes{}, pred)
/// \brief creates a Label node containing a sub-pattern described by \sa node.
///
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa node
/// Example:
/// \code{.cpp}
/// auto add = a + b; //a and b are op::Parameter in this example
/// auto label = pattern::op::Label::wrap(add);
/// \endcode
static std::shared_ptr<Label> wrap(const std::shared_ptr<ngraph::Node>& node,
Predicate pred = nullptr)
{
auto label = std::make_shared<Label>(Nodes{node}, pred);
label->set_value_type_checked(node->get_value_type());
return label;
}
Label(const Nodes& subgraph, Predicate pred)
: Pattern("Label", Nodes{subgraph}, pred)
{
}
};
......
......@@ -346,4 +346,30 @@ TEST(pattern, matcher)
make_shared<op::Constant>(element::Float32::element_type(), Shape{1}, std::to_string(1));
auto patternf = pattern::op::Label::make_from_node(fconst1_0);
ASSERT_FALSE(n.match(patternf * fconst1_0, a * iconst1_1)); //different iconst
//Subgraph labels
auto add = a + b;
auto label = pattern::op::Label::wrap(add);
ASSERT_TRUE(n.match(label, add));
ASSERT_EQ(n.get_pattern_map()[label], add);
ASSERT_FALSE(n.match(label, a - b));
ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
ASSERT_EQ(n.get_pattern_map()[label], add);
//Correlations
auto label1 = pattern::op::Label::make_from_node(a);
auto label2 = pattern::op::Label::wrap(label1 + b);
auto sub_label1 = label1 - label2;
ASSERT_TRUE(n.match(sub_label1, a - add));
ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add);
ASSERT_FALSE(n.match(sub_label1, add - a));
auto add_label1 = label1 + label2;
ASSERT_TRUE(n.match(add_label1, add + a));
ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment