Unverified Commit 4345e39d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Pattern matching for sum (#293)

* the first stab at pattern for sum

test refactoring, debug msg clean up, formatting fixes

removing v1 and cleaning up v2 + formatting

rollback the changes in reduce_ops

rename v2 -> sum_pred

remove unused funcs

switch to new c-tors

remove TensorViewType

removing an assert

fix a docstring to match a c-tor

* fixes after rebase
parent c5ffe8e9
...@@ -157,6 +157,11 @@ namespace ngraph ...@@ -157,6 +157,11 @@ namespace ngraph
template <typename T> template <typename T>
std::vector<T> get_vector() const std::vector<T> get_vector() const
{ {
if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0)
{
throw ngraph_error("Buffer over-read");
}
std::vector<T> rc; std::vector<T> rc;
const T* p = reinterpret_cast<const T*>(m_data); const T* p = reinterpret_cast<const T*>(m_data);
for (size_t i = 0; i < shape_size(m_shape); i++) for (size_t i = 0; i < shape_size(m_shape); i++)
......
...@@ -47,8 +47,8 @@ namespace ngraph ...@@ -47,8 +47,8 @@ namespace ngraph
if (pattern_map[label] != graph_node) if (pattern_map[label] != graph_node)
{ {
NGRAPH_DEBUG << "[MATCHER] get_bound_node " << pattern_map[label]->get_name() NGRAPH_DEBUG << "[MATCHER] get_bound_node " << pattern_map[label]->get_name()
<< " , " << pattern_map[label] << " NOT match " << " , " << pattern_map[label] << " does NOT match "
<< graph_node->get_name() << " , " << graph_node; << graph_node->get_name();
is_match = false; is_match = false;
} }
} }
...@@ -71,9 +71,8 @@ namespace ngraph ...@@ -71,9 +71,8 @@ namespace ngraph
if (is_match) if (is_match)
{ {
NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node " NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node " << label->get_name()
<< graph_node->get_name() << " , " << graph_node << " , " << " , " << graph_node << " , " << graph_node->get_name();
<< graph_node->get_name();
pattern_map[label] = graph_node; pattern_map[label] = graph_node;
} }
} }
...@@ -105,8 +104,8 @@ namespace ngraph ...@@ -105,8 +104,8 @@ namespace ngraph
assert(pattern_node && graph_node); assert(pattern_node && graph_node);
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : " NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " " << "pattern = " << pattern_node->get_name() << " matched "
<< "matched " << graph_node << " , " << graph_node->get_name(); << graph_node->get_name();
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node)) if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{ {
...@@ -151,9 +150,9 @@ namespace ngraph ...@@ -151,9 +150,9 @@ namespace ngraph
const std::shared_ptr<ngraph::Node>& graph_node, const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map) PatternMap& pattern_map)
{ {
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] " NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_arguments : "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " " << "pattern = " << pattern_node->get_name() << " "
<< "matched " << graph_node << " , " << graph_node->get_name(); << "matched " << graph_node->get_name();
auto args = get_arguments(graph_node); auto args = get_arguments(graph_node);
auto pattern_args = get_arguments(pattern_node); auto pattern_args = get_arguments(pattern_node);
...@@ -171,7 +170,7 @@ namespace ngraph ...@@ -171,7 +170,7 @@ namespace ngraph
do do
{ {
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node " NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
<< graph_node->get_name() << " , " << graph_node; << graph_node->get_name();
PatternMap copy{pattern_map}; PatternMap copy{pattern_map};
if (match_permutation(pattern_args, args, copy)) if (match_permutation(pattern_args, args, copy))
{ {
...@@ -231,14 +230,10 @@ namespace ngraph ...@@ -231,14 +230,10 @@ namespace ngraph
throw "m_pattern_node or graph_node are not set!"; throw "m_pattern_node or graph_node are not set!";
} }
if (get_users(m_pattern_node).size()) (void)get_users; //to supress an unused function warning
{
throw "Pattern Node must not be used elsewhere!";
}
NGRAPH_DEBUG << "Starting match pattern = " << m_pattern_node << " , " NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name()
<< m_pattern_node->get_name() << " , graph_node = " << graph_node << " , " << " , graph_node = " << graph_node->get_name();
<< graph_node->get_name();
bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map); bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map);
if (is_match) if (is_match)
......
...@@ -30,43 +30,37 @@ namespace ngraph ...@@ -30,43 +30,37 @@ namespace ngraph
class Label : public Pattern class Label : public Pattern
{ {
public: public:
/// \brief creates a Label node from \sa node. /// \brief creates a Label node containing a sub-pattern described by \sa type and \sa shape.
/// ///
/// this Label node can be bound to arbitrary nodes in an input graph /// this Label node can be bound only to the nodes in the input graph
/// as long as provided \sa pred is satisfied and the node hasn't been previously bound to /// that match the pattern specified by \sa wrapped_nodes
/// a different node in the input graph /// Example:
/// \code{.cpp} /// \code{.cpp}
/// auto pattern = pattern::op::Label::make_from_node(a); //a is op::Parameter /// auto add = a + b; //a and b are op::Parameter in this example
/// matcher.match(pattern, a)); /// auto label = std::make_shared<pattern::op::Label>(element::f32, Shape{2,2} , nullptr, Nodes{add});
/// \endcode /// \endcode
static std::shared_ptr<Label> Label(const element::Type& type,
make_from_node(const std::shared_ptr<ngraph::Node>& node, const Shape s,
Predicate pred = nullptr) Predicate pred = nullptr,
const Nodes& wrapped_nodes = Nodes{})
: Pattern("Label", wrapped_nodes, pred)
{ {
auto label = std::make_shared<Label>(Nodes{}, pred); add_output(type, s);
label->add_output(node->get_element_type(), node->get_shape());
return label;
} }
/// \brief creates a Label node containing a sub-pattern described by \sa node. /// \brief creates a Label node containing a sub-pattern described by the type and shape of \sa node.
/// ///
/// this Label node can be bound only to the nodes in the input graph /// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa node /// that match the pattern specified by \sa wrapped_nodes
/// Example: /// Example:
/// \code{.cpp} /// \code{.cpp}
/// auto add = a + b; //a and b are op::Parameter in this example /// auto add = a + b; //a and b are op::Parameter in this example
/// auto label = pattern::op::Label::wrap(add); /// auto label = std::make_shared<pattern::op::Label>(add, nullptr, Nodes{add});
/// \endcode /// \endcode
static std::shared_ptr<Label> wrap(const std::shared_ptr<ngraph::Node>& node, Label(std::shared_ptr<Node> node,
Predicate pred = nullptr) Predicate pred = nullptr,
{ const Nodes& wrapped_nodes = Nodes{})
auto label = std::make_shared<Label>(Nodes{node}, pred); : Label(node->get_element_type(), node->get_shape(), pred, wrapped_nodes)
label->add_output(node->get_element_type(), node->get_shape());
return label;
}
Label(const Nodes& subgraph, Predicate pred)
: Pattern("Label", Nodes{subgraph}, pred)
{ {
} }
}; };
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
...@@ -53,9 +54,8 @@ public: ...@@ -53,9 +54,8 @@ public:
assert( assert(
pattern_node && pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match` graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node << " , " NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< pattern_node->get_name() << " , graph_node = " << graph_node << " , " << " , graph_node = " << graph_node->get_name();
<< graph_node->get_name();
m_pattern_map.clear(); m_pattern_map.clear();
m_match_root.reset(); m_match_root.reset();
...@@ -69,9 +69,94 @@ public: ...@@ -69,9 +69,94 @@ public:
} }
}; };
template <typename T>
std::shared_ptr<Node> create_reduction(const std::shared_ptr<Node>& node,
const std::string& init_val,
const AxisSet& reduction_axes)
{
const auto& et = node->get_element_type();
auto f_A = std::make_shared<op::Parameter>(et, Shape{});
auto f_B = std::make_shared<op::Parameter>(et, Shape{});
auto f = std::make_shared<Function>(std::make_shared<T>(f_A, f_B), op::Parameters{f_A, f_B});
auto init = std::make_shared<op::Constant>(et, Shape{}, std::vector<std::string>({init_val}));
return std::make_shared<op::Reduce>(node, init, f, reduction_axes);
}
std::shared_ptr<Node> xla_sum(const std::shared_ptr<Node>& node, const AxisSet& reduction_axes)
{
return create_reduction<op::Add>(node, "0", reduction_axes);
}
static std::shared_ptr<Node> construct_constant_node(int n) static std::shared_ptr<Node> construct_constant_node(int n)
{ {
return op::Constant::create(element::i32, Shape{1}, {n}); return op::Constant::create(element::i32, Shape{}, {n});
}
bool is_equal_to_const_value(std::string const_value, std::shared_ptr<Node> reduce_constant)
{
if (auto rc = std::dynamic_pointer_cast<op::Constant>(reduce_constant))
{
auto cshape = rc->get_shape();
size_t n = shape_size(cshape);
//awkward(but generic) way to construct a constant of a given type, shape, value
std::vector<std::string> vz{n, const_value};
auto zero_constant = std::make_shared<op::Constant>(rc->get_element_type(), cshape, vz);
//equally awkward way to compare elements to const_value
size_t n_bytes = n * rc->get_element_type().size();
NGRAPH_DEBUG << "Comparing " << n_bytes << " bytes";
return !memcmp(zero_constant->get_data_ptr(), rc->get_data_ptr(), n_bytes);
}
else
{
return false;
}
}
bool is_zero(std::shared_ptr<Node> reduce_constant)
{
return is_equal_to_const_value("0", reduce_constant);
}
bool sum_predicate(std::shared_ptr<Node> gn)
{
NGRAPH_DEBUG << "pred_v2 : looking at " << gn->get_name();
if (auto r = std::dynamic_pointer_cast<op::Reduce>(gn))
{
auto reducee = gn->get_input_op(0);
auto reduce_constant = gn->get_input_op(1);
if (!is_zero(reduce_constant))
{
return false;
}
NGRAPH_DEBUG << "looking at function's result "
<< r->get_functions()[0]->get_result()->get_name();
if (auto sum = std::dynamic_pointer_cast<op::Add>(r->get_functions()[0]->get_result()))
{
auto parm1 = std::dynamic_pointer_cast<op::Parameter>(sum->get_input_op(0));
auto parm2 = std::dynamic_pointer_cast<op::Parameter>(sum->get_input_op(1));
const auto parm_or_nil = [](std::shared_ptr<Node> p) {
return p ? p->get_name() : std::string("(nil)");
};
NGRAPH_DEBUG << "parm1 = " << parm_or_nil(parm1) << " , parm2 = " << parm_or_nil(parm2)
<< std::endl;
if (parm1 && parm2 && parm1 != parm2)
{
return true;
}
}
}
return false;
}
std::shared_ptr<pattern::op::Label> construct_sum_pattern() //for the sake of explicitness
{
return std::make_shared<pattern::op::Label>(element::i32, Shape{}, sum_predicate);
} }
class TestGraphRewrite : public ngraph::pass::GraphRewrite class TestGraphRewrite : public ngraph::pass::GraphRewrite
...@@ -81,12 +166,11 @@ public: ...@@ -81,12 +166,11 @@ public:
{ {
//pattern #1 : a * 1 = a //pattern #1 : a * 1 = a
auto iconst1 = construct_constant_node(1); auto iconst1 = construct_constant_node(1);
auto pattern = pattern::op::Label::make_from_node(iconst1); auto pattern = std::make_shared<pattern::op::Label>(iconst1);
NGRAPH_DEBUG << "IN TestGraphRewrite";
ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) { ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "IN CALLBACK"; NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
<< m.match_root()->get_name();
assert(m.match_root()->get_input_ops().size() == 2); assert(m.match_root()->get_input_ops().size() == 2);
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -95,15 +179,14 @@ public: ...@@ -95,15 +179,14 @@ public:
auto const_node = dynamic_pointer_cast<op::Constant>( auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_input_ops().at(const_node_index)); m.match_root()->get_input_ops().at(const_node_index));
auto second_node = m.match_root()->get_input_ops().at(const_node_index); auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node; NGRAPH_DEBUG << "second_node = " << second_node->get_name()
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , " << " , pattern = " << pattern_map[pattern]->get_name();
<< pattern_map[pattern];
ASSERT_TRUE(const_node); ASSERT_TRUE(const_node);
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() || if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape()) pattern_map[pattern]->get_shape() != const_node->get_shape())
{ {
NGRAPH_DEBUG << "TYPE/SHAPE"; NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return; return;
} }
...@@ -113,11 +196,9 @@ public: ...@@ -113,11 +196,9 @@ public:
if (!all_ones) if (!all_ones)
{ {
NGRAPH_DEBUG << "ALL_ONES"; NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
return; return;
} }
NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::replace_node(m.match_root(), pattern_map[pattern]); ngraph::replace_node(m.match_root(), pattern_map[pattern]);
}; };
...@@ -129,12 +210,11 @@ public: ...@@ -129,12 +210,11 @@ public:
{ {
//pattern #2 : a + 0 = a //pattern #2 : a + 0 = a
auto iconst0 = construct_constant_node(0); auto iconst0 = construct_constant_node(0);
auto pattern = pattern::op::Label::make_from_node(iconst0); auto pattern = std::make_shared<pattern::op::Label>(iconst0);
NGRAPH_DEBUG << "IN TestGraphRewrite";
ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) { ngraph::pattern::gr_callback_fn callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "IN CALLBACK"; NGRAPH_DEBUG << "In a callback for construct_add_zero against "
<< m.match_root()->get_name();
assert(m.match_root()->get_input_ops().size() == 2); assert(m.match_root()->get_input_ops().size() == 2);
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
...@@ -143,15 +223,14 @@ public: ...@@ -143,15 +223,14 @@ public:
auto const_node = dynamic_pointer_cast<op::Constant>( auto const_node = dynamic_pointer_cast<op::Constant>(
m.match_root()->get_input_ops().at(const_node_index)); m.match_root()->get_input_ops().at(const_node_index));
auto second_node = m.match_root()->get_input_ops().at(const_node_index); auto second_node = m.match_root()->get_input_ops().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node; NGRAPH_DEBUG << "second_node = " << second_node->get_name()
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , " << " , pattern = " << pattern_map[pattern]->get_name();
<< pattern_map[pattern];
ASSERT_NE(nullptr, const_node); ASSERT_NE(nullptr, const_node);
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() || if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape()) pattern_map[pattern]->get_shape() != const_node->get_shape())
{ {
NGRAPH_DEBUG << "TYPE/SHAPE"; NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return; return;
} }
...@@ -161,11 +240,10 @@ public: ...@@ -161,11 +240,10 @@ public:
if (!all_zeros) if (!all_zeros)
{ {
NGRAPH_DEBUG << "ALL_ZEROS"; NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
return; return;
} }
NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::replace_node(m.match_root(), pattern_map[pattern]); ngraph::replace_node(m.match_root(), pattern_map[pattern]);
}; };
...@@ -173,11 +251,30 @@ public: ...@@ -173,11 +251,30 @@ public:
this->add_matcher(m); this->add_matcher(m);
} }
void construct_sum()
{
auto sum_pattern = construct_sum_pattern();
ngraph::pattern::gr_callback_fn callback = [](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_sum_pattern against "
<< m.match_root()->get_name();
auto reduce = std::dynamic_pointer_cast<op::Reduce>(m.match_root());
auto reducee = reduce->get_inputs().at(0).get_output().get_node();
NGRAPH_DEBUG << "reducee = " << reducee->get_name();
auto sum = std::make_shared<op::Sum>(reducee, reduce->get_reduction_axes());
ngraph::replace_node(reduce, sum);
};
auto m = make_shared<TestMatcher>(sum_pattern, callback);
this->add_matcher(m);
}
TestGraphRewrite() TestGraphRewrite()
: GraphRewrite() : GraphRewrite()
{ {
construct_multiply_by_one(); construct_multiply_by_one();
construct_add_zero(); construct_add_zero();
construct_sum();
} }
}; };
...@@ -185,14 +282,13 @@ static void run_passes(pass::Manager& pass_manager, ...@@ -185,14 +282,13 @@ static void run_passes(pass::Manager& pass_manager,
shared_ptr<Node> graph, shared_ptr<Node> graph,
std::vector<shared_ptr<op::Parameter>> parms) std::vector<shared_ptr<op::Parameter>> parms)
{ {
auto shape = Shape{1};
auto func = make_shared<Function>(graph, op::Parameters{parms}); auto func = make_shared<Function>(graph, op::Parameters{parms});
pass_manager.run_passes(func); pass_manager.run_passes(func);
} }
TEST(pattern, graph_rewrite) TEST(pattern, graph_rewrite)
{ {
auto shape = Shape{1}; auto shape = Shape{};
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<TestGraphRewrite>(); pass_manager.register_pass<TestGraphRewrite>();
...@@ -270,12 +366,28 @@ TEST(pattern, graph_rewrite) ...@@ -270,12 +366,28 @@ TEST(pattern, graph_rewrite)
ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count( ASSERT_TRUE(a->get_outputs().at(0).get_inputs().count(
&graph->get_inputs().at(1))); //a's output feeds into graph's input &graph->get_inputs().at(1))); //a's output feeds into graph's input
} }
//Sum rewrite
{
auto parm = make_shared<op::Parameter>(element::i32, Shape{2, 2});
auto axes = AxisSet{0, 1};
auto sum_graph = xla_sum(parm, axes);
auto innermost_abs = make_shared<op::Abs>(sum_graph);
auto nested_sum_graph = make_shared<op::Abs>(
make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(innermost_abs))));
run_passes(pass_manager, nested_sum_graph, {parm});
auto sum = std::dynamic_pointer_cast<op::Sum>(innermost_abs->get_input_op(0));
ASSERT_TRUE(sum);
ASSERT_EQ(sum->get_reduction_axes(), axes);
ASSERT_EQ(sum->get_input_op(0), parm);
}
} }
TEST(pattern, matcher) TEST(pattern, matcher)
{ {
auto shape = Shape{1}; auto shape = Shape{};
auto a = make_shared<op::Parameter>(element::i32, shape); auto a = make_shared<op::Parameter>(element::i32, shape);
TestMatcher n(nullptr); TestMatcher n(nullptr);
ASSERT_TRUE(n.match(a, a)); ASSERT_TRUE(n.match(a, a));
...@@ -288,12 +400,12 @@ TEST(pattern, matcher) ...@@ -288,12 +400,12 @@ TEST(pattern, matcher)
std::make_shared<pattern::op::Any>(a, [](std::shared_ptr<Node> no) { return false; }); std::make_shared<pattern::op::Any>(a, [](std::shared_ptr<Node> no) { return false; });
ASSERT_TRUE(n.match(any_false, a)); ASSERT_TRUE(n.match(any_false, a));
auto pattern = pattern::op::Label::make_from_node(a); auto pattern = std::make_shared<pattern::op::Label>(a);
ASSERT_TRUE(n.match(pattern, a)); ASSERT_TRUE(n.match(pattern, a));
ASSERT_EQ(n.get_pattern_map()[pattern], a); ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto pattern_false = auto pattern_false =
pattern::op::Label::make_from_node(a, [](std::shared_ptr<Node> no) { return false; }); std::make_shared<pattern::op::Label>(a, [](std::shared_ptr<Node> no) { return false; });
ASSERT_FALSE(n.match(pattern_false, a)); ASSERT_FALSE(n.match(pattern_false, a));
auto b = make_shared<op::Parameter>(element::i32, shape); auto b = make_shared<op::Parameter>(element::i32, shape);
...@@ -322,13 +434,13 @@ TEST(pattern, matcher) ...@@ -322,13 +434,13 @@ TEST(pattern, matcher)
auto iconst1_1 = construct_constant_node(1); auto iconst1_1 = construct_constant_node(1);
ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst
ASSERT_EQ(n.get_pattern_map()[pattern], a); ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto fconst1_0 = op::Constant::create(element::f32, Shape{1}, {1}); auto fconst1_0 = op::Constant::create(element::f32, shape, {1});
auto patternf = pattern::op::Label::make_from_node(fconst1_0); auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); //different iconst ASSERT_TRUE(n.match(patternf * fconst1_0, a * iconst1_1)); //different iconst
//Subgraph labels //Subgraph labels
auto add = a + b; auto add = a + b;
auto label = pattern::op::Label::wrap(add); auto label = std::make_shared<pattern::op::Label>(add, nullptr, Nodes{add});
ASSERT_TRUE(n.match(label, add)); ASSERT_TRUE(n.match(label, add));
ASSERT_EQ(n.get_pattern_map()[label], add); ASSERT_EQ(n.get_pattern_map()[label], add);
...@@ -338,8 +450,9 @@ TEST(pattern, matcher) ...@@ -338,8 +450,9 @@ TEST(pattern, matcher)
ASSERT_EQ(n.get_pattern_map()[label], add); ASSERT_EQ(n.get_pattern_map()[label], add);
//Correlations //Correlations
auto label1 = pattern::op::Label::make_from_node(a); auto label1 = std::make_shared<pattern::op::Label>(a);
auto label2 = pattern::op::Label::wrap(label1 + b); auto tmp = label1 + b;
auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, Nodes{tmp});
auto sub_label1 = label1 - label2; auto sub_label1 = label1 - label2;
ASSERT_TRUE(n.match(sub_label1, a - add)); ASSERT_TRUE(n.match(sub_label1, a - add));
ASSERT_EQ(n.get_pattern_map()[label1], a); ASSERT_EQ(n.get_pattern_map()[label1], a);
...@@ -352,3 +465,25 @@ TEST(pattern, matcher) ...@@ -352,3 +465,25 @@ TEST(pattern, matcher)
ASSERT_EQ(n.get_pattern_map()[label1], a); ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add); ASSERT_EQ(n.get_pattern_map()[label2], add);
} }
TEST(pattern, sum)
{
//Sum
TestMatcher n(nullptr);
auto reducee_const = std::make_shared<op::Constant>(
element::i32, Shape{2, 2}, std::vector<std::string>({"0", "0", "0", "0"}));
auto sum_graph = xla_sum(reducee_const, AxisSet{0, 1});
auto reduce_label = construct_sum_pattern();
ASSERT_TRUE(n.match(reduce_label, sum_graph));
ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);
auto nested_sum_graph = make_shared<op::Abs>(make_shared<op::Abs>(
make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(sum_graph)))));
auto nested_reduce_label = make_shared<op::Abs>(make_shared<op::Abs>(
make_shared<op::Abs>(make_shared<op::Abs>(make_shared<op::Abs>(reduce_label)))));
ASSERT_TRUE(n.match(nested_reduce_label, nested_sum_graph));
ASSERT_EQ(n.get_pattern_map()[reduce_label], sum_graph);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment