Unverified Commit 4345e39d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Pattern matching for sum (#293)

* the first stab at pattern for sum

test refactoring, debug msg clean up, formatting fixes

removing v1 and cleaning up v2 + formatting

rollback the changes in reduce_ops

rename v2 -> sum_pred

remove unused funcs

switch to new c-tors

remove TensorViewType

removing an assert

fix a docstring to match a c-tor

* fixes after rebase
parent c5ffe8e9
......@@ -157,6 +157,11 @@ namespace ngraph
template <typename T>
std::vector<T> get_vector() const
{
if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0)
{
throw ngraph_error("Buffer over-read");
}
std::vector<T> rc;
const T* p = reinterpret_cast<const T*>(m_data);
for (size_t i = 0; i < shape_size(m_shape); i++)
......
......@@ -47,8 +47,8 @@ namespace ngraph
if (pattern_map[label] != graph_node)
{
NGRAPH_DEBUG << "[MATCHER] get_bound_node " << pattern_map[label]->get_name()
<< " , " << pattern_map[label] << " NOT match "
<< graph_node->get_name() << " , " << graph_node;
<< " , " << pattern_map[label] << " does NOT match "
<< graph_node->get_name();
is_match = false;
}
}
......@@ -71,9 +71,8 @@ namespace ngraph
if (is_match)
{
NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node "
<< graph_node->get_name() << " , " << graph_node << " , "
<< graph_node->get_name();
NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node " << label->get_name()
<< " , " << graph_node << " , " << graph_node->get_name();
pattern_map[label] = graph_node;
}
}
......@@ -105,8 +104,8 @@ namespace ngraph
assert(pattern_node && graph_node);
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " "
<< "matched " << graph_node << " , " << graph_node->get_name();
<< "pattern = " << pattern_node->get_name() << " matched "
<< graph_node->get_name();
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{
......@@ -151,9 +150,9 @@ namespace ngraph
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map)
{
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " "
<< "matched " << graph_node << " , " << graph_node->get_name();
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_arguments : "
<< "pattern = " << pattern_node->get_name() << " "
<< "matched " << graph_node->get_name();
auto args = get_arguments(graph_node);
auto pattern_args = get_arguments(pattern_node);
......@@ -171,7 +170,7 @@ namespace ngraph
do
{
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
<< graph_node->get_name() << " , " << graph_node;
<< graph_node->get_name();
PatternMap copy{pattern_map};
if (match_permutation(pattern_args, args, copy))
{
......@@ -231,14 +230,10 @@ namespace ngraph
throw "m_pattern_node or graph_node are not set!";
}
if (get_users(m_pattern_node).size())
{
throw "Pattern Node must not be used elsewhere!";
}
(void)get_users; //to supress an unused function warning
NGRAPH_DEBUG << "Starting match pattern = " << m_pattern_node << " , "
<< m_pattern_node->get_name() << " , graph_node = " << graph_node << " , "
<< graph_node->get_name();
NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map);
if (is_match)
......
......@@ -30,43 +30,37 @@ namespace ngraph
class Label : public Pattern
{
public:
/// \brief creates a Label node from \sa node.
/// \brief creates a Label node containing a sub-pattern described by \sa type and \sa shape.
///
/// this Label node can be bound to arbitrary nodes in an input graph
/// as long as provided \sa pred is satisfied and the node hasn't been previously bound to
/// a different node in the input graph
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa wrapped_nodes
/// Example:
/// \code{.cpp}
/// auto pattern = pattern::op::Label::make_from_node(a); //a is op::Parameter
/// matcher.match(pattern, a));
/// auto add = a + b; //a and b are op::Parameter in this example
/// auto label = std::make_shared<pattern::op::Label>(element::f32, Shape{2,2} , nullptr, Nodes{add});
/// \endcode
static std::shared_ptr<Label>
make_from_node(const std::shared_ptr<ngraph::Node>& node,
Predicate pred = nullptr)
Label(const element::Type& type,
const Shape s,
Predicate pred = nullptr,
const Nodes& wrapped_nodes = Nodes{})
: Pattern("Label", wrapped_nodes, pred)
{
auto label = std::make_shared<Label>(Nodes{}, pred);
label->add_output(node->get_element_type(), node->get_shape());
return label;
add_output(type, s);
}
/// \brief creates a Label node containing a sub-pattern described by \sa node.
/// \brief creates a Label node containing a sub-pattern described by the type and shape of \sa node.
///
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa node
/// that match the pattern specified by \sa wrapped_nodes
/// Example:
/// \code{.cpp}
/// auto add = a + b; //a and b are op::Parameter in this example
/// auto label = pattern::op::Label::wrap(add);
/// auto label = std::make_shared<pattern::op::Label>(add, nullptr, Nodes{add});
/// \endcode
static std::shared_ptr<Label> wrap(const std::shared_ptr<ngraph::Node>& node,
Predicate pred = nullptr)
{
auto label = std::make_shared<Label>(Nodes{node}, pred);
label->add_output(node->get_element_type(), node->get_shape());
return label;
}
Label(const Nodes& subgraph, Predicate pred)
: Pattern("Label", Nodes{subgraph}, pred)
Label(std::shared_ptr<Node> node,
Predicate pred = nullptr,
const Nodes& wrapped_nodes = Nodes{})
: Label(node->get_element_type(), node->get_shape(), pred, wrapped_nodes)
{
}
};
......
......@@ -22,6 +22,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
......@@ -53,9 +54,8 @@ public:
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node << " , "
<< pattern_node->get_name() << " , graph_node = " << graph_node << " , "
<< graph_node->get_name();
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
......@@ -69,9 +69,94 @@ public:
}
};
template <typename T>
std::shared_ptr<Node> create_reduction(const std::shared_ptr<Node>& node,
const std::string& init_val,
const AxisSet& reduction_axes)
{
const auto& et = node->get_element_type();
auto f_A = std::make_shared<op::Parameter>(et, Shape{});
auto f_B = std::make_shared<op::Parameter>(et, Shape{});
auto f = std::make_shared<Function>(std::make_shared<T>(f_A, f_B), op::Parameters{f_A, f_B});
auto init = std::make_shared<op::Constant>(et, Shape{}, std::vector<std::string>({init_val}));
return std::make_shared<op::Reduce>(node, init, f, reduction_axes);
}
std::shared_ptr<Node> xla_sum(const std::shared_ptr<Node>& node, const AxisSet& reduction_axes)
{
return create_reduction<op::Add>(node, "0", reduction_axes);
}
static std::shared_ptr<Node> construct_constant_node(int n)
{
return op::Constant::create(element::i32, Shape{1}, {n});
return op::Constant::create(element::i32, Shape{}, {n});
}
bool is_equal_to_const_value(std::string const_value, std::shared_ptr<Node> reduce_constant)
{
if (auto rc = std::dynamic_pointer_cast<op::Constant>(reduce_constant))
{
auto cshape = rc->get_shape();
size_t n = shape_size(cshape);
//awkward(but generic) way to construct a constant of a given type, shape, value
std::vector<std::string> vz{n, const_value};
auto zero_constant = std::make_shared<op::Constant>(rc->get_element_type(), cshape, vz);
//equally awkward way to compare elements to const_value
size_t n_bytes = n * rc->get_element_type().size();
NGRAPH_DEBUG << "Comparing " << n_bytes << " bytes";
return !memcmp(zero_constant->get_data_ptr(), rc->get_data_ptr(), n_bytes);
}
else
{
return false;
}
}
bool is_zero(std::shared_ptr<Node> reduce_constant)
{
return is_equal_to_const_value("0", reduce_constant);
}
bool sum_predicate(std::shared_ptr<Node> gn)
{
NGRAPH_DEBUG << "pred_v2 : looking at " << gn->get_name();
if (auto r = std::dynamic_pointer_cast<op::Reduce>(gn))
{
auto reducee = gn->get_input_op(0);
auto reduce_constant = gn->get_input_op(1);
if (!is_zero(reduce_constant))
{
return false;
}
NGRAPH_DEBUG << "looking at function's result "
<< r->get_functions()[0]->get_result()->get_name();
if (auto sum = std::dynamic_pointer_cast<op::Add>(r->get_functions()[0]->get_result()))
{
auto parm1 = std::dynamic_pointer_cast<op::Parameter>(sum->get_input_op(0));
auto parm2 = std::dynamic_pointer_cast<op::Parameter>(sum->get_input_op(1));
const auto parm_or_nil = [](std::shared_ptr<Node> p) {
return p ? p->get_name() : std::string("(nil)");
};
NGRAPH_DEBUG << "parm1 = " << parm_or_nil(parm1) << " , parm2 = " << parm_or_nil(parm2)
<< std::endl;
if (parm1 && parm2 && parm1 != parm2)
{
return true;
}
}
}
return false;
}
std::shared_ptr<pattern::op::Label> construct_sum_pattern() //for the sake of explicitness
{
return std::make_shared<pattern::op::Label>(element::i32, Shape{}, sum_predicate);
}
class TestGraphRewrite : public ngraph::pass::GraphRewrite
......@@ -81,12 +166,11 @@ public:
{
//pattern #1 : a * 1 = a
auto iconst1 = construct_constant_node(1);
auto pattern = pattern::op::Label::make_from_node(iconst1);
NGRAPH_DEBUG << "IN TestGraphRewrite";
auto pattern = std::make_shared<pattern::op::Label>(iconst1);
ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "IN CALLBACK";
NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
<< m.match_root()->get_name();
assert(m.match_root()->get_input_ops().size() == 2);
auto pattern_map = m.get_pattern_map();
......@@ -95,15 +179,14 @@ public:
auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_input_ops().at(const_node_index));
auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern_map[pattern];
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
ASSERT_TRUE(const_node);
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape())
{
NGRAPH_DEBUG << "TYPE/SHAPE";
NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return;
}
......@@ -113,11 +196,9 @@ public:
if (!all_ones)
{
NGRAPH_DEBUG << "ALL_ONES";
NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
return;
}
NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
};
......@@ -129,12 +210,11 @@ public:
{
//pattern #2 : a + 0 = a
auto iconst0 = construct_constant_node(0);
auto pattern = pattern::op::Label::make_from_node(iconst0);
NGRAPH_DEBUG << "IN TestGraphRewrite";
auto pattern = std::make_shared<pattern::op::Label>(iconst0);
ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "IN CALLBACK";
NGRAPH_DEBUG << "In a callback for construct_add_zero against "
<< m.match_root()->get_name();
assert(m.match_root()->get_input_ops().size() == 2);
auto pattern_map = m.get_pattern_map();
......@@ -143,15 +223,14 @@ public:
auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_input_ops().at(const_node_index));
auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern_map[pattern];
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
ASSERT_NE(nullptr, const_node);
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape())
{
NGRAPH_DEBUG << "TYPE/SHAPE";
NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return;
}
......@@ -161,11 +240,10 @@ public:
if (!all_zeros)
{
NGRAPH_DEBUG << "ALL_ZEROS";
NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
return;
}
NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
};
......@@ -173,11 +251,30 @@ public:
this->add_matcher(m);
}
void construct_sum()
{
auto sum_pattern = construct_sum_pattern();
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_sum_pattern against "
<< m.match_root()->get_name();
auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root());
auto reducee = reduce->get_inputs().at(0).get_output().get_node();
NGRAPH_DEBUG << "reducee = " << reducee->get_name();
auto sum = std::make_shared<op::Sum>(reducee, reduce->get_reduction_axes());
ngraph::replace_node(reduce, sum);
};
auto m = make_shared<TestMatcher>(sum_pattern, callback);
this->add_matcher(m);
}
TestGraphRewrite()
: GraphRewrite()
{
construct_multiply_by_one();
construct_add_zero();
construct_sum();
}
};
......@@ -185,14 +282,13 @@ static void run_passes(pass::Manager& pass_manager,
shared_ptr<Node> graph,
std::vector<shared_ptr<op::Parameter>> parms)
{
auto shape = Shape{1};
auto func = make_shared<Function>(graph, op::Parameters{parms});
pass_manager.run_passes(func);
}
TEST(pattern, graph_rewrite)
{
auto shape = Shape{1};
auto shape = Shape{};
pass::Manager pass_manager;
pass_manager.register_pass<TestGraphRewrite>();
......@@ -270,12 +366,28 @@ TEST(pattern, graph_rewrite)
ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
&graph->get_inputs().at(1))); //a's output feeds into graph's input
}
//Sum rewrite
{
auto parm = make_shared<op::Parameter>(element::i32, Shape{2, 2});
auto axes = AxisSet{0, 1};
auto sum_graph = xla_sum(parm, axes);
auto innermost_abs = make_shared<op::Abs>(sum_graph);
auto nested_sum_graph = make_shared<op::Abs>(
make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(innermost_abs))));
run_passes(pass_manager, nested_sum_graph, {parm});
auto sum = std::dynamic_pointer_cast<op::Sum>(innermost_abs->get_input_op(0));
ASSERT_TRUE(sum);
ASSERT_EQ(sum->get_reduction_axes(), axes);
ASSERT_EQ(sum->get_input_op(0), parm);
}
}
TEST(pattern, matcher)
{
auto shape = Shape{1};
auto shape = Shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
TestMatcher n(nullptr);
ASSERT_TRUE(n.match(a, a));
......@@ -288,12 +400,12 @@ TEST(pattern, matcher)
std::make_shared<pattern::op::Any>(a, [](std::shared_ptr<Node> no) { return false; });
ASSERT_TRUE(n.match(any_false, a));
auto pattern = pattern::op::Label::make_from_node(a);
auto pattern = std::make_shared<pattern::op::Label>(a);
ASSERT_TRUE(n.match(pattern, a));
ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto pattern_false =
pattern::op::Label::make_from_node(a, [](std::shared_ptr<Node> no) { return false; });
std::make_shared<pattern::op::Label>(a, [](std::shared_ptr<Node> no) { return false; });
ASSERT_FALSE(n.match(pattern_false, a));
auto b = make_shared<op::Parameter>(element::i32, shape);
......@@ -322,13 +434,13 @@ TEST(pattern, matcher)
auto iconst1_1 = construct_constant_node(1);
ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst
ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto fconst1_0 = op::Constant::create(element::f32, Shape{1}, {1});
auto patternf = pattern::op::Label::make_from_node(fconst1_0);
auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); //different iconst
//Subgraph labels
auto add = a + b;
auto label = pattern::op::Label::wrap(add);
auto label = std::make_shared<pattern::op::Label>(add, nullptr, Nodes{add});
ASSERT_TRUE(n.match(label, add));
ASSERT_EQ(n.get_pattern_map()[label], add);
......@@ -338,8 +450,9 @@ TEST(pattern, matcher)
ASSERT_EQ(n.get_pattern_map()[label], add);
//Correlations
auto label1 = pattern::op::Label::make_from_node(a);
auto label2 = pattern::op::Label::wrap(label1 + b);
auto label1 = std::make_shared<pattern::op::Label>(a);
auto tmp = label1 + b;
auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, Nodes{tmp});
auto sub_label1 = label1 - label2;
ASSERT_TRUE(n.match(sub_label1, a - add));
ASSERT_EQ(n.get_pattern_map()[label1], a);
......@@ -352,3 +465,25 @@ TEST(pattern, matcher)
ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add);
}
TEST(pattern, sum)
{
//Sum
TestMatcher n(nullptr);
auto reducee_const = std::make_shared<op::Constant>(
element::i32, Shape{2, 2}, std::vector<std::string>({"0", "0", "0", "0"}));
auto sum_graph = xla_sum(reducee_const, AxisSet{0, 1});
auto reduce_label = construct_sum_pattern();
ASSERT_TRUE(n.match(reduce_label, sum_graph));
ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);
auto nested_sum_graph = make_shared<op::Abs>(make_shared<op::Abs>(
make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(sum_graph)))));
auto nested_reduce_label = make_shared<op::Abs>(make_shared<op::Abs>(
make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(reduce_label)))));
ASSERT_TRUE(n.match(nested_reduce_label, nested_sum_graph));
ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment