Unverified Commit 3bffe536 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Cyphers/pattern (#4095)

* Make pattern matcher node-based

Simplify implementation
Add support for Or, Branch
Start of support for recurrent pattern

* Only save state at branch points

* Factor Or out of label

* Documentation

* Review

* Only ops need to match on shape/output index
parent 35d8e436
......@@ -557,11 +557,24 @@ set (SRC
pass/pass_util.cpp
pattern/matcher.cpp
pattern/matcher.hpp
pattern/op/any.cpp
pattern/op/any.hpp
pattern/op/any_of.cpp
pattern/op/any_of.hpp
pattern/op/branch.cpp
pattern/op/branch.hpp
pattern/op/capture.cpp
pattern/op/capture.hpp
pattern/op/label.cpp
pattern/op/label.hpp
pattern/op/or.cpp
pattern/op/or.hpp
pattern/op/pattern.cpp
pattern/op/pattern.hpp
pattern/op/skip.cpp
pattern/op/skip.hpp
pattern/op/true.cpp
pattern/op/true.hpp
placement.cpp
placement.hpp
provenance.cpp
......
......@@ -27,6 +27,7 @@
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/placement.hpp"
using namespace std;
......@@ -930,6 +931,23 @@ void Node::validate_and_infer_elementwise_logical(const op::AutoBroadcastSpec& a
set_output_type(0, element::boolean, args_pshape);
}
bool Node::match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
if (pattern_value.get_index() != graph_value.get_index() ||
(matcher->is_strict_mode() &&
(!pattern_value.get_element_type().compatible(graph_value.get_element_type()) ||
!pattern_value.get_partial_shape().compatible(graph_value.get_partial_shape()))))
{
return false;
}
matcher->add_node(graph_value);
return graph_value.get_node_shared_ptr()->get_type_info() == get_type_info() &&
matcher->match_arguments(pattern_value, graph_value);
}
// default implementation for the node to check if it contains partial shape
// we will override this method, for the Op's which depends on additional shape
// attribute to determine if node contains partial shape or not
......
......@@ -67,6 +67,11 @@ namespace ngraph
}
} // namespace op
namespace pattern
{
class Matcher;
}
using ResultVector = std::vector<std::shared_ptr<op::v0::Result>>;
namespace autodiff
......@@ -260,6 +265,7 @@ namespace ngraph
virtual bool is_constant() const;
virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; }
virtual bool is_pattern() const { return false; }
virtual bool is_commutative() const { return false; }
virtual bool is_dynamic() const;
virtual bool has_state() const { return false; }
......@@ -502,6 +508,10 @@ namespace ngraph
return m_op_annotations;
}
virtual bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value);
private:
descriptor::Input& get_input_descriptor(size_t position);
descriptor::Output& get_output_descriptor(size_t position);
......@@ -722,6 +732,12 @@ namespace ngraph
/// A null output
Output() = default;
void reset()
{
m_node.reset();
m_index = 0;
}
/// This output position for a different node
Output<Node> for_node(const std::shared_ptr<Node>& node) { return Output(node, m_index); }
/// \return A pointer to the node referred to by this output handle.
......@@ -828,6 +844,12 @@ namespace ngraph
/// A null output
Output() = default;
void reset()
{
m_node.reset();
m_index = 0;
}
/// This output position for a different node
Output<const Node> for_node(const std::shared_ptr<const Node>& node)
{
......
This diff is collapsed.
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Any::type_info;
const NodeTypeInfo& pattern::op::Any::get_type_info() const
{
return type_info;
}
bool pattern::op::Any::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
matcher->add_node(graph_value);
return m_predicate(graph_value) && matcher->match_arguments(pattern_value, graph_value);
}
......@@ -25,7 +25,8 @@ namespace ngraph
{
namespace op
{
/// \brief Anys are used in patterns to express arbitrary queries on a node
/// The graph value is to the matched value list. If the predicate is true for the node
/// and the arguments match, the match succeeds.
class NGRAPH_API Any : public Pattern
{
public:
......@@ -35,26 +36,38 @@ namespace ngraph
/// shape.
Any(const element::Type& type,
const PartialShape& s,
Predicate pred,
const NodeVector& wrapped_nodes)
: Pattern(wrapped_nodes, pred)
ValuePredicate pred,
const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred)
{
if (!pred)
{
throw ngraph_error("predicate is required");
}
set_output_type(0, type, s);
}
Any(const element::Type& type,
const PartialShape& s,
NodePredicate pred,
const NodeVector& wrapped_values)
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values))
{
}
/// \brief creates a Any node containing a sub-pattern described by the type and
/// shape of \sa node.
Any(std::shared_ptr<Node> node, Predicate pred, const NodeVector& wrapped_nodes)
: Any(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
Any(const Output<Node>& node,
ValuePredicate pred,
const OutputVector& wrapped_values)
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values)
{
}
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
: Any(node.get_element_type(),
node.get_partial_shape(),
as_value_predicate(pred),
as_output_vector(wrapped_values))
{
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::AnyOf::type_info;
const NodeTypeInfo& pattern::op::AnyOf::get_type_info() const
{
return type_info;
}
bool pattern::op::AnyOf::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
matcher->add_node(graph_value);
return m_predicate(graph_value) && ([&]() {
for (auto arg : graph_value.get_node_shared_ptr()->input_values())
{
auto saved = matcher->start_match();
if (matcher->match_value(input_value(0), arg))
{
return saved.finish(true);
}
}
return false;
}());
}
......@@ -25,13 +25,13 @@ namespace ngraph
{
namespace op
{
/// \brief AnyOfs are used in patterns to express arbitrary queries on a node
/// The graph value is added to the matched values list. If the predicate is true for
/// the
/// graph node, a submatch is performed on the input of AnyOf and each input of the
/// graph node. The first match that succeeds results in a successful match. Otherwise
/// the match fails.
///
/// When AnyOf predicate matches a node; Matcher tries to match node's arguments to
/// a single argument of AnyOf one by one. The first match is returned.
/// This is useful for nodes with variable number of arguments such as Concat
/// AnyOf enables on to specify one single branch/chain. The remaining arguments
/// can be discovered (in a callback) by simply inspecting matched node's argument.
/// AnyOf may be given a type and shape for use in strict mode.
class NGRAPH_API AnyOf : public Pattern
{
public:
......@@ -41,31 +41,46 @@ namespace ngraph
/// \sa shape.
AnyOf(const element::Type& type,
const PartialShape& s,
Predicate pred,
const NodeVector& wrapped_nodes)
: Pattern(wrapped_nodes, pred)
ValuePredicate pred,
const OutputVector& wrapped_values)
: Pattern(wrapped_values, pred)
{
if (!pred)
{
throw ngraph_error("predicate is required");
}
if (wrapped_nodes.size() != 1)
if (wrapped_values.size() != 1)
{
throw ngraph_error("AnyOf expects exactly one argument");
}
set_output_type(0, type, s);
}
AnyOf(const element::Type& type,
const PartialShape& s,
NodePredicate pred,
const NodeVector& wrapped_values)
: AnyOf(type,
s,
[pred](const Output<Node>& value) {
return pred(value.as_single_output_node(false));
},
as_output_vector(wrapped_values))
{
}
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
/// shape of \sa node.
AnyOf(std::shared_ptr<Node> node, Predicate pred, const NodeVector& wrapped_nodes)
: AnyOf(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
AnyOf(const Output<Node>& node,
ValuePredicate pred,
const OutputVector& wrapped_values)
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values)
{
}
AnyOf(std::shared_ptr<Node> node,
NodePredicate pred,
const NodeVector& wrapped_values)
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values))
{
}
bool match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/branch.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Branch::type_info;
const NodeTypeInfo& pattern::op::Branch::get_type_info() const
{
return type_info;
}
bool pattern::op::Branch::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
return matcher->match_value(get_destination(), graph_value);
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// A branch adds a loop to the pattern. The branch match is successful if the
/// destination node pattern matches the graph value. The destination node is a node in
/// the pattern graph that will not have been created some time after the Branch node is
/// created; use set_destination to add it.
///
/// The branch destination is not stored as a shared pointer to prevent reference
/// cycles. Thus the destination node must be referenced in some other way to prevent it
/// from being deleted.
class NGRAPH_API Branch : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternBranch", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief Creates a Branch pattern
/// \param pattern the destinationing pattern
/// \param labels Labels where the destination may occur
Branch()
: Pattern(OutputVector{})
{
set_output_type(0, element::f32, Shape{});
}
void set_destination(const Output<Node>& destination)
{
m_destination_node = destination.get_node();
m_destination_index = destination.get_index();
}
Output<Node> get_destination() const
{
return m_destination_node == nullptr
? Output<Node>()
: Output<Node>{m_destination_node->shared_from_this(),
m_destination_index};
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
Node* m_destination_node{nullptr};
size_t m_destination_index{0};
};
}
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/capture.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Capture::type_info;
const NodeTypeInfo& pattern::op::Capture::get_type_info() const
{
return type_info;
}
bool pattern::op::Capture::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
matcher->capture(m_static_nodes);
return true;
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// Experimental for support of recurrent matches.
///
/// Capture adds the pattern value map to a list of pattern value maps and resets
/// matches for pattern nodes not in the static node list. The match always succeeds.
class NGRAPH_API Capture : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternCapture", 0};
const NodeTypeInfo& get_type_info() const override;
Capture(const Output<Node>& arg)
: Pattern({arg})
{
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
/// \brief static nodes are retained after a capture. All other nodes are dropped
std::set<Node*> get_static_nodes() { return m_static_nodes; }
void set_static_nodes(const std::set<Node*>& static_nodes)
{
m_static_nodes = static_nodes;
}
virtual bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
std::set<Node*> m_static_nodes;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/or.hpp"
#include "ngraph/pattern/op/true.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Label::type_info;
const NodeTypeInfo& pattern::op::Label::get_type_info() const
{
return type_info;
}
Output<Node> pattern::op::Label::wrap_values(const OutputVector& wrapped_values)
{
switch (wrapped_values.size())
{
case 0: return make_shared<pattern::op::True>()->output(0);
case 1: return wrapped_values[0];
default: return make_shared<pattern::op::Or>(wrapped_values)->output(0);
}
}
bool pattern::op::Label::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
if (m_predicate(graph_value))
{
auto& pattern_map = matcher->get_pattern_value_map();
auto saved = matcher->start_match();
matcher->add_node(graph_value);
if (pattern_map.count(shared_from_this()))
{
return saved.finish(pattern_map[shared_from_this()] == graph_value);
}
else
{
pattern_map[shared_from_this()] = graph_value;
return saved.finish(matcher->match_value(input_value(0), graph_value));
}
}
return false;
}
......@@ -25,9 +25,15 @@ namespace ngraph
{
namespace op
{
/// \brief Labels are used in patterns to express repeating nodes in an input graph
/// and bind them to specific nodes from the graph
/// Fails if the predicate returns false on the graph value.
///
/// The graph value is added to the matched values list. If the Label is already
/// associated with a value, the match succeeds if the value is the same as the graph
/// value. Otherwise, the label is associated with the graph value and the match
/// succeeds if the pattern input matches the graph value.
///
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
class NGRAPH_API Label : public Pattern
{
public:
......@@ -44,38 +50,95 @@ namespace ngraph
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
/// Shape{2,2},
/// nullptr,
/// NodeVector{add});
/// OutputVector{add});
/// \endcode
Label(const element::Type& type,
const PartialShape& s,
Predicate pred = nullptr,
const NodeVector& wrapped_nodes = NodeVector{})
: Pattern(wrapped_nodes, pred)
const ValuePredicate pred,
const OutputVector& wrapped_values)
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred)
{
set_output_type(0, type, s);
}
Label(const element::Type& type, const PartialShape& s)
: Label(type, s, [](const Output<Node>&) { return true; }, OutputVector())
{
}
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
: Label(type, s, pred, OutputVector{})
{
}
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
: Label(type, s, as_value_predicate(pred), OutputVector{})
{
}
Label(const element::Type& type,
const PartialShape& s,
const NodePredicate pred,
const NodeVector& wrapped_values)
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values))
{
}
/// \brief creates a Label node containing a sub-pattern described by the type and
/// shape of \sa node.
///
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa wrapped_nodes
/// that match the pattern specified by \sa wrapped_values
/// Example:
/// \code{.cpp}
/// auto add = a + b; // a and b are op::Parameter in this example
/// auto label = std::make_shared<pattern::op::Label>(add,
/// nullptr,
/// NodeVector{add});
/// OutputVector{add});
/// \endcode
Label(std::shared_ptr<Node> node,
Predicate pred = nullptr,
const NodeVector& wrapped_nodes = NodeVector{})
: Label(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
Label(const Output<Node>& value,
const ValuePredicate pred,
const OutputVector& wrapped_values)
: Label(
value.get_element_type(), value.get_partial_shape(), pred, wrapped_values)
{
}
Label(const Output<Node>& value, const ValuePredicate pred)
: Label(
value.get_element_type(), value.get_partial_shape(), pred, OutputVector{})
{
}
Label(const Output<Node>& value, const NodePredicate pred)
: Label(value.get_element_type(),
value.get_partial_shape(),
as_value_predicate(pred),
OutputVector{})
{
}
Label(const Output<Node>& value)
: Label(value.get_element_type(),
value.get_partial_shape(),
[](const Output<Node>&) { return true; },
OutputVector{})
{
}
Label(const Output<Node>& node,
const NodePredicate pred,
const NodeVector& wrapped_values)
: Label(node.get_element_type(),
node.get_partial_shape(),
as_value_predicate(pred),
as_output_vector(wrapped_values))
{
}
bool match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
protected:
static Output<Node> wrap_values(const OutputVector& wrapped_values);
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/or.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Or::type_info;
const NodeTypeInfo& pattern::op::Or::get_type_info() const
{
return type_info;
}
bool pattern::op::Or::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
for (auto input_value : input_values())
{
auto saved = matcher->start_match();
if (matcher->match_value(input_value, graph_value))
{
return saved.finish(true);
}
}
return false;
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// A submatch on the graph value is performed on each input to the Or; the match
/// succeeds on the first match. Otherwise the match fails.
class NGRAPH_API Or : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternOr", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief creates an Or node matching one of several sub-patterns in order. Does
/// not add node to match list.
/// \param patterns The patterns to try for matching
Or(const OutputVector& patterns)
: Pattern(patterns)
{
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <regex>
#include "pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
// The symbols are required to be in cpp file to workaround RTTI issue on Android LLVM
ValuePredicate Pattern::get_predicate() const { return m_predicate; }
ValuePredicate as_value_predicate(NodePredicate pred)
{
if (pred == nullptr)
{
return [](const Output<Node>&) { return true; };
}
else
{
return [pred](const Output<Node>& value) {
return pred(value.get_node_shared_ptr());
};
}
}
}
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map)
{
PatternMap result;
for (auto& kv : pattern_value_map)
{
result[kv.first] = kv.second.get_node_shared_ptr();
}
return result;
}
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map)
{
PatternValueMap result;
for (auto& kv : pattern_map)
{
result[kv.first] = kv.second;
}
return result;
}
}
}
......@@ -26,16 +26,53 @@ namespace ngraph
{
namespace op
{
using Predicate = std::function<bool(std::shared_ptr<Node>)>;
class Label;
}
class Matcher;
class MatchState;
using RPatternValueMap = std::map<std::shared_ptr<Node>, OutputVector>;
using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
using PatternValueMaps = std::vector<PatternValueMap>;
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
template <typename T>
std::function<bool(std::shared_ptr<Node>)> has_class()
{
auto pred = [](std::shared_ptr<Node> node) -> bool { return is_type<T>(node); };
return pred;
}
namespace op
{
using NodePredicate = std::function<bool(std::shared_ptr<Node>)>;
using ValuePredicate = std::function<bool(const Output<Node>& value)>;
ValuePredicate as_value_predicate(NodePredicate pred);
class NGRAPH_API Pattern : public Node
{
public:
/// \brief \p a base class for \sa Skip and \sa Label
///
Pattern(const NodeVector& nodes, Predicate pred)
: Node(nodes)
Pattern(const OutputVector& patterns, ValuePredicate pred)
: Node(patterns)
, m_predicate(pred)
{
if (!m_predicate)
{
m_predicate = [](const Output<Node>&) { return true; };
}
}
Pattern(const OutputVector& patterns)
: Pattern(patterns, nullptr)
{
}
......@@ -45,10 +82,11 @@ namespace ngraph
throw ngraph_error("Uncopyable");
}
Predicate get_predicate() const;
ValuePredicate get_predicate() const;
bool is_pattern() const override { return true; }
protected:
std::function<bool(std::shared_ptr<Node>)> m_predicate;
ValuePredicate m_predicate;
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::Skip::type_info;
const NodeTypeInfo& pattern::op::Skip::get_type_info() const
{
return type_info;
}
bool pattern::op::Skip::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
matcher->add_node(graph_value);
return m_predicate(graph_value) ? matcher->match_arguments(pattern_value, graph_value)
: matcher->match_value(input_value(0), graph_value);
}
......@@ -25,19 +25,29 @@ namespace ngraph
{
namespace op
{
/// \brief \p Skip allows users to specify unexpected nodes in a pattern
/// and skip them if a predicate condition is satisfied.
///
/// The graph value is added to the matched value list. If the predicate is true, the
/// match succeeds if the arguments match; if the predicate is false, the match succeeds
/// if the pattern input matches the graph value.
class NGRAPH_API Skip : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternSkip", 0};
const NodeTypeInfo& get_type_info() const override;
Skip(const std::shared_ptr<Node>& arg, Predicate predicate = nullptr)
: Pattern(NodeVector{arg}, predicate)
Skip(const Output<Node>& arg, ValuePredicate pred)
: Pattern({arg}, pred)
{
set_output_type(0, arg->get_element_type(), arg->get_output_partial_shape(0));
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
Skip(const Output<Node>& arg, NodePredicate pred = nullptr)
: Pattern({arg}, as_value_predicate(pred))
{
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
}
virtual bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pattern/op/true.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo pattern::op::True::type_info;
const NodeTypeInfo& pattern::op::True::get_type_info() const
{
return type_info;
}
bool pattern::op::True::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
return true;
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// \brief The match always succeeds.
class NGRAPH_API True : public Pattern
{
public:
static constexpr NodeTypeInfo type_info{"patternTrue", 0};
const NodeTypeInfo& get_type_info() const override;
/// \brief Always matches, does not add node to match list.
True()
: Pattern(OutputVector{})
{
}
bool match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
};
}
}
}
......@@ -48,6 +48,7 @@
#include "ngraph/op/tanh.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/or.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
......@@ -540,8 +541,11 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
ref_rnn_type);
auto lstm_goe = std::make_shared<ngraph::op::GetOutputElement>(lstm, 1);
// We cannot attach labels to multi-output nodes, so we attach a label to the goe instead
auto lstm_goe_label =
std::make_shared<pattern::op::Label>(lstm_goe, nullptr, NodeVector{lstm_goe});
auto lstm_goe_label = std::make_shared<pattern::op::Label>(
lstm_goe,
nullptr,
OutputVector{std::make_shared<pattern::op::Or>(
OutputVector{lstm_goe, std::make_shared<ngraph::op::GetOutputElement>(lstm, 0)})});
auto lstm_goe_slice =
std::make_shared<ngraph::op::Slice>(lstm_goe_label, Coordinate{10, 0}, Coordinate{20, 100});
......@@ -935,6 +939,7 @@ void ngraph::runtime::cpu::pass::RNNFusion::construct_rnn_lstm_fprop()
};
auto m = std::make_shared<pattern::RecurrentMatcher>(
std::make_shared<ngraph::op::GetOutputElement>(lstm, 1),
lstm_goe,
lstm_ct,
std::set<std::shared_ptr<pattern::op::Label>>{lstm_weights_layer_shared,
......@@ -1255,10 +1260,8 @@ void ngraph::runtime::cpu::pass::BiDirectionalRnn::construct_bidirectional_rnn()
// Define a call back that needs to called once the DFG matches the pattern
auto callback = [rnn_left_to_right, rnn_right_to_left](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto rnn_ltor_node =
std::static_pointer_cast<ngraph::op::Rnn>(pattern_map[rnn_left_to_right]);
auto rnn_rtol_node =
std::static_pointer_cast<ngraph::op::Rnn>(pattern_map[rnn_right_to_left]);
auto rnn_ltor_node = as_type_ptr<ngraph::op::Rnn>(pattern_map[rnn_left_to_right]);
auto rnn_rtol_node = as_type_ptr<ngraph::op::Rnn>(pattern_map[rnn_right_to_left]);
if (rnn_ltor_node->get_src_sequence_length() != rnn_rtol_node->get_src_sequence_length())
{
......
......@@ -122,7 +122,7 @@ TEST(cpu_fusion, gemm_pattern)
auto pbroadcast = make_shared<op::Broadcast>(b, dot->get_shape(), AxisSet{0});
auto padd = pdot + pbroadcast;
TestMatcher n(nullptr);
TestMatcher n;
ASSERT_TRUE(n.match(padd, add));
ASSERT_EQ(n.get_pattern_map()[W], A);
ASSERT_EQ(n.get_pattern_map()[x], B);
......
......@@ -37,8 +37,11 @@
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/branch.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/or.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/pattern/op/true.hpp"
#include "ngraph/serializer.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
......@@ -296,7 +299,7 @@ TEST(pattern, matcher)
{
Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
TestMatcher n(nullptr);
TestMatcher n;
ASSERT_TRUE(n.match(a, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
......@@ -435,9 +438,24 @@ TEST(pattern, matcher)
ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add);
// Or
ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a + b));
ASSERT_TRUE(n.match(std::make_shared<pattern::op::Or>(OutputVector{a + b, a - b}), a - b));
// Branch
{
auto branch = std::make_shared<pattern::op::Branch>();
auto star = std::make_shared<pattern::op::Or>(
OutputVector{branch, std::make_shared<pattern::op::True>()});
auto pattern = star + star;
branch->set_destination(pattern);
ASSERT_TRUE(n.match(pattern, ((a + b) + (b + a) + a)));
ASSERT_EQ(n.get_matched_nodes().size(), 4);
}
// strict mode
{
TestMatcher sm(nullptr, "TestMatcher", true);
TestMatcher sm(Output<Node>{}, "TestMatcher", true);
// exact shape and type
auto scalar_param = make_shared<op::Parameter>(element::i32, Shape{});
auto label_dynamic_shape =
......@@ -462,7 +480,7 @@ TEST(pattern, matcher)
TEST(pattern, mean)
{
// construct mean
TestMatcher n(nullptr);
TestMatcher n;
auto input = std::make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
......@@ -477,7 +495,7 @@ TEST(pattern, mean)
TEST(pattern, variance)
{
// construct variance
TestMatcher n(nullptr);
TestMatcher n;
auto N = op::Constant::create(element::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::Multiply>(input, input);
......@@ -733,7 +751,7 @@ TEST(pattern, is_contained_match)
Shape shape{};
auto a = make_shared<op::Parameter>(element::i32, shape);
auto absn = make_shared<op::Abs>(a);
TestMatcher n(nullptr);
TestMatcher n;
auto label_a = std::make_shared<pattern::op::Label>(a);
auto label_abs = make_shared<op::Abs>(a);
......
......@@ -18,21 +18,26 @@
class TestMatcher : public ngraph::pattern::Matcher
{
using ngraph::pattern::Matcher::Matcher;
bool virtual match_node(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map) override
public:
TestMatcher()
: TestMatcher(ngraph::Output<ngraph::Node>{})
{
}
bool virtual match_value(const ngraph::Output<ngraph::Node>& pattern_value,
const ngraph::Output<ngraph::Node>& graph_value) override
{
if (ngraph::as_type_ptr<::ngraph::op::Parameter>(pattern_node))
if (ngraph::is_type<::ngraph::op::Parameter>(pattern_value.get_node_shared_ptr()))
{
bool result = pattern_node == ngraph::as_type_ptr<::ngraph::op::Parameter>(graph_node);
bool result = pattern_value == graph_value;
if (result)
{
m_matched_list.push_back(graph_node);
m_matched_list.push_back(graph_value.get_node_shared_ptr());
}
return result;
}
return this->ngraph::pattern::Matcher::match_node(pattern_node, graph_node, pattern_map);
return this->ngraph::pattern::Matcher::match_value(pattern_value, graph_value);
}
public:
......@@ -44,15 +49,7 @@ public:
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
m_pattern_map.clear();
m_match_root.reset();
m_matched_list.clear();
bool is_match = match_node(pattern_node, graph_node, m_pattern_map);
if (is_match)
{
m_match_root = graph_node;
}
return is_match;
m_pattern_node = pattern_node;
return ngraph::pattern::Matcher::match(graph_node, ngraph::pattern::PatternValueMap{});
}
};
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment