Unverified Commit 4345e39d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Pattern matching for sum (#293)

* the first stab at pattern for sum

test refactoring, debug msg clean up, formatting fixes

removing v1 and cleaning up v2 + formatting

rollback the changes in reduce_ops

rename v2 -> sum_pred

remove unused funcs

switch to new c-tors

remove TensorViewType

removing an assert

fix a docstring to match a c-tor

* fixes after rebase
parent c5ffe8e9
...@@ -157,6 +157,11 @@ namespace ngraph ...@@ -157,6 +157,11 @@ namespace ngraph
template <typename T> template <typename T>
std::vector<T> get_vector() const std::vector<T> get_vector() const
{ {
if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0)
{
throw ngraph_error("Buffer over-read");
}
std::vector<T> rc; std::vector<T> rc;
const T* p = reinterpret_cast<const T*>(m_data); const T* p = reinterpret_cast<const T*>(m_data);
for (size_t i = 0; i < shape_size(m_shape); i++) for (size_t i = 0; i < shape_size(m_shape); i++)
......
...@@ -47,8 +47,8 @@ namespace ngraph ...@@ -47,8 +47,8 @@ namespace ngraph
if (pattern_map[label] != graph_node) if (pattern_map[label] != graph_node)
{ {
NGRAPH_DEBUG << "[MATCHER] get_bound_node " << pattern_map[label]->get_name() NGRAPH_DEBUG << "[MATCHER] get_bound_node " << pattern_map[label]->get_name()
<< " , " << pattern_map[label] << " NOT match " << " , " << pattern_map[label] << " does NOT match "
<< graph_node->get_name() << " , " << graph_node; << graph_node->get_name();
is_match = false; is_match = false;
} }
} }
...@@ -71,9 +71,8 @@ namespace ngraph ...@@ -71,9 +71,8 @@ namespace ngraph
if (is_match) if (is_match)
{ {
NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node " NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node " << label->get_name()
<< graph_node->get_name() << " , " << graph_node << " , " << " , " << graph_node << " , " << graph_node->get_name();
<< graph_node->get_name();
pattern_map[label] = graph_node; pattern_map[label] = graph_node;
} }
} }
...@@ -105,8 +104,8 @@ namespace ngraph ...@@ -105,8 +104,8 @@ namespace ngraph
assert(pattern_node && graph_node); assert(pattern_node && graph_node);
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : " NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " " << "pattern = " << pattern_node->get_name() << " matched "
<< "matched " << graph_node << " , " << graph_node->get_name(); << graph_node->get_name();
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node)) if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{ {
...@@ -151,9 +150,9 @@ namespace ngraph ...@@ -151,9 +150,9 @@ namespace ngraph
const std::shared_ptr<ngraph::Node>& graph_node, const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map) PatternMap& pattern_map)
{ {
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] " NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_arguments : "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " " << "pattern = " << pattern_node->get_name() << " "
<< "matched " << graph_node << " , " << graph_node->get_name(); << "matched " << graph_node->get_name();
auto args = get_arguments(graph_node); auto args = get_arguments(graph_node);
auto pattern_args = get_arguments(pattern_node); auto pattern_args = get_arguments(pattern_node);
...@@ -171,7 +170,7 @@ namespace ngraph ...@@ -171,7 +170,7 @@ namespace ngraph
do do
{ {
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node " NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
<< graph_node->get_name() << " , " << graph_node; << graph_node->get_name();
PatternMap copy{pattern_map}; PatternMap copy{pattern_map};
if (match_permutation(pattern_args, args, copy)) if (match_permutation(pattern_args, args, copy))
{ {
...@@ -231,14 +230,10 @@ namespace ngraph ...@@ -231,14 +230,10 @@ namespace ngraph
throw "m_pattern_node or graph_node are not set!"; throw "m_pattern_node or graph_node are not set!";
} }
if (get_users(m_pattern_node).size()) (void)get_users; //to supress an unused function warning
{
throw "Pattern Node must not be used elsewhere!";
}
NGRAPH_DEBUG << "Starting match pattern = " << m_pattern_node << " , " NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name()
<< m_pattern_node->get_name() << " , graph_node = " << graph_node << " , " << " , graph_node = " << graph_node->get_name();
<< graph_node->get_name();
bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map); bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map);
if (is_match) if (is_match)
......
...@@ -30,43 +30,37 @@ namespace ngraph ...@@ -30,43 +30,37 @@ namespace ngraph
class Label : public Pattern class Label : public Pattern
{ {
public: public:
/// \brief creates a Label node from \sa node. /// \brief creates a Label node containing a sub-pattern described by \sa type and \sa shape.
/// ///
/// this Label node can be bound to arbitrary nodes in an input graph /// this Label node can be bound only to the nodes in the input graph
/// as long as provided \sa pred is satisfied and the node hasn't been previously bound to /// that match the pattern specified by \sa wrapped_nodes
/// a different node in the input graph /// Example:
/// \code{.cpp} /// \code{.cpp}
/// auto pattern = pattern::op::Label::make_from_node(a); //a is op::Parameter /// auto add = a + b; //a and b are op::Parameter in this example
/// matcher.match(pattern, a)); /// auto label = std::make_shared<pattern::op::Label>(element::f32, Shape{2,2} , nullptr, Nodes{add});
/// \endcode /// \endcode
static std::shared_ptr<Label> Label(const element::Type& type,
make_from_node(const std::shared_ptr<ngraph::Node>& node, const Shape s,
Predicate pred = nullptr) Predicate pred = nullptr,
const Nodes& wrapped_nodes = Nodes{})
: Pattern("Label", wrapped_nodes, pred)
{ {
auto label = std::make_shared<Label>(Nodes{}, pred); add_output(type, s);
label->add_output(node->get_element_type(), node->get_shape());
return label;
} }
/// \brief creates a Label node containing a sub-pattern described by \sa node. /// \brief creates a Label node containing a sub-pattern described by the type and shape of \sa node.
/// ///
/// this Label node can be bound only to the nodes in the input graph /// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa node /// that match the pattern specified by \sa wrapped_nodes
/// Example: /// Example:
/// \code{.cpp} /// \code{.cpp}
/// auto add = a + b; //a and b are op::Parameter in this example /// auto add = a + b; //a and b are op::Parameter in this example
/// auto label = pattern::op::Label::wrap(add); /// auto label = std::make_shared<pattern::op::Label>(add, nullptr, Nodes{add});
/// \endcode /// \endcode
static std::shared_ptr<Label> wrap(const std::shared_ptr<ngraph::Node>& node, Label(std::shared_ptr<Node> node,
Predicate pred = nullptr) Predicate pred = nullptr,
{ const Nodes& wrapped_nodes = Nodes{})
auto label = std::make_shared<Label>(Nodes{node}, pred); : Label(node->get_element_type(), node->get_shape(), pred, wrapped_nodes)
label->add_output(node->get_element_type(), node->get_shape());
return label;
}
Label(const Nodes& subgraph, Predicate pred)
: Pattern("Label", Nodes{subgraph}, pred)
{ {
} }
}; };
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment