Unverified Commit 3bffe536 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/pattern (#4095)

* Make pattern matcher node-based

Simplify implementation
Add support for Or, Branch
Start of support for recurrent pattern

* Only save state at branch points

* Factor Or out of label

* Documentation

* Review

* Only ops need to match on shape/output index
parent 35d8e436
......@@ -557,11 +557,24 @@ set (SRC
pass/pass_util.cpp
pattern/matcher.cpp
pattern/matcher.hpp
pattern/op/any.cpp
pattern/op/any.hpp
pattern/op/any_of.cpp
pattern/op/any_of.hpp
pattern/op/branch.cpp
pattern/op/branch.hpp
pattern/op/capture.cpp
pattern/op/capture.hpp
pattern/op/label.cpp
pattern/op/label.hpp
pattern/op/or.cpp
pattern/op/or.hpp
pattern/op/pattern.cpp
pattern/op/pattern.hpp
pattern/op/skip.cpp
pattern/op/skip.hpp
pattern/op/true.cpp
pattern/op/true.hpp
placement.cpp
placement.hpp
provenance.cpp
......
......@@ -27,6 +27,7 @@
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/placement.hpp"
using namespace std;
......@@ -930,6 +931,23 @@ void Node::validate_and_infer_elementwise_logical(const op::AutoBroadcastSpec& a
set_output_type(0, element::boolean, args_pshape);
}
bool Node::match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
if (pattern_value.get_index() != graph_value.get_index() ||
(matcher->is_strict_mode() &&
(!pattern_value.get_element_type().compatible(graph_value.get_element_type()) ||
!pattern_value.get_partial_shape().compatible(graph_value.get_partial_shape()))))
{
return false;
}
matcher->add_node(graph_value);
return graph_value.get_node_shared_ptr()->get_type_info() == get_type_info() &&
matcher->match_arguments(pattern_value, graph_value);
}
// default implementation for the node to check if it contains partial shape
// we will override this method, for the Op's which depends on additional shape
// attribute to determine if node contains partial shape or not
......
......@@ -67,6 +67,11 @@ namespace ngraph
}
} // namespace op
namespace pattern
{
class Matcher;
}
using ResultVector = std::vector<std::shared_ptr<op::v0::Result>>;
namespace autodiff
......@@ -260,6 +265,7 @@ namespace ngraph
virtual bool is_constant() const;
virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; }
virtual bool is_pattern() const { return false; }
virtual bool is_commutative() const { return false; }
virtual bool is_dynamic() const;
virtual bool has_state() const { return false; }
......@@ -502,6 +508,10 @@ namespace ngraph
return m_op_annotations;
}
virtual bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value);
private:
descriptor::Input& get_input_descriptor(size_t position);
descriptor::Output& get_output_descriptor(size_t position);
......@@ -722,6 +732,12 @@ namespace ngraph
/// A null output
Output() = default;
void reset()
{
m_node.reset();
m_index = 0;
}
/// This output position for a different node
Output<Node> for_node(const std::shared_ptr<Node>& node) { return Output(node, m_index); }
/// \return A pointer to the node referred to by this output handle.
......@@ -828,6 +844,12 @@ namespace ngraph
/// A null output
Output() = default;
void reset()
{
m_node.reset();
m_index = 0;
}
/// This output position for a different node
Output<const Node> for_node(const std::shared_ptr<const Node>& node)
{
......
......@@ -14,87 +14,71 @@
// limitations under the License.
//*****************************************************************************
#include "matcher.hpp"
#include <algorithm>
#include <regex>
#include <typeindex>
#include <typeinfo>
#include "matcher.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
MatcherState::MatcherState(Matcher* matcher)
: m_matcher(matcher)
, m_pattern_value_map(matcher->m_pattern_map)
, m_watermark(matcher->m_matched_list.size())
, m_capture_size(matcher->m_pattern_value_maps.size())
{
// The symbols are requiered to be in cpp file to workaround RTTI issue on Android LLVM
const NodeTypeInfo& Any::get_type_info() const { return type_info; }
const NodeTypeInfo& AnyOf::get_type_info() const { return type_info; }
const NodeTypeInfo& Label::get_type_info() const { return type_info; }
const NodeTypeInfo& Skip::get_type_info() const { return type_info; }
Predicate Pattern::get_predicate() const { return m_predicate; }
}
constexpr NodeTypeInfo op::AnyOf::type_info;
constexpr NodeTypeInfo op::Any::type_info;
constexpr NodeTypeInfo op::Label::type_info;
constexpr NodeTypeInfo op::Skip::type_info;
std::shared_ptr<Node> Matcher::get_match_root() { return m_match_root; }
bool Matcher::match_pattern(const std::shared_ptr<op::Label>& label,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
MatcherState::~MatcherState()
{
bool is_match = true;
if (pattern_map.count(label))
if (m_restore)
{
if (pattern_map[label] != graph_node)
{
NGRAPH_DEBUG << "[MATCHER] get_bound_node " << pattern_map[label]->get_name()
<< " , " << pattern_map[label] << " does NOT match "
<< graph_node->get_name();
is_match = false;
}
}
else
{
auto predicate = label->get_predicate();
is_match = !predicate || predicate(graph_node);
m_matcher->m_matched_list.erase(m_matcher->m_matched_list.begin() + m_watermark,
m_matcher->m_matched_list.end());
m_matcher->m_pattern_value_maps.erase(m_pattern_value_maps.begin() + m_capture_size,
m_pattern_value_maps.end());
m_matcher->m_pattern_map = m_pattern_value_map;
}
}
if (is_match) // in case label was already bound this rebinds it to the same node
// (harmless; and the logic seems cleaner)
{
auto args = label->get_arguments();
if (args.size() > 0)
{
if (args.size() != 1)
{
throw ngraph_error("Labels can only take 1 argument!");
}
NGRAPH_DEBUG << "[MATCHER] Label describes a sub graph in the pattern";
is_match = match_node(args.at(0), graph_node, pattern_map);
}
bool MatcherState::finish(bool is_successful)
{
m_restore = !is_successful;
return is_successful;
}
PatternMap Matcher::get_pattern_map() const { return as_pattern_map(m_pattern_map); }
size_t Matcher::add_node(Output<Node> node)
{
size_t result = m_matched_list.size();
m_matched_list.push_back(node.get_node_shared_ptr());
return result;
}
if (is_match)
{
NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node " << label->get_name()
<< " , " << graph_node << " , " << graph_node->get_name();
pattern_map[label] = graph_node;
}
}
std::shared_ptr<Node> Matcher::get_match_root()
{
return m_match_root.get_node_shared_ptr();
}
if (!is_match)
MatcherState Matcher::start_match() { return MatcherState(this); }
Output<Node> Matcher::get_match_value() { return m_match_root; }
void Matcher::capture(const std::set<Node*>& static_nodes)
{
m_pattern_value_maps.push_back(m_pattern_map);
m_pattern_map.clear();
for (auto key_value : m_pattern_value_maps.back())
{
NGRAPH_DEBUG << "[MATCHER] Aborting at " << graph_node->get_name()
<< " for pattern " << label->get_name();
if (static_nodes.count(key_value.first.get()) > 0)
{
m_pattern_map.insert(key_value);
}
}
return is_match;
}
bool Matcher::is_contained_match(const NodeVector& exclusions, bool ignore_unused)
{
if (exclusions.empty())
......@@ -105,7 +89,7 @@ namespace ngraph
// leaf label
if (entry.first->get_input_size() == 0)
{
label_exclusions.push_back(entry.second);
label_exclusions.push_back(entry.second.get_node_shared_ptr());
}
}
return ngraph::get_subgraph_outputs(
......@@ -116,112 +100,11 @@ namespace ngraph
return ngraph::get_subgraph_outputs(get_matched_nodes(), exclusions).size() < 2;
}
bool Matcher::match_skip(const std::shared_ptr<op::Skip>& skip,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
auto predicate = skip->get_predicate();
if (!predicate || predicate(graph_node))
{
return match_arguments(skip, graph_node, pattern_map);
}
else
{
auto args = skip->get_arguments();
if (args.size() != 1)
{
throw ngraph_error("Skip can only take one argument");
}
return match_node(args.at(0), graph_node, pattern_map);
}
}
bool Matcher::match_any(const std::shared_ptr<op::Any>& any,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
bool Matcher::match_value(const ngraph::Output<Node>& pattern_value,
const ngraph::Output<Node>& graph_value)
{
auto predicate = any->get_predicate();
if (!predicate)
{
throw ngraph_error("predicate is required");
}
if (predicate(graph_node))
{
return match_arguments(any, graph_node, pattern_map);
}
else
{
NGRAPH_DEBUG << "[MATCHER] Aborting at " << graph_node->get_name()
<< " for pattern " << any->get_name();
return false;
}
}
bool Matcher::match_any_of(const std::shared_ptr<op::AnyOf>& any,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
auto predicate = any->get_predicate();
if (!predicate)
{
throw ngraph_error("predicate is required");
}
if (predicate(graph_node))
{
for (auto arg : graph_node->get_arguments())
{
PatternMap copy{pattern_map};
if (match_node(any->get_argument(0), arg, copy))
{
pattern_map.insert(begin(copy), end(copy));
return true;
}
}
NGRAPH_DEBUG << "[MATCHER] Aborting at " << graph_node->get_name()
<< " for pattern " << any->get_name();
return false;
}
else
{
NGRAPH_DEBUG << "[MATCHER] Aborting at " << graph_node->get_name()
<< " for pattern " << any->get_name();
return false;
}
}
bool Matcher::match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
if (!pattern_node || !graph_node)
{
throw ngraph_error("pattern_node or graph_node shouldn't be nullptrs!");
}
add_node(graph_node);
size_t watermark = m_matched_list.size() - 1;
// we can skip multi-output nodes since their shapes will be compared
// when their individual GOE are matched
// this also gives a bit more flexibility since we don't have to worry
// about *all* outputs of a pattern node but only the ones we want to match.
if (m_strict_mode && graph_node->get_outputs().size() == 1)
{
bool shape_match = pattern_node->get_output_partial_shape(0).compatible(
graph_node->get_output_partial_shape(0));
bool et_match =
pattern_node->get_element_type().compatible(graph_node->get_element_type());
if (!shape_match || !et_match)
{
return abort_match(watermark, false);
}
}
std::shared_ptr<Node> pattern_node = pattern_value.get_node_shared_ptr();
std::shared_ptr<Node> graph_node = graph_value.get_node_shared_ptr();
// This env var allows one to specify node name patterns to abort pattern matching
// at particular nodes. The upshot is that one can quickly zero in on an offending
......@@ -232,83 +115,41 @@ namespace ngraph
static const std::regex node_skip_regex(node_skip_cregex);
if (std::regex_match(graph_node->get_name(), node_skip_regex))
{
NGRAPH_DEBUG << "[MATCHER] Aborting at " << graph_node->get_name()
NGRAPH_DEBUG << "[MATCHER] Aborting at " << *graph_node
<< " due to NGRAPH_MATCHER_SKIP set to " << node_skip_cregex;
return abort_match(watermark, false);
return false;
}
}
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node->get_name() << " matched "
<< graph_node->get_name();
if (auto label_node = as_type_ptr<op::Label>(pattern_node))
{
return abort_match(watermark, match_pattern(label_node, graph_node, pattern_map));
}
if (auto skip_node =
as_type_ptr<op::Skip>(pattern_node)) // matches PatternSkipOp semantics
{
return abort_match(watermark, match_skip(skip_node, graph_node, pattern_map));
}
if (auto any_node = as_type_ptr<op::Any>(pattern_node))
{
return abort_match(watermark, match_any(any_node, graph_node, pattern_map));
}
if (auto any_of_node = as_type_ptr<op::AnyOf>(pattern_node))
{
return abort_match(watermark, match_any_of(any_of_node, graph_node, pattern_map));
}
auto p_pattern_node = pattern_node.get();
auto p_graph_node = graph_node.get();
if (std::type_index(typeid(*p_pattern_node)) == std::type_index(typeid(*p_graph_node)))
{
return abort_match(watermark,
match_arguments(pattern_node, graph_node, pattern_map));
}
NGRAPH_DEBUG << "[MATCHER] Aborting at " << graph_node->get_name() << " for pattern "
<< pattern_node->get_name();
return abort_match(watermark, false);
return pattern_node->match_value(this, pattern_value, graph_value);
}
bool Matcher::match_permutation(const NodeVector& pattern_args,
const NodeVector& args,
PatternMap& pattern_map)
bool Matcher::match_permutation(const OutputVector& pattern_args, const OutputVector& args)
{
m_depth++;
for (size_t i = 0; i < args.size(); i++)
{
if (!match_node(pattern_args.at(i), args.at(i), pattern_map))
if (!match_value(pattern_args.at(i), args.at(i)))
{
m_depth--;
return false;
}
}
m_depth--;
return true;
}
bool Matcher::match_arguments(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map)
bool Matcher::match_arguments(const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_arguments : "
<< "pattern = " << pattern_node->get_name() << " "
<< "matched " << graph_node->get_name();
auto pattern_node = pattern_value.get_node_shared_ptr();
auto graph_node = graph_value.get_node_shared_ptr();
NGRAPH_DEBUG << "[MATCHER] Match arguments at " << *graph_node << " for pattern "
<< *pattern_node;
auto args = graph_node->get_arguments();
auto pattern_args = pattern_node->get_arguments();
auto args = graph_node->input_values();
auto pattern_args = pattern_node->input_values();
if (args.size() != pattern_args.size())
{
NGRAPH_DEBUG << "[MATCHER] Aborting at " << graph_node->get_name()
<< " for pattern " << pattern_node->get_name();
NGRAPH_DEBUG << "[MATCHER] Aborting at " << *graph_node << " for pattern "
<< *pattern_node;
return false;
}
......@@ -318,112 +159,104 @@ namespace ngraph
// heap's algo should be faster
std::sort(begin(pattern_args),
end(pattern_args),
[](const std::shared_ptr<ngraph::Node>& n1,
const std::shared_ptr<ngraph::Node>& n2) {
return n1->get_instance_id() < n2->get_instance_id();
});
[](const ngraph::Output<ngraph::Node>& n1,
const ngraph::Output<ngraph::Node>& n2) { return n1 < n2; });
do
{
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
<< graph_node->get_name();
PatternMap copy{pattern_map};
if (match_permutation(pattern_args, args, copy))
auto saved = start_match();
if (match_permutation(pattern_args, args))
{
pattern_map.insert(begin(copy), end(copy));
return true;
return saved.finish(true);
}
} while (std::next_permutation(begin(pattern_args),
end(pattern_args),
[](const std::shared_ptr<ngraph::Node>& n1,
const std::shared_ptr<ngraph::Node>& n2) {
return n1->get_instance_id() <
n2->get_instance_id();
}));
} while (std::next_permutation(
begin(pattern_args),
end(pattern_args),
[](const ngraph::Output<ngraph::Node>& n1,
const ngraph::Output<ngraph::Node>& n2) { return n1 < n2; }));
}
else
{
PatternMap copy{pattern_map};
if (match_permutation(pattern_args, args, copy))
{
pattern_map.insert(begin(copy), end(copy));
return true;
}
return match_permutation(pattern_args, args);
}
NGRAPH_DEBUG << "[MATCHER] Aborting at " << graph_node->get_name() << " for pattern "
<< pattern_node->get_name();
NGRAPH_DEBUG << "[MATCHER] Aborting at " << *graph_node << " for pattern "
<< *pattern_node;
return false;
}
bool Matcher::match(const std::shared_ptr<Node>& graph_node)
bool Matcher::match(const Output<Node>& graph_value)
{
// clear our state
m_match_root.reset();
m_pattern_map.clear();
m_matched_list.clear();
if (!m_pattern_node || !graph_node)
{
throw ngraph_error("m_pattern_node or graph_node are not set");
}
NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
return match(graph_value, PatternValueMap{});
}
bool Matcher::match(const std::shared_ptr<Node>& graph_node,
const PatternMap& previous_matches)
bool Matcher::match(const Output<Node>& graph_value,
const PatternValueMap& previous_matches)
{
// clear our state
m_match_root.reset();
m_pattern_map.clear();
m_matched_list.clear();
// insert previous matches
m_pattern_map.insert(previous_matches.cbegin(), previous_matches.cend());
if (!m_pattern_node || !graph_node)
auto saved = start_match();
bool is_match = saved.finish(match_value(m_pattern_node, graph_value));
if (is_match)
{
throw ngraph_error("m_pattern_node or graph_node are not set");
m_match_root = graph_value;
}
return is_match;
}
NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
bool Matcher::match(const Output<Node>& graph_value, const PatternMap& previous_matches)
{
return match(graph_value, as_pattern_value_map(previous_matches));
}
bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map);
if (is_match)
namespace
{
std::set<std::shared_ptr<Node>>
as_node_set(const std::set<std::shared_ptr<op::Label>>& label_set)
{
m_match_root = graph_node;
std::set<std::shared_ptr<Node>> result;
for (auto label : label_set)
{
result.insert(label);
}
return result;
}
return is_match;
}
bool RecurrentMatcher::match(std::shared_ptr<Node> graph)
RecurrentMatcher::RecurrentMatcher(
const Output<Node>& initial_pattern,
const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
: RecurrentMatcher(initial_pattern, pattern, rpattern, as_node_set(correlated_patterns))
{
}
bool RecurrentMatcher::match(Output<Node> graph)
{
bool matched = false;
Matcher m(m_pattern);
Matcher::PatternMap previous_matches;
Matcher m_initial(m_initial_pattern);
Matcher m_repeat(m_pattern);
Matcher& m = m_initial;
PatternValueMap previous_matches;
m_matches.clear();
m_match_root = graph;
NGRAPH_DEBUG << "matching graph to " << graph->get_name() << std::endl;
// try to match one cell (i.e. pattern)
while (m.match(graph, previous_matches))
{
matched = true;
// move to the next cell
graph = m.get_pattern_map()[m_recurrent_pattern];
NGRAPH_DEBUG << "setting graph to " << graph->get_name() << std::endl;
graph = m.get_pattern_value_map()[m_recurrent_pattern];
// copy bound nodes for the current pattern graph into a global matches map
for (auto cur_match : m.get_pattern_map())
for (auto cur_match : m.get_pattern_value_map())
{
m_matches[cur_match.first].push_back(cur_match.second);
}
......@@ -434,28 +267,13 @@ namespace ngraph
// unbounded by default
for (auto cor_pat : m_correlated_patterns)
{
if (m.get_pattern_map().count(cor_pat) != 0)
{
// assert that bound nodes from the previous and current matches are the
// same
if (previous_matches.count(cor_pat) != 0)
{
if (previous_matches[cor_pat] != m.get_pattern_map()[cor_pat])
{
throw ngraph_error(
"previous matches and current matches aren't consistent!");
}
}
previous_matches[cor_pat] = m.get_pattern_map()[cor_pat];
}
previous_matches[cor_pat] = m.get_pattern_value_map()[cor_pat];
}
m = m_repeat;
}
if (!matched)
{
NGRAPH_DEBUG << "[RecurrentMatcher] Aborting at " << graph->get_name()
<< " for pattern " << m_pattern->get_name();
m_match_root.reset();
}
......
......@@ -16,6 +16,7 @@
#pragma once
#include <algorithm>
#include <functional>
#include <memory.h>
......@@ -35,43 +36,56 @@ namespace ngraph
namespace pattern
{
using RPatternMap = std::map<std::shared_ptr<op::Label>, NodeVector>;
class Matcher;
template <typename T>
std::function<bool(std::shared_ptr<Node>)> has_class()
class NGRAPH_API MatcherState
{
auto pred = [](std::shared_ptr<Node> node) -> bool { return is_type<T>(node); };
return pred;
}
public:
MatcherState(Matcher*);
bool finish(bool is_successful);
~MatcherState();
namespace op
{
class Label;
}
protected:
Matcher* m_matcher;
PatternValueMap m_pattern_value_map;
PatternValueMaps m_pattern_value_maps;
size_t m_watermark;
size_t m_capture_size;
bool m_restore{true};
};
/// \brief Matcher matches (compares) two graphs
/// Matcher looks for node patterns in a computation graph. The patterns are described by an
/// automaton that is described by an extended computation graph. The matcher executes
/// by attempting to match the start node of the pattern to a computation graph value
/// (output of a Node). In addition to determing if a match occurs, a pattern node may add
/// graph nodes to a list of matched nodes, associate nodes with graph values, and start
/// submatches. Submatches add match state changes to the enclosing match if the submatch
/// succeeds; otherwise the state is reverted.
///
/// The default match behavior of a pattern node with a graph nodes is that the computation
/// graph value is added to the end of the matched value list and the match succeeds if the
/// node/pattern types match and the input values match. In the case of a commutative node,
/// the inputs can match in any order. If the matcher is in strict mode, the graph value
/// element type and shape must also match.
///
/// Pattern nodes that have different match behavior are in ngraph::pattern::op and have
/// descriptions of their match behavior.
class NGRAPH_API Matcher
{
public:
using PatternMap = std::map<std::shared_ptr<op::Label>, std::shared_ptr<Node>>;
using PatternMap = ngraph::pattern::PatternMap;
// Avoid implicit string construction from nullptr.
Matcher(const std::shared_ptr<Node>& pattern_node, std::nullptr_t name) = delete;
Matcher(const std::shared_ptr<Node>& pattern_node)
Matcher(const Output<Node>& pattern_node)
: m_pattern_node{pattern_node}
, m_depth{0}
, m_name{"Unnamed"}
, m_strict_mode{false}
{
}
Matcher(const std::shared_ptr<Node>& pattern_node, const std::string& name)
Matcher(const Output<Node>& pattern_node, const std::string& name)
: m_pattern_node(pattern_node)
, m_depth{0}
, m_name{name}
, m_strict_mode{false}
{
}
/// \brief Constructs a Matcher object
......@@ -79,11 +93,8 @@ namespace ngraph
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
/// \param name is a string which is used for logging and disabling a matcher
/// \param strict_mode forces a matcher to consider shapes and ET of nodes
Matcher(const std::shared_ptr<Node>& pattern_node,
const std::string& name,
bool strict_mode)
Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
: m_pattern_node(pattern_node)
, m_depth(0)
, m_name(name)
, m_strict_mode(strict_mode)
{
......@@ -92,14 +103,15 @@ namespace ngraph
virtual ~Matcher() {}
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_node is an input graph to be matched against
bool match(const std::shared_ptr<Node>& graph_node);
/// \param graph_value is an input graph to be matched against
bool match(const Output<Node>& graph_value);
/// \brief Matches a pattern to \p graph_node
///
/// \param graph_node is an input graph to be matched against
/// \param graph_value is an input graph to be matched against
/// \param previous_matches contains previous mappings from labels to nodes to use
bool match(const std::shared_ptr<Node>& graph_node, const PatternMap& previous_matches);
bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
template <typename T>
static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node)
......@@ -123,65 +135,53 @@ namespace ngraph
}
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
const NodeVector& get_matched_nodes() { return m_matched_list; }
const NodeVector get_matched_nodes() { return as_node_vector(m_matched_list); }
const OutputVector& get_matched_values() const { return m_matched_list; }
OutputVector& get_matched_values() { return m_matched_list; }
void reset() {}
const std::string& get_name() { return m_name; }
std::shared_ptr<Node> get_pattern() { return m_pattern_node; }
std::shared_ptr<Node> get_pattern() { return m_pattern_node.as_single_output_node(); }
Output<Node> get_pattern_value() { return m_pattern_node; }
std::shared_ptr<Node> get_match_root();
PatternMap get_pattern_map() { return PatternMap{m_pattern_map}; }
Output<Node> get_match_value();
PatternMap get_pattern_map() const;
PatternValueMap& get_pattern_value_map() { return m_pattern_map; }
PatternValueMaps& get_pattern_value_maps() { return m_pattern_value_maps; }
/// \brief Low-level helper to match recurring patterns
///
/// \param graph is a graph to be matched against
/// \param pattern is a recurring pattern
/// \param rpattern specifies a node to recur from next
/// \param patterns a map from labels to matches
friend op::Label; // TODO: refine to match_class
protected:
void add_node(std::shared_ptr<Node> node) { m_matched_list.push_back(node); }
bool abort_match(size_t watermark, bool matched)
{
if (!matched)
{
m_matched_list.erase(m_matched_list.begin() + watermark, m_matched_list.end());
}
return matched;
}
size_t add_node(Output<Node> node);
bool virtual match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool virtual match_value(const ngraph::Output<Node>& pattern_value,
const ngraph::Output<Node>& graph_value);
virtual bool match_arguments(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool is_strict_mode() { return m_strict_mode; }
virtual bool match_arguments(const Output<Node>& pattern_value,
const Output<Node>& graph_value);
std::shared_ptr<Node> m_match_root;
std::shared_ptr<Node> m_pattern_node;
PatternMap m_pattern_map;
NodeVector m_matched_list;
void capture(const std::set<Node*>& static_nodes);
private:
static std::string pad(size_t num) { return std::string(num, ' '); }
bool match_permutation(const NodeVector& pattern_args,
const NodeVector& args,
PatternMap& pattern_map);
bool match_pattern(const std::shared_ptr<op::Label>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool match_skip(const std::shared_ptr<op::Skip>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool match_any_of(const std::shared_ptr<op::AnyOf>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
size_t m_depth;
std::string m_name;
bool m_strict_mode;
size_t get_number_of_recurrent_matches() const { return m_pattern_value_maps.size(); }
NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
size_t get_number_of_bound_labels() const;
/// \brief Try a match
MatcherState start_match();
Output<Node> m_match_root;
Output<Node> m_pattern_node;
PatternValueMap m_pattern_map;
PatternValueMaps m_pattern_value_maps;
OutputVector m_matched_list;
protected:
bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
std::string m_name{"unnamed"};
bool m_strict_mode{false};
};
class RecurrentMatcher
......@@ -190,30 +190,60 @@ namespace ngraph
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
///
/// \param initial_pattern is a pattern sub graph describing the initial cell
/// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should
/// start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
/// across all cells
RecurrentMatcher(std::shared_ptr<Node> pattern,
std::shared_ptr<op::Label> rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
: m_pattern(pattern)
RecurrentMatcher(const Output<Node>& initial_pattern,
const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<Node>>& correlated_patterns)
: m_initial_pattern(initial_pattern)
, m_pattern(pattern)
, m_recurrent_pattern(rpattern)
, m_correlated_patterns(correlated_patterns)
{
}
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
///
/// \param pattern is a pattern sub graph describing an individual cell
/// \param rpattern is a (recurring) label to denote which node the next match should
/// start at
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
/// across all cells
RecurrentMatcher(const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<Node>>& correlated_patterns)
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns)
{
}
RecurrentMatcher(const Output<Node>& initial_pattern,
const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
RecurrentMatcher(const Output<Node>& pattern,
const std::shared_ptr<Node>& rpattern,
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns)
{
}
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
/// describing an individual cell
NodeVector get_bound_nodes_for_pattern(std::shared_ptr<op::Label> pattern) const
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const
{
if (m_matches.count(pattern) == 0)
{
throw ngraph_error("No bound nodes for a given label");
}
return NodeVector{m_matches.at(pattern)};
return as_node_vector(m_matches.at(pattern));
}
size_t get_number_of_recurrent_matches() const
......@@ -228,15 +258,17 @@ namespace ngraph
size_t get_number_of_bound_labels() const { return m_matches.size(); }
/// \brief Tries to match a pattern for an individual cell to a given \p graph
bool match(std::shared_ptr<Node> graph);
bool match(Output<Node> graph);
std::shared_ptr<Node> get_match_root() { return m_match_root; }
std::shared_ptr<Node> get_match_root() { return m_match_root.get_node_shared_ptr(); }
Output<Node> get_match_value() { return m_match_root; }
private:
std::shared_ptr<Node> m_pattern;
std::shared_ptr<op::Label> m_recurrent_pattern;
const std::set<std::shared_ptr<op::Label>> m_correlated_patterns;
RPatternMap m_matches;
std::shared_ptr<Node> m_match_root;
Output<Node> m_initial_pattern;
Output<Node> m_pattern;
std::shared_ptr<Node> m_recurrent_pattern;
const std::set<std::shared_ptr<Node>> m_correlated_patterns;
RPatternValueMap m_matches;
Output<Node> m_match_root;
};
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Any::type_info;
const NodeTypeInfo& pattern::op::Any::get_type_info() const
{
return type_info;
}
bool pattern::op::Any::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
matcher->add_node(graph_value);
return m_predicate(graph_value) && matcher->match_arguments(pattern_value, graph_value);
}
......@@ -25,7 +25,8 @@ namespace ngraph
{
namespace op
{
/// \brief Anys are used in patterns to express arbitrary queries on a node
/// The graph value is to the matched value list. If the predicate is true for the node
/// and the arguments match, the match succeeds.
class NGRAPH_API Any : public Pattern
{
public:
......@@ -35,26 +36,38 @@ namespace ngraph
/// shape.
Any(const element::Type& type,
const PartialShape& s,
Predicate pred,
const NodeVector& wrapped_nodes)
: Pattern(wrapped_nodes, pred)
ValuePredicate pred,
const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred)
{
if (!pred)
{
throw ngraph_error("predicate is required");
}
set_output_type(0, type, s);
}
Any(const element::Type& type,
const PartialShape& s,
NodePredicate pred,
const NodeVector& wrapped_values)
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values))
{
}
/// \brief creates a Any node containing a sub-pattern described by the type and
/// shape of \sa node.
Any(std::shared_ptr<Node> node, Predicate pred, const NodeVector& wrapped_nodes)
: Any(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
Any(const Output<Node>& node,
ValuePredicate pred,
const OutputVector& wrapped_values)
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values)
{
}
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
: Any(node.get_element_type(),
node.get_partial_shape(),
as_value_predicate(pred),
as_output_vector(wrapped_values))
{
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::AnyOf::type_info;
const NodeTypeInfo& pattern::op::AnyOf::get_type_info() const
{
return type_info;
}
bool pattern::op::AnyOf::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
matcher->add_node(graph_value);
return m_predicate(graph_value) && ([&]() {
for (auto arg : graph_value.get_node_shared_ptr()->input_values())
{
auto saved = matcher->start_match();
if (matcher->match_value(input_value(0), arg))
{
return saved.finish(true);
}
}
return false;
}());
}
......@@ -25,13 +25,13 @@ namespace ngraph
{
namespace op
{
/// \brief AnyOfs are used in patterns to express arbitrary queries on a node
/// The graph value is added to the matched values list. If the predicate is true for
/// the
/// graph node, a submatch is performed on the input of AnyOf and each input of the
/// graph node. The first match that succeeds results in a successful match. Otherwise
/// the match fails.
///
/// When AnyOf predicate matches a node; Matcher tries to match node's arguments to
/// a single argument of AnyOf one by one. The first match is returned.
/// This is useful for nodes with variable number of arguments such as Concat
/// AnyOf enables on to specify one single branch/chain. The remaining arguments
/// can be discovered (in a callback) by simply inspecting matched node's argument.
/// AnyOf may be given a type and shape for use in strict mode.
class NGRAPH_API AnyOf : public Pattern
{
public:
......@@ -41,31 +41,46 @@ namespace ngraph
/// \sa shape.
AnyOf(const element::Type& type,
const PartialShape& s,
Predicate pred,
const NodeVector& wrapped_nodes)
: Pattern(wrapped_nodes, pred)
ValuePredicate pred,
const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred)
{
if (!pred)
{
throw ngraph_error("predicate is required");
}
if (wrapped_nodes.size() != 1)
if (wrapped_values.size() != 1)
{
throw ngraph_error("AnyOf expects exactly one argument");
}
set_output_type(0, type, s);
}
AnyOf(const element::Type& type,
const PartialShape& s,
NodePredicate pred,
const NodeVector& wrapped_values)
: AnyOf(type,
s,
[pred](const Output<Node>& value) {
return pred(value.as_single_output_node(false));
},
as_output_vector(wrapped_values))
{
}
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
/// shape of \sa node.
AnyOf(std::shared_ptr<Node> node, Predicate pred, const NodeVector& wrapped_nodes)
: AnyOf(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
AnyOf(const Output<Node>& node,
ValuePredicate pred,
const OutputVector& wrapped_values)
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values)
{
}
AnyOf(std::shared_ptr<Node> node,
NodePredicate pred,
const NodeVector& wrapped_values)
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values))
{
}
bool match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/branch.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Branch::type_info;
const NodeTypeInfo& pattern::op::Branch::get_type_info() const
{
return type_info;
}
bool pattern::op::Branch::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
return matcher->match_value(get_destination(), graph_value);
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// A branch adds a loop to the pattern. The branch match is successful if the
/// destination node pattern matches the graph value. The destination node is a node in
/// the pattern graph that will not have been created some time after the Branch node is
/// created; use set_destination to add it.
///
/// The branch destination is not stored as a shared pointer to prevent reference
/// cycles. Thus the destination node must be referenced in some other way to prevent it
/// from being deleted.
class NGRAPH_API Branch : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternBranch", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief Creates a Branch pattern
/// \param pattern the destinationing pattern
/// \param labels Labels where the destination may occur
Branch()
: Pattern(OutputVector{})
{
set_output_type(0, element::f32, Shape{});
}
void set_destination(const Output<Node>& destination)
{
m_destination_node = destination.get_node();
m_destination_index = destination.get_index();
}
Output<Node> get_destination() const
{
return m_destination_node == nullptr
? Output<Node>()
: Output<Node>{m_destination_node->shared_from_this(),
m_destination_index};
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
Node* m_destination_node{nullptr};
size_t m_destination_index{0};
};
}
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/capture.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Capture::type_info;
const NodeTypeInfo& pattern::op::Capture::get_type_info() const
{
return type_info;
}
bool pattern::op::Capture::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
matcher->capture(m_static_nodes);
return true;
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// Experimental for support of recurrent matches.
///
/// Capture adds the pattern value map to a list of pattern value maps and resets
/// matches for pattern nodes not in the static node list. The match always succeeds.
class NGRAPH_API Capture : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternCapture", 0};
const NodeTypeInfo& get_type_info() const override;
Capture(const Output<Node>& arg)
: Pattern({arg})
{
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
/// \brief static nodes are retained after a capture. All other nodes are dropped
std::set<Node*> get_static_nodes() { return m_static_nodes; }
void set_static_nodes(const std::set<Node*>& static_nodes)
{
m_static_nodes = static_nodes;
}
virtual bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
std::set<Node*> m_static_nodes;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/or.hpp"
#include "ngraph/pattern/op/true.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Label::type_info;
const NodeTypeInfo& pattern::op::Label::get_type_info() const
{
return type_info;
}
Output<Node> pattern::op::Label::wrap_values(const OutputVector& wrapped_values)
{
switch (wrapped_values.size())
{
case 0: return make_shared<pattern::op::True>()->output(0);
case 1: return wrapped_values[0];
default: return make_shared<pattern::op::Or>(wrapped_values)->output(0);
}
}
bool pattern::op::Label::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
if (m_predicate(graph_value))
{
auto& pattern_map = matcher->get_pattern_value_map();
auto saved = matcher->start_match();
matcher->add_node(graph_value);
if (pattern_map.count(shared_from_this()))
{
return saved.finish(pattern_map[shared_from_this()] == graph_value);
}
else
{
pattern_map[shared_from_this()] = graph_value;
return saved.finish(matcher->match_value(input_value(0), graph_value));
}
}
return false;
}
......@@ -25,9 +25,15 @@ namespace ngraph
{
namespace op
{
/// \brief Labels are used in patterns to express repeating nodes in an input graph
/// and bind them to specific nodes from the graph
/// Fails if the predicate returns false on the graph value.
///
/// The graph value is added to the matched values list. If the Label is already
/// associated with a value, the match succeeds if the value is the same as the graph
/// value. Otherwise, the label is associated with the graph value and the match
/// succeeds if the pattern input matches the graph value.
///
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
class NGRAPH_API Label : public Pattern
{
public:
......@@ -44,38 +50,95 @@ namespace ngraph
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
/// Shape{2,2},
/// nullptr,
/// NodeVector{add});
/// OutputVector{add});
/// \endcode
Label(const element::Type& type,
const PartialShape& s,
Predicate pred = nullptr,
const NodeVector& wrapped_nodes = NodeVector{})
: Pattern(wrapped_nodes, pred)
const ValuePredicate pred,
const OutputVector& wrapped_values)
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred)
{
set_output_type(0, type, s);
}
Label(const element::Type& type, const PartialShape& s)
: Label(type, s, [](const Output<Node>&) { return true; }, OutputVector())
{
}
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
: Label(type, s, pred, OutputVector{})
{
}
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
: Label(type, s, as_value_predicate(pred), OutputVector{})
{
}
Label(const element::Type& type,
const PartialShape& s,
const NodePredicate pred,
const NodeVector& wrapped_values)
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values))
{
}
/// \brief creates a Label node containing a sub-pattern described by the type and
/// shape of \sa node.
///
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa wrapped_nodes
/// that match the pattern specified by \sa wrapped_values
/// Example:
/// \code{.cpp}
/// auto add = a + b; // a and b are op::Parameter in this example
/// auto label = std::make_shared<pattern::op::Label>(add,
/// nullptr,
/// NodeVector{add});
/// OutputVector{add});
/// \endcode
Label(std::shared_ptr<Node> node,
Predicate pred = nullptr,
const NodeVector& wrapped_nodes = NodeVector{})
: Label(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
Label(const Output<Node>& value,
const ValuePredicate pred,
const OutputVector& wrapped_values)
: Label(
value.get_element_type(), value.get_partial_shape(), pred, wrapped_values)
{
}
Label(const Output<Node>& value, const ValuePredicate pred)
: Label(
value.get_element_type(), value.get_partial_shape(), pred, OutputVector{})
{
}
Label(const Output<Node>& value, const NodePredicate pred)
: Label(value.get_element_type(),
value.get_partial_shape(),
as_value_predicate(pred),
OutputVector{})
{
}
Label(const Output<Node>& value)
: Label(value.get_element_type(),
value.get_partial_shape(),
[](const Output<Node>&) { return true; },
OutputVector{})
{
}
Label(const Output<Node>& node,
const NodePredicate pred,
const NodeVector& wrapped_values)
: Label(node.get_element_type(),
node.get_partial_shape(),
as_value_predicate(pred),
as_output_vector(wrapped_values))
{
}
bool match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
static Output<Node> wrap_values(const OutputVector& wrapped_values);
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/or.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Or::type_info;
const NodeTypeInfo& pattern::op::Or::get_type_info() const
{
return type_info;
}
bool pattern::op::Or::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
for (auto input_value : input_values())
{
auto saved = matcher->start_match();
if (matcher->match_value(input_value, graph_value))
{
return saved.finish(true);
}
}
return false;
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// A submatch on the graph value is performed on each input to the Or; the match
/// succeeds on the first match. Otherwise the match fails.
class NGRAPH_API Or : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternOr", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates an Or node matching one of several sub-patterns in order. Does
/// not add node to match list.
/// \param patterns The patterns to try for matching
Or(const OutputVector& patterns)
: Pattern(patterns)
{
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <regex>
#include "pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
// The symbols are required to be in cpp file to workaround RTTI issue on Android LLVM
ValuePredicate Pattern::get_predicate() const { return m_predicate; }
ValuePredicate as_value_predicate(NodePredicate pred)
{
if (pred == nullptr)
{
return [](const Output<Node>&) { return true; };
}
else
{
return [pred](const Output<Node>& value) {
return pred(value.get_node_shared_ptr());
};
}
}
}
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map)
{
PatternMap result;
for (auto& kv : pattern_value_map)
{
result[kv.first] = kv.second.get_node_shared_ptr();
}
return result;
}
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map)
{
PatternValueMap result;
for (auto& kv : pattern_map)
{
result[kv.first] = kv.second;
}
return result;
}
}
}
......@@ -26,16 +26,53 @@ namespace ngraph
{
namespace op
{
using Predicate = std::function<bool(std::shared_ptr<Node>)>;
class Label;
}
class Matcher;
class MatchState;
using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
using PatternValueMaps = std::vector<PatternValueMap>;
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
template <typename T>
std::function<bool(std::shared_ptr<Node>)> has_class()
{
auto pred = [](std::shared_ptr<Node> node) -> bool { return is_type<T>(node); };
return pred;
}
namespace op
{
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
using ValuePredicate = std::function<bool(const Output<Node>& value)>;
ValuePredicate as_value_predicate(NodePredicate pred);
class NGRAPH_API Pattern : public Node
{
public:
/// \brief \p a base class for \sa Skip and \sa Label
///
Pattern(const NodeVector& nodes, Predicate pred)
: Node(nodes)
Pattern(const OutputVector& patterns, ValuePredicate pred)
: Node(patterns)
, m_predicate(pred)
{
if (!m_predicate)
{
m_predicate = [](const Output<Node>&) { return true; };
}
}
Pattern(const OutputVector& patterns)
: Pattern(patterns, nullptr)
{
}
......@@ -45,10 +82,11 @@ namespace ngraph
throw ngraph_error("Uncopyable");
}
Predicate get_predicate() const;
ValuePredicate get_predicate() const;
bool is_pattern() const override { return true; }
protected:
std::function<bool(std::shared_ptr<Node>)> m_predicate;
ValuePredicate m_predicate;
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Skip::type_info;
const NodeTypeInfo& pattern::op::Skip::get_type_info() const
{
return type_info;
}
bool pattern::op::Skip::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
matcher->add_node(graph_value);
return m_predicate(graph_value) ? matcher->match_arguments(pattern_value, graph_value)
: matcher->match_value(input_value(0), graph_value);
}
......@@ -25,19 +25,29 @@ namespace ngraph
{
namespace op
{
/// \brief \p Skip allows users to specify unexpected nodes in a pattern
/// and skip them if a predicate condition is satisfied.
///
/// The graph value is added to the matched value list. If the predicate is true, the
/// match succeeds if the arguments match; if the predicate is false, the match succeeds
/// if the pattern input matches the graph value.
class NGRAPH_API Skip : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
const NodeTypeInfo& get_type_info() const override;
Skip(const std::shared_ptr<Node>& arg, Predicate predicate = nullptr)
: Pattern(NodeVector{arg}, predicate)
Skip(const Output<Node>& arg, ValuePredicate pred)
: Pattern({arg}, pred)
{
set_output_type(0, arg->get_element_type(), arg->get_output_partial_shape(0));
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
Skip(const Output<Node>& arg, NodePredicate pred = nullptr)
: Pattern({arg}, as_value_predicate(pred))
{
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
virtual bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/true.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::True::type_info;
const NodeTypeInfo& pattern::op::True::get_type_info() const
{
return type_info;
}
bool pattern::op::True::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
return true;
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// \brief The match always succeeds.
class NGRAPH_API True : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternTrue", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief Always matches, does not add node to match list.
True()
: Pattern(OutputVector{})
{
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
}
......@@ -48,6 +48,7 @@
#include "ngraph/op/tanh.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/or.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
......@@ -540,8 +541,11 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
ref_rnn_type);
auto lstm_goe = std::make_shared<ngraph::op::GetOutputElement>(lstm, 1);
// We cannot attach labels to multi-output nodes, so we attach a label to the goe instead
auto lstm_goe_label =
std::make_shared<pattern::op::Label>(lstm_goe, nullptr, NodeVector{lstm_goe});
auto lstm_goe_label = std::make_shared<pattern::op::Label>(
lstm_goe,
nullptr,
OutputVector{std::make_shared<pattern::op::Or>(
OutputVector{lstm_goe, std::make_shared<ngraph::op::GetOutputElement>(lstm, 0)})});
auto lstm_goe_slice =
std::make_shared<ngraph::op::Slice>(lstm_goe_label, Coordinate{10, 0}, Coordinate{20, 100});
......@@ -935,6 +939,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
};
auto m = std::make_shared<pattern::RecurrentMatcher>(
std::make_shared<ngraph::op::GetOutputElement>(lstm, 1),
lstm_goe,
lstm_ct,
std::set<std::shared_ptr<pattern::op::Label>>{lstm_weights_layer_shared,
......@@ -1255,10 +1260,8 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
// Define a call back that needs to called once the DFG matches the pattern
auto callback = [rnn_left_to_right, rnn_right_to_left](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto rnn_ltor_node =
std::static_pointer_cast<ngraph::op::Rnn>(pattern_map[rnn_left_to_right]);
auto rnn_rtol_node =
std::static_pointer_cast<ngraph::op::Rnn>(pattern_map[rnn_right_to_left]);
auto rnn_ltor_node = as_type_ptr<ngraph::op::Rnn>(pattern_map[rnn_left_to_right]);
auto rnn_rtol_node = as_type_ptr<ngraph::op::Rnn>(pattern_map[rnn_right_to_left]);
if (rnn_ltor_node->get_src_sequence_length() != rnn_rtol_node->get_src_sequence_length())
{
......
......@@ -122,7 +122,7 @@ TEST(cpu_fusion, gemm_pattern)
auto pbroadcast = make_shared<op::Broadcast>(b, dot->get_shape(), AxisSet{0});
auto padd = pdot + pbroadcast;
TestMatcher n(nullptr);
TestMatcher n;
ASSERT_TRUE(n.match(padd, add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
......
......@@ -37,8 +37,11 @@
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/branch.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/or.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/pattern/op/true.hpp"
#include "ngraph/serializer.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
......@@ -296,7 +299,7 @@ TEST(pattern, matcher)
{
Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
TestMatcher n(nullptr);
TestMatcher n;
ASSERT_TRUE(n.match(a, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
......@@ -435,9 +438,24 @@ TEST(pattern, matcher)
ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add);
// Or
ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a + b));
ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a - b));
// Branch
{
auto branch = std::make_shared<pattern::op::Branch>();
auto star = std::make_shared<pattern::op::Or>(
OutputVector{branch, std::make_shared<pattern::op::True>()});
auto pattern = star + star;
branch->set_destination(pattern);
ASSERT_TRUE(n.match(pattern, ((a + b) + (b + a) + a)));
ASSERT_EQ(n.get_matched_nodes().size(), 4);
}
// strict mode
{
TestMatcher sm(nullptr, "TestMatcher", true);
TestMatcher sm(Output<Node>{}, "TestMatcher", true);
// exact shape and type
auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
auto label_dynamic_shape =
......@@ -462,7 +480,7 @@ TEST(pattern, matcher)
TEST(pattern, mean)
{
// construct mean
TestMatcher n(nullptr);
TestMatcher n;
auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
......@@ -477,7 +495,7 @@ TEST(pattern, mean)
TEST(pattern, variance)
{
// construct variance
TestMatcher n(nullptr);
TestMatcher n;
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::Multiply>(input, input);
......@@ -733,7 +751,7 @@ TEST(pattern, is_contained_match)
Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
auto absn = make_shared<op::Abs>(a);
TestMatcher n(nullptr);
TestMatcher n;
auto label_a = std::make_shared<pattern::op::Label>(a);
auto label_abs = make_shared<op::Abs>(a);
......
......@@ -18,21 +18,26 @@
class TestMatcher : public ngraph::pattern::Matcher
{
using ngraph::pattern::Matcher::Matcher;
bool virtual match_node(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map) override
public:
TestMatcher()
: TestMatcher(ngraph::Output<ngraph::Node>{})
{
}
bool virtual match_value(const ngraph::Output<ngraph::Node>& pattern_value,
const ngraph::Output<ngraph::Node>& graph_value) override
{
if (ngraph::as_type_ptr<::ngraph::op::Parameter>(pattern_node))
if (ngraph::is_type<::ngraph::op::Parameter>(pattern_value.get_node_shared_ptr()))
{
bool result = pattern_node == ngraph::as_type_ptr<::ngraph::op::Parameter>(graph_node);
bool result = pattern_value == graph_value;
if (result)
{
m_matched_list.push_back(graph_node);
m_matched_list.push_back(graph_value.get_node_shared_ptr());
}
return result;
}
return this->ngraph::pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
return this->ngraph::pattern::Matcher::match_value(pattern_value, graph_value);
}
public:
......@@ -44,15 +49,7 @@ public:
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
m_matched_list.clear();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
m_pattern_node = pattern_node;
return ngraph::pattern::Matcher::match(graph_node, ngraph::pattern::PatternValueMap{});
}
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment