Commit 6ed233db authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

AnyOf for nodes with a variable number of arguments (#2075)

* any_of matching

* include a new file any_of.hpp
parent bf727e36
......@@ -135,6 +135,36 @@ namespace ngraph
}
}
bool Matcher::match_any_of(const std::shared_ptr<op::AnyOf>& any,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
auto predicate = any->get_predicate();
if (!predicate)
{
throw ngraph_error("predicate is required");
}
if (predicate(graph_node))
{
for (auto arg : graph_node->get_arguments())
{
PatternMap copy{pattern_map};
if (match_node(any->get_argument(0), arg, copy))
{
pattern_map.insert(begin(copy), end(copy));
return true;
}
}
return false;
}
else
{
return false;
}
}
bool Matcher::match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
......@@ -167,6 +197,11 @@ namespace ngraph
return abort_match(watermark, match_any(any_node, graph_node, pattern_map));
}
if (auto any_of_node = std::dynamic_pointer_cast<op::AnyOf>(pattern_node))
{
return abort_match(watermark, match_any_of(any_of_node, graph_node, pattern_map));
}
auto p_pattern_node = pattern_node.get();
auto p_graph_node = graph_node.get();
......
......@@ -23,6 +23,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/any_of.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
......@@ -163,6 +164,9 @@ namespace ngraph
bool match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool match_any_of(const std::shared_ptr<op::AnyOf>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
graph_rewrite_callback m_callback;
size_t m_depth;
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF AnyOf KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// \brief AnyOfs are used in patterns to express arbitrary queries on a node
///
/// When AnyOf predicate matches a node; Matcher tries to match node's arguments to
/// a single argument of AnyOf one by one. The first match is returned.
/// This is useful for nodes with variable number of arguments such as Concat
/// AnyOf enables on to specify one single branch/chain. The remaining arguments
/// can be discovered (in a callback) by simply inspecting matched node's argument.
class AnyOf : public Pattern
{
public:
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and \sa shape.
AnyOf(const element::Type& type,
const PartialShape& s,
Predicate pred,
const NodeVector& wrapped_nodes)
: Pattern("AnyOf", wrapped_nodes, pred)
{
if (!pred)
{
throw ngraph_error("predicate is required");
}
if (wrapped_nodes.size() != 1)
{
throw ngraph_error("AnyOf expects exactly one argument");
}
set_output_type(0, type, s);
}
/// \brief creates a AnyOf node containing a sub-pattern described by the type and shape of \sa node.
AnyOf(std::shared_ptr<Node> node, Predicate pred, const NodeVector& wrapped_nodes)
: AnyOf(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
{
}
};
}
}
}
......@@ -440,6 +440,23 @@ TEST(pattern, matcher)
auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
ASSERT_FALSE(n.match(bea_false, a + b));
auto add_abs_b = abs + b;
auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
ASSERT_TRUE(n.match(bea_any_of, add_abs_b));
auto add_b_abs = b + abs;
ASSERT_TRUE(n.match(bea_any_of, add_b_abs));
auto bea_any_of_label =
std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);
auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
ASSERT_EQ(n.get_pattern_map()[abs_label], abs);
auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
auto ab = a + b;
ASSERT_TRUE(n.match(bea_label, ab));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment