Commit 3e68842b authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

REBASE: graph pattern matcher half I/O half arguments/users (#269)

* Start of pattern matcher

recursive graph matcher, pattern node

add matcher.cpp

add files for matcher, graph_rewrite

add const to on_match_class

fix comp errors

reshuffle pattern matching code across corresponding files

fix comment

run clang-format

graph_rewrite replace_node

getting simple test cases to work

op/pattern.cpp

toward graph_rewrite tests

older matcher API

before clean up tests

before rebase

build bbrks

more tests

clean up

more clean-up

more cleanup 2

more clean up 3

clean up 4

clang errors

clang errors2

apply code format

move match_class to matcher

major clean up after moving match_class to matcher.cpp

removing tracing changes

rebased as of 11/8

make matcher use i/o descs to traverse the graph; change replace_io

switching to io tds

graph_rewrite tests fail

all tests pass

formatting

unhandle outputs explicitly for now

reset permissions back to 0644; bad bad windows

fixes after rebase

* fixes

* addressing Scott's feedback
parent ca977e70
......@@ -57,6 +57,7 @@ set (SRC
ops/unary_elementwise_arithmetic.cpp
ops/unary_elementwise.cpp
pass/dump_sorted.cpp
pass/graph_rewrite.cpp
pass/liveness.cpp
pass/manager.cpp
pass/manager_state.cpp
......@@ -66,6 +67,7 @@ set (SRC
pass/topological_sort.cpp
pass/visualize_tree.cpp
runtime/aligned_buffer.cpp
pattern/matcher.cpp
runtime/backend.cpp
runtime/manager.cpp
runtime/ngvm/call_frame.cpp
......
......@@ -24,11 +24,18 @@ Input::Input(Node* node, size_t index, size_t argno, size_t arg_index, Output& o
, m_index(index)
, m_argno(argno)
, m_arg_index(arg_index)
, m_output(output)
, m_output(&output)
{
output.add_input(this);
}
void Input::replace_output(Output& new_output)
{
m_output->remove_input(this);
new_output.add_input(this);
m_output = &new_output;
}
std::shared_ptr<Node> Input::get_node()
{
return m_node->shared_from_this();
......@@ -36,25 +43,25 @@ std::shared_ptr<Node> Input::get_node()
const Tensor& Input::get_tensor() const
{
return m_output.get_tensor();
return m_output->get_tensor();
}
Tensor& Input::get_tensor()
{
return m_output.get_tensor();
return m_output->get_tensor();
}
std::shared_ptr<const TensorView> Input::get_tensor_view() const
{
return m_output.get_tensor_view();
return m_output->get_tensor_view();
}
std::shared_ptr<TensorView> Input::get_tensor_view()
{
return m_output.get_tensor_view();
return m_output->get_tensor_view();
}
std::shared_ptr<const TensorViewType> Input::get_tensor_view_type() const
{
return m_output.get_tensor_view()->get_tensor_view_type();
return m_output->get_tensor_view()->get_tensor_view_type();
}
......@@ -50,15 +50,17 @@ namespace ngraph
/// @return the position within all supplied tensors of this input
size_t get_index() const { return m_index; }
// @return the connected output
const Output& get_output() const { return m_output; }
const Output& get_output() const { return *m_output; }
// @return the connected output
Output& get_output() { return m_output; }
Output& get_output() { return *m_output; }
// @return the tensor of the connected output
const Tensor& get_tensor() const;
// @return the tensor of the connected output
Tensor& get_tensor();
void replace_output(Output& output);
/// @return the tensor view for the connected output
std::shared_ptr<const TensorView> get_tensor_view() const;
......@@ -73,7 +75,7 @@ namespace ngraph
size_t m_index; // Index into all input tensors
size_t m_argno; // Arg number for this input
size_t m_arg_index; // Index into arg's tensors
Output& m_output;
Output* m_output;
private:
Input(const Input&) = delete;
......
......@@ -33,6 +33,11 @@ void Output::add_input(Input* input)
m_inputs.insert(input);
}
void Output::remove_input(Input* input)
{
m_inputs.erase(input);
}
std::shared_ptr<Node> Output::get_node() const
{
return m_node->shared_from_this();
......
......@@ -44,6 +44,7 @@ namespace ngraph
size_t get_index() const { return m_index; }
std::shared_ptr<TensorView> get_tensor_view() const { return m_tensor_view; }
void add_input(Input* input);
void remove_input(Input* input);
const std::set<Input*>& get_inputs() const { return m_inputs; }
const Tensor& get_tensor() const;
Tensor& get_tensor();
......
......@@ -37,6 +37,12 @@ static condition_variable queue_condition;
static unique_ptr<thread> queue_thread;
static bool active = false;
std::ostream& nervana::get_nil_stream()
{
static std::stringstream nil;
return nil;
}
class nervana::thread_starter
{
public:
......@@ -100,6 +106,7 @@ nervana::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const
case LOG_TYPE::_LOG_TYPE_ERROR: _stream << "[ERR ] "; break;
case LOG_TYPE::_LOG_TYPE_WARNING: _stream << "[WARN] "; break;
case LOG_TYPE::_LOG_TYPE_INFO: _stream << "[INFO] "; break;
case LOG_TYPE::_LOG_TYPE_DEBUG: _stream << "[DEBUG] "; break;
}
std::time_t tt = chrono::system_clock::to_time_t(chrono::system_clock::now());
......
......@@ -58,6 +58,7 @@ namespace nervana
_LOG_TYPE_ERROR,
_LOG_TYPE_WARNING,
_LOG_TYPE_INFO,
_LOG_TYPE_DEBUG,
};
class log_helper
......@@ -88,6 +89,8 @@ namespace nervana
static std::deque<std::string> queue;
};
extern std::ostream& get_nil_stream();
#define NGRAPH_ERR \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_ERROR, \
nervana::get_file_name(__FILE__), \
......@@ -106,4 +109,13 @@ namespace nervana
__LINE__, \
__PRETTY_FUNCTION__) \
.stream()
/*
#define NGRAPH_DEBUG \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_DEBUG, \
nervana::get_file_name(__FILE__), \
__LINE__, \
__PRETTY_FUNCTION__) \
.stream()
*/
#define NGRAPH_DEBUG nervana::get_nil_stream()
}
......@@ -13,6 +13,10 @@
// ----------------------------------------------------------------------------
#include "ngraph/node.hpp"
#include <memory>
#include <typeindex>
#include <typeinfo>
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/descriptor/primary_tensor_view.hpp"
#include "ngraph/ops/parameter.hpp"
......
......@@ -90,7 +90,7 @@ namespace ngraph
bool is_output() const;
void set_is_output();
virtual bool is_constant() const;
virtual bool is_commutative() { return false; }
size_t get_instance_id() const { return m_instance_id; }
friend std::ostream& operator<<(std::ostream&, const Node&);
......
......@@ -65,6 +65,7 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
virtual bool is_commutative() override { return true; }
};
}
......
......@@ -63,6 +63,7 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override;
virtual bool is_commutative() override { return true; }
};
};
......
File mode changed from 100644 to 100755
#include "graph_rewrite.hpp"
#include <algorithm>
#include <iostream>
#include <unordered_set>
#include "ngraph/log.hpp"
#include "ngraph/pattern/matcher.hpp"
bool ngraph::pass::GraphRewrite::run_on_call_graph(std::list<std::shared_ptr<Node>>& nodes)
{
bool rewritten = false;
for (auto node : nodes)
{
for (auto matcher : m_matchers)
{
NGRAPH_DEBUG << "Running matcher " << matcher << " on " << node << " , "
<< node->get_name();
if (!node->is_output() /*this restriction can be lifted when we find an use case for it*/
&&
matcher->match(node))
{
NGRAPH_DEBUG << "Matcher " << matcher << " matched " << node << " , "
<< node->get_name();
rewritten = true;
matcher->process_match();
break; //move onto the next node
}
}
}
return rewritten;
}
void ngraph::pass::GraphRewrite::replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
if (target->is_output()) //this restriction can be lifted when we find an use case for it
{
return;
}
//fix input/output descriptors
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
assert(target->get_outputs().size() == replacement->get_outputs().size());
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{
begin(target_output.get_inputs()),
end(target_output.get_inputs())}; //replace_output modifies target_output->m_inputs
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
//fix users and arguments
replace_node_users_arguments(target, replacement);
}
void ngraph::pass::GraphRewrite::replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement)
{
NGRAPH_DEBUG << "Replacing target = " << target << " , " << target->get_name() << " , "
<< "replacement = " << replacement << " , " << replacement->get_name();
NGRAPH_DEBUG << "user = " << replacement << " , " << replacement->get_name();
for (auto user : target->users())
{
auto& args = const_cast<ngraph::Nodes&>(user->get_arguments());
auto it = std::find(begin(args), end(args), target);
assert(it != end(args));
//NGRAPH_DEBUG << "Replaced " << *it << " w/ " << replacement << " in args of " << user << " , args = " << &args;
it = args.erase(it);
args.insert(it, replacement);
const_cast<std::multiset<Node*>&>(replacement->users()).insert(user);
}
const_cast<std::multiset<Node*>&>(target->users()).clear();
//TODO: [nikolayk] recursively walk target and update users()
//nodes w/ empty users sets should be DSE'ed.
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include <set>
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class GraphRewrite;
}
namespace pattern
{
class Matcher;
}
}
/// \brief GraphRewrite (in tandem with \sa Matcher) performs transformations on specified patterns
///
/// Graph rewrite pass essentially allows pass users to rewrite parts of the
/// input graph in any way they want. Fusion is one example of graph rewrite that
/// fuses multiple ops together. At a high-level users of the pass need to
/// specify 2 things: 1) which ops to fuse (via \sa Matcher, and 2) how to create new op(s) from
/// the existing ops by providing a callback to \p Matcher object
/// Patterns can be added by using \sa add_matcher
/// Callbacks should use \sa replace_node to transform matched sub graphs
class ngraph::pass::GraphRewrite : public CallGraphPass
{
public:
GraphRewrite()
: CallGraphPass()
{
}
void add_matcher(std::shared_ptr<pattern::Matcher> m) { m_matchers.push_back(m); }
static void replace_node_users_arguments(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement);
static void replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement);
virtual bool run_on_call_graph(std::list<std::shared_ptr<ngraph::Node>>&) override;
private:
//enable cascading rewrites
std::vector<std::shared_ptr<pattern::Matcher>> m_matchers;
};
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "matcher.hpp"
#include <algorithm>
#include <typeindex>
#include <typeinfo>
#include "ngraph/log.hpp"
#include "ngraph/ops/parameter.hpp"
namespace ngraph
{
namespace pattern
{
static std::vector<std::shared_ptr<Node>> get_arguments(std::shared_ptr<Node> n)
{
std::unordered_set<std::shared_ptr<Node>> arguments;
for (const auto& input : n->get_inputs())
{
arguments.insert(input.get_output().get_node());
}
return std::vector<std::shared_ptr<Node>>(
begin(arguments), end(arguments)); //vector is needed for generating permutations
}
std::shared_ptr<Node> Matcher::match_root()
{
assert(is_match());
return m_match_root;
}
void Matcher::reset_pattern_nodes(
std::shared_ptr<Node> node) //TODO: [nikolayk] this doesn't have to be recursive
//even better we should walk the entire pattern subgraph once
//and keep track of all pattern nodes
{
auto label = std::dynamic_pointer_cast<::ngraph::pattern::op::Label>(node);
NGRAPH_DEBUG << "reset_pattern_nodes : node = " << node->get_name() << " , " << node;
if (label)
{
NGRAPH_DEBUG << "reset_pattern_nodes : label = " << node->get_name() << " , "
<< node;
label->reset();
}
for (auto arg : get_arguments(node))
{
reset_pattern_nodes(arg);
}
}
void Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node)
{
bool is_match = true;
if (label->is_bound())
{
if (label->get_bound_node() != graph_node)
{
NGRAPH_DEBUG << "get_bound_node " << label->get_bound_node()->get_name()
<< " , " << label->get_bound_node() << " NOT match "
<< graph_node->get_name() << " , " << graph_node;
is_match = false;
}
}
else
{
auto predicate = label->get_predicate();
is_match = !predicate || predicate(graph_node);
}
if (is_match)
{
NGRAPH_DEBUG << "Binding get_bound_node " << graph_node->get_name() << " , "
<< graph_node << " , " << graph_node->get_name();
label->bind(graph_node);
}
else
{
reset();
m_match_root.reset();
NGRAPH_DEBUG << "MATCHER IS MATCH : " << this->is_match();
}
}
void Matcher::match_any(const std::shared_ptr<op::Any>& any,
const std::shared_ptr<Node>& graph_node)
{
auto predicate = any->get_predicate();
if (!predicate || any->get_predicate()(graph_node))
{
on_match_class(any, graph_node, true);
}
else
{
auto args = get_arguments(any);
assert(args.size() == 1);
on_match_class(args.at(0), graph_node, true);
}
}
void Matcher::match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node)
{
assert(pattern_node && graph_node);
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{
match_pattern(label_node, graph_node);
return;
}
if (auto any_node = std::dynamic_pointer_cast<op::Any>(
pattern_node)) //matches PatternSkipOp semantics
{
match_any(any_node, graph_node);
return;
}
on_match_class(pattern_node,
graph_node,
std::type_index(typeid(*&*pattern_node)) ==
std::type_index(typeid(*&*graph_node)));
}
void Matcher::match_arguments(const Nodes& pattern_args, const Nodes& args)
{
m_depth++;
for (size_t i = 0; i < args.size(); i++)
{
match_class(pattern_args.at(i), args.at(i));
if (!is_match())
{
m_depth--;
return;
}
}
m_depth--;
}
void Matcher::on_match_class(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
bool is_match)
{
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " "
<< (is_match ? " " : "NOT ") << "matched " << graph_node << " , "
<< graph_node->get_name();
if (!is_match)
{
reset_pattern_nodes(pattern_node);
m_match_root.reset();
return;
}
auto args = get_arguments(graph_node);
auto pattern_args = get_arguments(pattern_node);
if (args.size() != pattern_args.size())
{
reset_pattern_nodes(pattern_node);
m_match_root.reset();
return;
}
if (graph_node->is_commutative())
{
auto old_match_root = m_match_root;
std::sort(
begin(pattern_args),
end(pattern_args)); //TODO: [nikolayk] we don't really have to use lexicographically-based perms, heap's algo should be faster
do
{
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
<< graph_node->get_name() << " , " << graph_node;
reset_pattern_nodes(pattern_node);
m_match_root =
old_match_root; //previous permutation wasn't a match; reset m_match_root
match_arguments(pattern_args, args);
if (this->is_match())
{
return;
}
} while (std::next_permutation(begin(pattern_args), end(pattern_args)));
}
else
{
match_arguments(pattern_args, args);
}
}
void Matcher::process_match(::ngraph::pattern::gr_callback_fn callback)
{
gr_callback_fn cb = m_callback;
if (callback)
{
cb = callback;
}
assert(cb);
assert(is_match());
cb(*this);
}
static Nodes get_users(std::shared_ptr<Node> node)
{
Nodes result;
for (auto& output : node->get_outputs())
{
for (auto input : output.get_inputs())
{
result.push_back(input->get_node());
}
}
return result;
}
bool Matcher::match(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node)
{
if (!pattern_node || !graph_node)
{
NGRAPH_DEBUG << "pattern_node or graph_node are not set; matching FAILED";
m_match_root.reset();
}
if (get_users(pattern_node).size())
{
throw "Pattern Node must not be used elsewhere!";
}
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node << " , "
<< pattern_node->get_name() << " , graph_node = " << graph_node << " , "
<< graph_node->get_name();
reset_pattern_nodes(pattern_node);
m_match_root = graph_node;
match_class(pattern_node, graph_node);
//NGRAPH_DEBUG << pad(2 * m_depth) << "is_match() " << is_match();
return is_match();
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <cassert>
#include <memory.h>
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
namespace ngraph
{
namespace pass
{
class GraphRewrite;
}
namespace pattern
{
using gr_callback_fn = std::function<void(class Matcher& m)>;
namespace op
{
class Label;
}
/// \brief Matcher matches (compares) two graphs
///
class Matcher
{
public:
/// \brief Constructs a Matcher object
///
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param callback is a callback function that will be called on a successful match
Matcher(const std::shared_ptr<Node> pattern_node = nullptr,
gr_callback_fn callback = nullptr)
: m_match_root(nullptr)
, m_pattern_node(pattern_node)
, m_callback(callback)
, m_depth(0)
{
}
virtual ~Matcher() {}
// Called when the pattern node matches a graph node.
virtual void on_match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
bool is_match);
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node)
{
return match(m_pattern_node, graph_node);
}
bool match(const std::shared_ptr<Node>& pattern_node, //keep public for testing for now
const std::shared_ptr<Node>& graph_node);
void process_match(gr_callback_fn callback = nullptr);
void reset() {}
bool is_match() { return m_match_root != nullptr; }
std::shared_ptr<Node> pattern_node() { return m_pattern_node; }
std::shared_ptr<Node> match_root();
void reset_pattern_nodes(std::shared_ptr<Node> node);
friend op::Label; //TODO: refine to match_class
protected:
void virtual match_class(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node);
private:
static std::string pad(size_t num) { return std::string(num, ' '); }
void match_arguments(const Nodes& pattern_args, const Nodes& args);
void match_pattern(const std::shared_ptr<op::Label>& pattern_node,
const std::shared_ptr<Node>& graph_node);
void match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node);
std::shared_ptr<Node> m_match_root;
std::shared_ptr<Node> m_pattern_node;
gr_callback_fn m_callback;
size_t m_depth;
};
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// \brief \p Any allows users to specify unexpected nodes in a pattern
/// and skip them if a predicate condition is satisfied.
///
class Any : public Pattern
{
public:
Any(const std::shared_ptr<Node>& arg, Predicate predicate = nullptr)
: Pattern("Any", Nodes{arg}, predicate)
{
set_value_type_checked(arg->get_value_type());
//m_arguments.push_back(arg);
//const_cast<std::multiset<Node*>&>(arg->users()).insert(this);
}
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// \brief Labels are used in patterns to express repeating nodes in an input graph
/// and bind them to specific nodes from the graph
///
class Label : public Pattern
{
public:
static std::shared_ptr<Label>
make_from_node(const std::shared_ptr<ngraph::Node>& node,
Predicate pred = nullptr)
{
auto label = std::make_shared<Label>(pred);
label->set_value_type_checked(node->get_value_type());
return label;
}
bool is_bound() { return m_bound != nullptr; }
std::shared_ptr<Node> get_bound_node() { return m_bound; }
void reset() { m_bound.reset(); }
void bind(std::shared_ptr<Node> n) { m_bound = n; }
Label(Predicate pred = nullptr)
: Pattern("Label", Nodes{}, pred)
{
}
private:
std::shared_ptr<Node> m_bound;
};
}
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include <functional>
#include "ngraph/node.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pattern/matcher.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
using Predicate = std::function<bool(std::shared_ptr<Node>)>;
class Pattern : public Node
{
public:
/// \brief \p a base class for \sa Any and \sa Label
///
Pattern(const std::string& type_name, const Nodes& nodes, Predicate pred)
: Node(type_name, nodes)
, m_predicate(pred)
{
}
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
throw ngraph_error("Uncopyable");
}
Predicate get_predicate() const { return m_predicate; }
protected:
std::function<bool(std::shared_ptr<Node>)> m_predicate;
};
}
}
}
......@@ -36,6 +36,7 @@ set (SRC
pass_manager.cpp
pass_memory_layout.cpp
serialize.cpp
pattern.cpp
shape.cpp
tensor.cpp
topological_sort.cpp
......
......@@ -71,10 +71,6 @@ TEST(build_graph, node_comparison)
auto parg = make_shared<op::Parameter>(element::Float32::element_type(), Shape{});
auto pattern_dot = make_shared<op::Dot>(parg, parg);
ASSERT_TRUE(pattern_dot->is_same_op_type(dot));
// TODO This passes because typeid is not behaving as documented.
// Need to figure out what's wrong.
ASSERT_FALSE(pattern_dot->is_same_op_type(add));
}
TEST(build_graph, literal)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment