Commit 10195034 authored by baojun's avatar baojun Committed by Scott Cyphers

Resolve shape prop issue (#4080)

* use pre_validate function

* use pre_validate
parent 2cbe42c8
...@@ -63,15 +63,12 @@ shared_ptr<Node> LayoutConverter::copy_with_new_args(const NodeVector& new_args) ...@@ -63,15 +63,12 @@ shared_ptr<Node> LayoutConverter::copy_with_new_args(const NodeVector& new_args)
return make_shared<LayoutConverter>(new_args.at(0), get_mode()); return make_shared<LayoutConverter>(new_args.at(0), get_mode());
} }
void LayoutConverter::validate_and_infer_types() void LayoutConverter::pre_validate_and_infer_types()
{ {
auto shape = get_input_partial_shape(0); auto shape = get_input_partial_shape(0);
if (shape.is_dynamic()) if (shape.is_dynamic())
{ {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
} }
else
{
FusedOp::validate_and_infer_types();
}
} }
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -125,17 +125,14 @@ shared_ptr<Node> Pool::copy_with_new_args(const NodeVector& new_args) const ...@@ -125,17 +125,14 @@ shared_ptr<Node> Pool::copy_with_new_args(const NodeVector& new_args) const
get_pooling_type()); get_pooling_type());
} }
void Pool::validate_and_infer_types() void Pool::pre_validate_and_infer_types()
{ {
auto shape = get_input_partial_shape(0); auto shape = get_input_partial_shape(0);
if (shape.is_dynamic()) if (shape.is_dynamic())
{ {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
} }
else
{
FusedOp::validate_and_infer_types();
}
} }
constexpr NodeTypeInfo PoolGrad::type_info; constexpr NodeTypeInfo PoolGrad::type_info;
...@@ -162,18 +159,13 @@ PoolGrad::PoolGrad(const Output<Node>& x, ...@@ -162,18 +159,13 @@ PoolGrad::PoolGrad(const Output<Node>& x,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void PoolGrad::validate_and_infer_types() void PoolGrad::pre_validate_and_infer_types()
{ {
auto shape = get_input_partial_shape(0);
if (get_input_partial_shape(0).is_dynamic() || get_input_partial_shape(1).is_dynamic() || if (get_input_partial_shape(0).is_dynamic() || get_input_partial_shape(1).is_dynamic() ||
get_input_partial_shape(2).is_dynamic()) get_input_partial_shape(2).is_dynamic())
{ {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
} }
else
{
FusedOp::validate_and_infer_types();
}
} }
shared_ptr<Node> PoolGrad::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> PoolGrad::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -49,7 +49,7 @@ namespace ngraph ...@@ -49,7 +49,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -103,7 +103,7 @@ namespace ngraph ...@@ -103,7 +103,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -71,16 +71,13 @@ NodeVector ReduceSum::decompose_op() const ...@@ -71,16 +71,13 @@ NodeVector ReduceSum::decompose_op() const
return retval; return retval;
} }
void ReduceSum::validate_and_infer_types() void ReduceSum::pre_validate_and_infer_types()
{ {
auto shape = get_input_partial_shape(0); auto shape = get_input_partial_shape(0);
if (shape.is_dynamic()) if (shape.is_dynamic())
{ {
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
else
{
FusedOp::validate_and_infer_types();
} }
} }
...@@ -161,16 +158,13 @@ NodeVector ReduceSumGrad::decompose_op() const ...@@ -161,16 +158,13 @@ NodeVector ReduceSumGrad::decompose_op() const
return retval; return retval;
} }
void ReduceSumGrad::validate_and_infer_types() void ReduceSumGrad::pre_validate_and_infer_types()
{ {
auto shape = get_input_partial_shape(0); auto shape = get_input_partial_shape(0);
if (shape.is_dynamic()) if (shape.is_dynamic())
{ {
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
else
{
FusedOp::validate_and_infer_types();
} }
} }
......
...@@ -44,7 +44,7 @@ namespace ngraph ...@@ -44,7 +44,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -75,7 +75,7 @@ namespace ngraph ...@@ -75,7 +75,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment