Commit af889535 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Update pattern ops to propagate partial shapes (#1986)

parent 4918449c
......@@ -31,7 +31,7 @@ namespace ngraph
public:
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa shape.
Any(const element::Type& type,
const Shape s,
const PartialShape& s,
Predicate pred,
const NodeVector& wrapped_nodes)
: Pattern("Any", wrapped_nodes, pred)
......@@ -45,7 +45,10 @@ namespace ngraph
/// \brief creates a Any node containing a sub-pattern described by the type and shape of \sa node.
Any(std::shared_ptr<Node> node, Predicate pred, const NodeVector& wrapped_nodes)
: Any(node->get_element_type(), node->get_shape(), pred, wrapped_nodes)
: Any(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
{
}
};
......
......@@ -41,7 +41,7 @@ namespace ngraph
/// auto label = std::make_shared<pattern::op::Label>(element::f32, Shape{2,2} , nullptr, NodeVector{add});
/// \endcode
Label(const element::Type& type,
const Shape s,
const PartialShape& s,
Predicate pred = nullptr,
const NodeVector& wrapped_nodes = NodeVector{})
: Pattern("Label", wrapped_nodes, pred)
......@@ -61,7 +61,10 @@ namespace ngraph
Label(std::shared_ptr<Node> node,
Predicate pred = nullptr,
const NodeVector& wrapped_nodes = NodeVector{})
: Label(node->get_element_type(), node->get_shape(), pred, wrapped_nodes)
: Label(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
{
}
};
......
......@@ -34,7 +34,7 @@ namespace ngraph
Skip(const std::shared_ptr<Node>& arg, Predicate predicate = nullptr)
: Pattern("Skip", NodeVector{arg}, predicate)
{
set_output_type(0, arg->get_element_type(), arg->get_shape());
set_output_type(0, arg->get_element_type(), arg->get_output_partial_shape(0));
}
};
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment