Unverified Commit 4345e39d authored by Nick Korovaiko's avatar Nick Korovaiko Committed by GitHub

Pattern matching for sum (#293)

* the first stab at pattern for sum

test refactoring, debug msg clean up, formatting fixes

removing v1 and cleaning up v2 + formatting

rollback the changes in reduce_ops

rename v2 -> sum_pred

remove unused funcs

switch to new c-tors

remove TensorViewType

removing an assert

fix a docstring to match a c-tor

* fixes after rebase
parent c5ffe8e9
......@@ -157,6 +157,11 @@ namespace ngraph
template <typename T>
std::vector<T> get_vector() const
{
if (sizeof(T) > m_element_type.size() && shape_size(m_shape) > 0)
{
throw ngraph_error("Buffer over-read");
}
std::vector<T> rc;
const T* p = reinterpret_cast<const T*>(m_data);
for (size_t i = 0; i < shape_size(m_shape); i++)
......
......@@ -47,8 +47,8 @@ namespace ngraph
if (pattern_map[label] != graph_node)
{
NGRAPH_DEBUG << "[MATCHER] get_bound_node " << pattern_map[label]->get_name()
<< " , " << pattern_map[label] << " NOT match "
<< graph_node->get_name() << " , " << graph_node;
<< " , " << pattern_map[label] << " does NOT match "
<< graph_node->get_name();
is_match = false;
}
}
......@@ -71,9 +71,8 @@ namespace ngraph
if (is_match)
{
NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node "
<< graph_node->get_name() << " , " << graph_node << " , "
<< graph_node->get_name();
NGRAPH_DEBUG << "[MATCHER] (Re)binding get_bound_node " << label->get_name()
<< " , " << graph_node << " , " << graph_node->get_name();
pattern_map[label] = graph_node;
}
}
......@@ -105,8 +104,8 @@ namespace ngraph
assert(pattern_node && graph_node);
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_node : "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " "
<< "matched " << graph_node << " , " << graph_node->get_name();
<< "pattern = " << pattern_node->get_name() << " matched "
<< graph_node->get_name();
if (auto label_node = std::dynamic_pointer_cast<op::Label>(pattern_node))
{
......@@ -151,9 +150,9 @@ namespace ngraph
const std::shared_ptr<ngraph::Node>& graph_node,
PatternMap& pattern_map)
{
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] "
<< "pattern = " << pattern_node << " , " << pattern_node->get_name() << " "
<< "matched " << graph_node << " , " << graph_node->get_name();
NGRAPH_DEBUG << pad(2 * m_depth) << "[MATCHER] in match_arguments : "
<< "pattern = " << pattern_node->get_name() << " "
<< "matched " << graph_node->get_name();
auto args = get_arguments(graph_node);
auto pattern_args = get_arguments(pattern_node);
......@@ -171,7 +170,7 @@ namespace ngraph
do
{
NGRAPH_DEBUG << pad(2 * m_depth) << "Running a permutation for graph_node "
<< graph_node->get_name() << " , " << graph_node;
<< graph_node->get_name();
PatternMap copy{pattern_map};
if (match_permutation(pattern_args, args, copy))
{
......@@ -231,14 +230,10 @@ namespace ngraph
throw "m_pattern_node or graph_node are not set!";
}
if (get_users(m_pattern_node).size())
{
throw "Pattern Node must not be used elsewhere!";
}
(void)get_users; //to supress an unused function warning
NGRAPH_DEBUG << "Starting match pattern = " << m_pattern_node << " , "
<< m_pattern_node->get_name() << " , graph_node = " << graph_node << " , "
<< graph_node->get_name();
NGRAPH_DEBUG << "[MATCHER] Starting match pattern = " << m_pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();
bool is_match = match_node(m_pattern_node, graph_node, m_pattern_map);
if (is_match)
......
......@@ -30,43 +30,37 @@ namespace ngraph
class Label : public Pattern
{
public:
/// \brief creates a Label node from \sa node.
/// \brief creates a Label node containing a sub-pattern described by \sa type and \sa shape.
///
/// this Label node can be bound to arbitrary nodes in an input graph
/// as long as provided \sa pred is satisfied and the node hasn't been previously bound to
/// a different node in the input graph
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa wrapped_nodes
/// Example:
/// \code{.cpp}
/// auto pattern = pattern::op::Label::make_from_node(a); //a is op::Parameter
/// matcher.match(pattern, a));
/// auto add = a + b; //a and b are op::Parameter in this example
/// auto label = std::make_shared<pattern::op::Label>(element::f32, Shape{2,2} , nullptr, Nodes{add});
/// \endcode
static std::shared_ptr<Label>
make_from_node(const std::shared_ptr<ngraph::Node>& node,
Predicate pred = nullptr)
Label(const element::Type& type,
const Shape s,
Predicate pred = nullptr,
const Nodes& wrapped_nodes = Nodes{})
: Pattern("Label", wrapped_nodes, pred)
{
auto label = std::make_shared<Label>(Nodes{}, pred);
label->add_output(node->get_element_type(), node->get_shape());
return label;
add_output(type, s);
}
/// \brief creates a Label node containing a sub-pattern described by \sa node.
/// \brief creates a Label node containing a sub-pattern described by the type and shape of \sa node.
///
/// this Label node can be bound only to the nodes in the input graph
/// that match the pattern specified by \sa node
/// that match the pattern specified by \sa wrapped_nodes
/// Example:
/// \code{.cpp}
/// auto add = a + b; //a and b are op::Parameter in this example
/// auto label = pattern::op::Label::wrap(add);
/// auto label = std::make_shared<pattern::op::Label>(add, nullptr, Nodes{add});
/// \endcode
static std::shared_ptr<Label> wrap(const std::shared_ptr<ngraph::Node>& node,
Predicate pred = nullptr)
{
auto label = std::make_shared<Label>(Nodes{node}, pred);
label->add_output(node->get_element_type(), node->get_shape());
return label;
}
Label(const Nodes& subgraph, Predicate pred)
: Pattern("Label", Nodes{subgraph}, pred)
Label(std::shared_ptr<Node> node,
Predicate pred = nullptr,
const Nodes& wrapped_nodes = Nodes{})
: Label(node->get_element_type(), node->get_shape(), pred, wrapped_nodes)
{
}
};
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment