Commit af889535 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Update pattern ops to propagate partial shapes (#1986)

parent 4918449c
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
public: public:
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa shape. /// \brief creates a Any node containing a sub-pattern described by \sa type and \sa shape.
Any(const element::Type& type, Any(const element::Type& type,
const Shape s, const PartialShape& s,
Predicate pred, Predicate pred,
const NodeVector& wrapped_nodes) const NodeVector& wrapped_nodes)
: Pattern("Any", wrapped_nodes, pred) : Pattern("Any", wrapped_nodes, pred)
...@@ -45,7 +45,10 @@ namespace ngraph ...@@ -45,7 +45,10 @@ namespace ngraph
/// \brief creates a Any node containing a sub-pattern described by the type and shape of \sa node. /// \brief creates a Any node containing a sub-pattern described by the type and shape of \sa node.
Any(std::shared_ptr<Node> node, Predicate pred, const NodeVector& wrapped_nodes) Any(std::shared_ptr<Node> node, Predicate pred, const NodeVector& wrapped_nodes)
: Any(node->get_element_type(), node->get_shape(), pred, wrapped_nodes) : Any(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
{ {
} }
}; };
......
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
/// auto label = std::make_shared<pattern::op::Label>(element::f32, Shape{2,2} , nullptr, NodeVector{add}); /// auto label = std::make_shared<pattern::op::Label>(element::f32, Shape{2,2} , nullptr, NodeVector{add});
/// \endcode /// \endcode
Label(const element::Type& type, Label(const element::Type& type,
const Shape s, const PartialShape& s,
Predicate pred = nullptr, Predicate pred = nullptr,
const NodeVector& wrapped_nodes = NodeVector{}) const NodeVector& wrapped_nodes = NodeVector{})
: Pattern("Label", wrapped_nodes, pred) : Pattern("Label", wrapped_nodes, pred)
...@@ -61,7 +61,10 @@ namespace ngraph ...@@ -61,7 +61,10 @@ namespace ngraph
Label(std::shared_ptr<Node> node, Label(std::shared_ptr<Node> node,
Predicate pred = nullptr, Predicate pred = nullptr,
const NodeVector& wrapped_nodes = NodeVector{}) const NodeVector& wrapped_nodes = NodeVector{})
: Label(node->get_element_type(), node->get_shape(), pred, wrapped_nodes) : Label(node->get_element_type(),
node->get_output_partial_shape(0),
pred,
wrapped_nodes)
{ {
} }
}; };
......
...@@ -34,7 +34,7 @@ namespace ngraph ...@@ -34,7 +34,7 @@ namespace ngraph
Skip(const std::shared_ptr<Node>& arg, Predicate predicate = nullptr) Skip(const std::shared_ptr<Node>& arg, Predicate predicate = nullptr)
: Pattern("Skip", NodeVector{arg}, predicate) : Pattern("Skip", NodeVector{arg}, predicate)
{ {
set_output_type(0, arg->get_element_type(), arg->get_shape()); set_output_type(0, arg->get_element_type(), arg->get_output_partial_shape(0));
} }
}; };
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment