Unverified Commit 18998c41 authored by Adam Procter's avatar Adam Procter Committed by GitHub

Merge branch 'master' into aprocter/de-eigenize-partial

parents 676e601d 2f0a33c3
......@@ -5,12 +5,14 @@
#include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp"
bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes)
bool ngraph::pass::GraphRewrite::run_matchers_on_nodes_list(
const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers)
{
bool rewritten = false;
for (auto node : nodes)
{
for (auto matcher : m_matchers)
for (auto matcher : matchers)
{
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node << " , "
<< node->get_name();
......@@ -29,53 +31,7 @@ bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Nod
return rewritten;
}
void ngraph::pass::GraphRewrite::replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
{
return;
}
//fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target_output.get_inputs()),
end(target_output.get_inputs())}; //replace_output modifies target_output->m_inputs
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
//fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::pass::GraphRewrite::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes)
{
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users())
{
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
//TODO: [nikolayk] recursively walk target and update users()
//nodes w/ empty users sets should be DSE'ed.
return run_matchers_on_nodes_list(nodes, m_matchers);
}
......@@ -49,10 +49,10 @@ public:
}
void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); }
static void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
static void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
virtual bool run_on_call_graph(std::list<std::shared_ptr<ngraph::Node>>&) override;
static bool
run_matchers_on_nodes_list(const std::list<std::shared_ptr<ngraph::Node>>& nodes,
const std::vector<std::shared_ptr<pattern::Matcher>>& matchers);
private:
//enable cascading rewrites
......
......@@ -36,43 +36,19 @@ namespace ngraph
begin(arguments), end(arguments)); //vector is needed for generating permutations
}
std::shared_ptr<Node> Matcher::match_root()
{
assert(is_match());
return m_match_root;
}
void Matcher::reset_pattern_nodes(
std::shared_ptr<Node> node) //TODO: [nikolayk] this doesn't have to be recursive
//even better we should walk the entire pattern subgraph once
//and keep track of all pattern nodes
{
auto label = std::dynamic_pointer_cast<::ngraph::pattern::op::Label>(node);
NGRAPH_DEBUG << "reset_pattern_nodes : node = " << node->get_name() << " , " << node;
if (label)
{
NGRAPH_DEBUG << "reset_pattern_nodes : label = " << node->get_name() << " , "
<< node;
label->reset();
}
for (auto arg : get_arguments(node))
{
reset_pattern_nodes(arg);
}
}
void Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node)
std::shared_ptr<Node> Matcher::match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
bool is_match = true;
if (label->is_bound())
if (pattern_map.count(label))
{
if (label->get_bound_node() != graph_node)
if (pattern_map[label] != graph_node)
{
NGRAPH_DEBUG << "get_bound_node " << label->get_bound_node()->get_name()
<< " , " << label->get_bound_node() << " NOT match "
<< graph_node->get_name() << " , " << graph_node;
NGRAPH_DEBUG << "get_bound_node " << pattern_map[label]->get_name() << " , "
<< pattern_map[label] << " NOT match " << graph_node->get_name()
<< " , " << graph_node;
is_match = false;
}
}
......@@ -82,103 +58,96 @@ namespace ngraph
is_match = !predicate || predicate(graph_node);
}
if (is_match)
if (is_match) //in case label was already bound this rebinds it to the same node (harmless; and the logic seems cleaner)
{
NGRAPH_DEBUG << "Binding get_bound_node " << graph_node->get_name() << " , "
NGRAPH_DEBUG << "(Re)binding get_bound_node " << graph_node->get_name() << " , "
<< graph_node << " , " << graph_node->get_name();
label->bind(graph_node);
}
else
{
reset();
m_match_root.reset();
NGRAPH_DEBUG << "MATCHER IS MATCH : " << this->is_match();
pattern_map[label] = graph_node;
}
return is_match;
}
void Matcher::match_any(const std::shared_ptr<op::Any>& any,
const std::shared_ptr<Node>& graph_node)
bool Matcher::match_any(const std::shared_ptr<op::Any>& any,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
auto predicate = any->get_predicate();
if (!predicate || any->get_predicate()(graph_node))
{
on_match_class(any, graph_node, true);
return match_arguments(any, graph_node, pattern_map);
}
else
{
auto args = get_arguments(any);
assert(args.size() == 1);
on_match_class(args.at(0), graph_node, true);
return match_node(args.at(0), graph_node, pattern_map);
}
}
void Matcher::match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node)
bool Matcher::match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
assert(pattern_node && graph_node);
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{
match_pattern(label_node, graph_node);
return;
return match_pattern(label_node, graph_node, pattern_map);
}
if (auto any_node = std::dynamic_pointer_cast<op::Any>(
pattern_node)) //matches PatternSkipOp semantics
{
match_any(any_node, graph_node);
return;
return match_any(any_node, graph_node, pattern_map);
}
auto p_pattern_node = pattern_node.get();
auto p_graph_node = graph_node.get();
if (std::type_index(typeid(*p_pattern_node)) == std::type_index(typeid(*p_graph_node)))
{
return match_arguments(pattern_node, graph_node, pattern_map);
}
on_match_class(pattern_node,
graph_node,
std::type_index(typeid(*&*pattern_node)) ==
std::type_index(typeid(*&*graph_node)));
return false;
}
void Matcher::match_arguments(const Nodes& pattern_args, const Nodes& args)
bool Matcher::match_permutation(const Nodes& pattern_args,
const Nodes& args,
PatternMap& pattern_map)
{
m_depth++;
for (size_t i = 0; i < args.size(); i++)
{
match_class(pattern_args.at(i), args.at(i));
if (!is_match())
if (!match_node(pattern_args.at(i), args.at(i), pattern_map))
{
m_depth--;
return;
return false;
}
}
m_depth--;
return true;
}
void Matcher::on_match_class(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
bool is_match)
bool Matcher::match_arguments(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map)
{
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " "
<< (is_match ? " " : "NOT ") << "matched " << graph_node << " , "
<< graph_node->get_name();
if (!is_match)
{
reset_pattern_nodes(pattern_node);
m_match_root.reset();
return;
}
<< "matched " << graph_node << " , " << graph_node->get_name();
auto args = get_arguments(graph_node);
auto pattern_args = get_arguments(pattern_node);
if (args.size() != pattern_args.size())
{
reset_pattern_nodes(pattern_node);
m_match_root.reset();
return;
return false;
}
if (graph_node->is_commutative())
{
auto old_match_root = m_match_root;
std::sort(
begin(pattern_args),
end(pattern_args)); //TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
......@@ -186,20 +155,24 @@ namespace ngraph
{
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
<< graph_node->get_name() << " , " << graph_node;
reset_pattern_nodes(pattern_node);
m_match_root =
old_match_root; //previous permutation wasn't a match; reset m_match_root
match_arguments(pattern_args, args);
if (this->is_match())
PatternMap copy{pattern_map};
if (match_permutation(pattern_args, args, copy))
{
return;
pattern_map.insert(begin(copy), end(copy));
return true;
}
} while (std::next_permutation(begin(pattern_args), end(pattern_args)));
}
else
{
match_arguments(pattern_args, args);
PatternMap copy{pattern_map};
if (match_permutation(pattern_args, args, copy))
{
pattern_map.insert(begin(copy), end(copy));
return true;
}
}
return false;
}
void Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
......@@ -211,7 +184,7 @@ namespace ngraph
}
assert(cb);
assert(is_match());
assert(this->m_match_root);
cb(*this);
}
......@@ -230,29 +203,32 @@ namespace ngraph
return result;
}
bool Matcher::match(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node)
bool Matcher::match(const std::shared_ptr<Node>& graph_node)
{
if (!pattern_node || !graph_node)
//clear our state
m_match_root.reset();
m_pattern_map.clear();
if (!m_pattern_node || !graph_node)
{
NGRAPH_DEBUG << "pattern_node or graph_node are not set; matching FAILED";
m_match_root.reset();
throw "m_pattern_node or graph_node are not set!";
}
if (get_users(pattern_node).size())
if (get_users(m_pattern_node).size())
{
throw "Pattern Node must not be used elsewhere!";
}
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node << " , "
<< pattern_node->get_name() << " , graph_node = " << graph_node << " , "
NGRAPH_DEBUG << "Starting match pattern = " << m_pattern_node << " , "
<< m_pattern_node->get_name() << " , graph_node = " << graph_node << " , "
<< graph_node->get_name();
reset_pattern_nodes(pattern_node);
m_match_root = graph_node;
match_class(pattern_node, graph_node);
//NGRAPH_DEBUG << pad(2 * m_depth) << "is_match() " << is_match();
return is_match();
bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
}
}
......@@ -41,58 +41,58 @@ namespace ngraph
class Matcher
{
public:
using PatternMap = std::map<std::shared_ptr<op::Label>, std::shared_ptr<Node>>;
/// \brief Constructs a Matcher object
///
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param callback is a callback function that will be called on a successful match
Matcher(const std::shared_ptr<Node> pattern_node = nullptr,
gr_callback_fn callback = nullptr)
: m_match_root(nullptr)
, m_pattern_node(pattern_node)
: m_pattern_node(pattern_node)
, m_callback(callback)
, m_depth(0)
{
}
virtual ~Matcher() {}
// Called when the pattern node matches a graph node.
virtual void on_match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
bool is_match);
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node)
{
return match(m_pattern_node, graph_node);
}
bool match(const std::shared_ptr<Node>& pattern_node, //keep public for testing for now
const std::shared_ptr<Node>& graph_node);
bool match(const std::shared_ptr<Node>& graph_node);
void process_match(gr_callback_fn callback = nullptr);
void reset() {}
bool is_match() { return m_match_root != nullptr; }
std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
std::shared_ptr<Node> match_root();
void reset_pattern_nodes(std::shared_ptr<Node> node);
PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; }
friend op::Label; //TODO: refine to match_class
protected:
void virtual match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node);
bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
virtual bool match_arguments(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
private:
static std::string pad(size_t num) { return std::string(num, ' '); }
void match_arguments(const Nodes& pattern_args, const Nodes& args);
void match_pattern(const std::shared_ptr<op::Label>& pattern_node,
const std::shared_ptr<Node>& graph_node);
void match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node);
std::shared_ptr<Node> m_match_root;
std::shared_ptr<Node> m_pattern_node;
PatternMap m_pattern_map;
private:
static std::string pad(size_t num) { return std::string(num, ' '); }
bool match_permutation(const Nodes& pattern_args,
const Nodes& args,
PatternMap& pattern_map);
bool match_pattern(const std::shared_ptr<op::Label>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
gr_callback_fn m_callback;
size_t m_depth;
};
......
......@@ -38,17 +38,11 @@ namespace ngraph
label->set_value_type_checked(node->get_value_type());
return label;
}
bool is_bound() { return m_bound != nullptr; }
std::shared_ptr<Node> get_bound_node() { return m_bound; }
void reset() { m_bound.reset(); }
void bind(std::shared_ptr<Node> n) { m_bound = n; }
Label(Predicate pred = nullptr)
: Pattern("Label", Nodes{}, pred)
{
}
private:
std::shared_ptr<Node> m_bound;
};
}
}
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <cassert>
#include <deque>
#include <forward_list>
#include <iomanip>
......@@ -210,3 +211,50 @@ void ngraph::free_nodes(shared_ptr<Function> p)
n->clear_arguments();
}
}
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
{
return;
}
//fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target_output.get_inputs()),
end(target_output.get_inputs())}; //replace_output modifies target_output->m_inputs
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
//fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users())
{
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
}
......@@ -245,4 +245,9 @@ namespace ngraph
std::function<void(std::shared_ptr<Function>)> f);
void free_nodes(std::shared_ptr<Function>);
void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
} // end namespace ngraph
......@@ -27,6 +27,7 @@
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp"
using namespace ngraph;
using namespace std;
......@@ -35,22 +36,37 @@ using namespace std;
class TestMatcher : public pattern::Matcher
{
using pattern::Matcher::Matcher;
void virtual match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node) override
bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map) override
{
static const auto parameter_type = std::type_index(typeid(::ngraph::op::Parameter));
const auto pattern_type = std::type_index(typeid(*&*pattern_node));
if (pattern_type == parameter_type)
if (std::dynamic_pointer_cast<::ngraph::op::Parameter>(pattern_node))
{
on_match_class(pattern_node,
graph_node,
pattern_node.get() ==
dynamic_cast<::ngraph::op::Parameter*>(graph_node.get()));
return;
return pattern_node.get() == dynamic_cast<::ngraph::op::Parameter*>(graph_node.get());
}
this->pattern::Matcher::match_class(pattern_node, graph_node);
return this->pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
}
public:
bool match(const std::shared_ptr<Node>& pattern_node, const std::shared_ptr<Node>& graph_node)
{
assert(
pattern_node &&
graph_node); //the same condition throws an exception in the non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node << " , "
<< pattern_node->get_name() << " , graph_node = " << graph_node << " , "
<< graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
}
};
......@@ -75,18 +91,19 @@ public:
NGRAPH_DEBUG << "IN CALLBACK";
assert(m.match_root()->get_arguments().size() == 2);
size_t const_node_index =
m.match_root()->get_arguments().at(0) == pattern->get_bound_node();
auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>(
m.match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern->get_bound_node()->description() << " , "
<< pattern->get_bound_node();
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern_map[pattern];
assert(const_node);
auto pattern_value_type = dynamic_pointer_cast<const TensorViewType>(
pattern->get_bound_node()->get_value_type());
auto pattern_value_type =
dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type());
auto const_node_value_type =
dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type());
assert(pattern_value_type && const_node);
......@@ -110,7 +127,7 @@ public:
}
NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::pass::GraphRewrite::replace_node(m.match_root(), pattern->get_bound_node());
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
};
auto m = make_shared<TestMatcher>(pattern * iconst1, callback);
......@@ -129,18 +146,19 @@ public:
NGRAPH_DEBUG << "IN CALLBACK";
assert(m.match_root()->get_arguments().size() == 2);
size_t const_node_index =
m.match_root()->get_arguments().at(0) == pattern->get_bound_node();
auto pattern_map = m.get_pattern_map();
size_t const_node_index = m.match_root()->get_arguments().at(0) == pattern_map[pattern];
auto const_node = dynamic_pointer_cast<op::ParameterizedConstant<element::Int32>>(
m.match_root()->get_arguments().at(const_node_index));
auto second_node = m.match_root()->get_arguments().at(const_node_index);
NGRAPH_DEBUG << "second_node " << second_node->description() << " , " << second_node;
NGRAPH_DEBUG << "pattern " << pattern->get_bound_node()->description() << " , "
<< pattern->get_bound_node();
NGRAPH_DEBUG << "pattern " << pattern_map[pattern]->description() << " , "
<< pattern_map[pattern];
assert(const_node);
auto pattern_value_type = dynamic_pointer_cast<const TensorViewType>(
pattern->get_bound_node()->get_value_type());
auto pattern_value_type =
dynamic_pointer_cast<const TensorViewType>(pattern_map[pattern]->get_value_type());
auto const_node_value_type =
dynamic_pointer_cast<const TensorViewType>(const_node->get_value_type());
assert(pattern_value_type && const_node);
......@@ -164,7 +182,7 @@ public:
}
NGRAPH_DEBUG << "BEFORE REPLACE";
ngraph::pass::GraphRewrite::replace_node(m.match_root(), pattern->get_bound_node());
ngraph::replace_node(m.match_root(), pattern_map[pattern]);
};
auto m = make_shared<TestMatcher>(pattern + iconst0, callback);
......@@ -292,7 +310,7 @@ TEST(pattern, matcher)
auto pattern = pattern::op::Label::make_from_node(a);
ASSERT_TRUE(n.match(pattern, a));
ASSERT_EQ(pattern->get_bound_node(), a);
ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto pattern_false =
pattern::op::Label::make_from_node(a, [](std::shared_ptr<Node> no) { return false; });
......@@ -306,14 +324,14 @@ TEST(pattern, matcher)
ASSERT_TRUE(n.match(any + b, abs + b));
ASSERT_TRUE(n.match(pattern + b, abs + b));
ASSERT_EQ(pattern->get_bound_node(), abs);
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_TRUE(n.match(b + pattern, abs + b));
ASSERT_EQ(pattern->get_bound_node(), abs);
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
auto c = make_shared<op::Parameter>(element::Int32::element_type(), shape);
ASSERT_TRUE(n.match(c * (b + pattern), c * (abs + b)));
ASSERT_EQ(pattern->get_bound_node(), abs);
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_TRUE(n.match(c * (any + b), c * (abs + b))); //nested any
ASSERT_TRUE(n.match(c * (any + b), (b + abs) * c)); //permutations w/ any
......@@ -323,7 +341,7 @@ TEST(pattern, matcher)
auto iconst1_0 = construct_constant_node(1);
auto iconst1_1 = construct_constant_node(1);
ASSERT_TRUE(n.match(pattern * iconst1_0, a * iconst1_1)); //different iconst
ASSERT_EQ(pattern->get_bound_node(), a);
ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto fconst1_0 =
make_shared<op::Constant>(element::Float32::element_type(), Shape{1}, std::to_string(1));
auto patternf = pattern::op::Label::make_from_node(fconst1_0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment