Commit bff65fe3 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Any op (#1036)

* add any op
parent 05a4fbef
......@@ -71,19 +71,19 @@ namespace ngraph
return is_match;
}
bool Matcher::match_skip(const std::shared_ptr<op::Skip>& any,
bool Matcher::match_skip(const std::shared_ptr<op::Skip>& skip,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
auto predicate = any->get_predicate();
auto predicate = skip->get_predicate();
if (!predicate || any->get_predicate()(graph_node))
if (!predicate || predicate(graph_node))
{
return match_arguments(any, graph_node, pattern_map);
return match_arguments(skip, graph_node, pattern_map);
}
else
{
auto args = any->get_arguments();
auto args = skip->get_arguments();
if (args.size() != 1)
{
throw ngraph_error("Skip can only take one argument");
......@@ -93,6 +93,26 @@ namespace ngraph
}
}
bool Matcher::match_any(const std::shared_ptr<op::Any>& any,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
{
auto predicate = any->get_predicate();
if (!predicate)
{
throw ngraph_error("predicate is required");
}
if (predicate(graph_node))
{
return match_arguments(any, graph_node, pattern_map);
}
else
{
return false;
}
}
bool Matcher::match_node(const std::shared_ptr<Node>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map)
......@@ -111,10 +131,15 @@ namespace ngraph
return match_pattern(label_node, graph_node, pattern_map);
}
if (auto any_node = std::dynamic_pointer_cast<op::Skip>(
if (auto skip_node = std::dynamic_pointer_cast<op::Skip>(
pattern_node)) //matches PatternSkipOp semantics
{
return match_skip(any_node, graph_node, pattern_map);
return match_skip(skip_node, graph_node, pattern_map);
}
if (auto any_node = std::dynamic_pointer_cast<op::Any>(pattern_node))
{
return match_any(any_node, graph_node, pattern_map);
}
auto p_pattern_node = pattern_node.get();
......
......@@ -17,9 +17,12 @@
#pragma once
#include <cassert>
#include <functional>
#include <memory.h>
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/skip.hpp"
......@@ -36,6 +39,16 @@ namespace ngraph
using recurrent_graph_rewrite_callback = std::function<bool(class RecurrentMatcher& m)>;
using RPatternMap = std::map<std::shared_ptr<op::Label>, NodeVector>;
template <typename T>
std::function<bool(std::shared_ptr<Node>)> has_class()
{
auto pred = [](std::shared_ptr<Node> node) -> bool {
return std::dynamic_pointer_cast<T>(node) != nullptr;
};
return pred;
}
namespace op
{
class Label;
......@@ -130,6 +143,9 @@ namespace ngraph
bool match_skip(const std::shared_ptr<op::Skip>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
bool match_any(const std::shared_ptr<op::Any>& pattern_node,
const std::shared_ptr<Node>& graph_node,
PatternMap& pattern_map);
graph_rewrite_callback m_callback;
size_t m_depth;
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/pattern/op/pattern.hpp"
namespace ngraph
{
namespace pattern
{
namespace op
{
/// \brief Anys are used in patterns to express arbitrary queries on a node
class Any : public Pattern
{
public:
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa shape.
Any(const element::Type& type,
const Shape s,
Predicate pred,
const NodeVector& wrapped_nodes)
: Pattern("Any", wrapped_nodes, pred)
{
if (!pred)
{
throw ngraph_error("predicate is required");
}
add_output(type, s);
}
/// \brief creates a Any node containing a sub-pattern described by the type and shape of \sa node.
Any(std::shared_ptr<Node> node, Predicate pred, const NodeVector& wrapped_nodes)
: Any(node->get_element_type(), node->get_shape(), pred, wrapped_nodes)
{
}
};
}
}
}
......@@ -402,19 +402,32 @@ TEST(pattern, matcher)
auto any = std::make_shared<pattern::op::Skip>(a);
ASSERT_TRUE(n.match(any, abs));
auto any_false =
std::make_shared<pattern::op::Skip>(a, [](std::shared_ptr<Node> no) { return false; });
auto false_pred = [](std::shared_ptr<Node> no) { return false; };
auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
ASSERT_TRUE(n.match(any_false, a));
auto pattern = std::make_shared<pattern::op::Label>(a);
ASSERT_TRUE(n.match(pattern, a));
ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto pattern_false =
std::make_shared<pattern::op::Label>(a, [](std::shared_ptr<Node> no) { return false; });
auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
ASSERT_FALSE(n.match(pattern_false, a));
auto b = make_shared<op::Parameter>(element::i32, shape);
auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
ASSERT_TRUE(n.match(bea, a + b));
ASSERT_TRUE(n.match(bea, b + a));
auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
ASSERT_FALSE(n.match(bea_false, a + b));
auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
auto ab = a + b;
ASSERT_TRUE(n.match(bea_label, ab));
ASSERT_EQ(n.get_pattern_map()[bea_label], ab);
auto d = make_shared<op::Parameter>(element::i32, shape);
ASSERT_FALSE(n.match(d, b));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment