Commit 9e7ad170 authored by baojun's avatar baojun Committed by Scott Cyphers

Update fused ops groupconvolution, gelu and layernorm to be dynamic friendly (#3876)

* set output et

* set output et

* overwrote validate and infer
parent bc448701
...@@ -66,12 +66,18 @@ shared_ptr<Node> op::Gelu::copy_with_new_args(const NodeVector& new_args) const ...@@ -66,12 +66,18 @@ shared_ptr<Node> op::Gelu::copy_with_new_args(const NodeVector& new_args) const
void op::Gelu::pre_validate_and_infer_types() void op::Gelu::pre_validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
PartialShape input_pshape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(), input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ", "Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type, input_element_type,
")."); ").");
if (input_pshape.is_dynamic())
{
set_output_type(0, input_element_type, input_pshape);
}
} }
void op::Gelu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::Gelu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
...@@ -94,12 +100,18 @@ op::GeluBackpropFactor::GeluBackpropFactor(const Output<Node>& x) ...@@ -94,12 +100,18 @@ op::GeluBackpropFactor::GeluBackpropFactor(const Output<Node>& x)
void op::GeluBackpropFactor::pre_validate_and_infer_types() void op::GeluBackpropFactor::pre_validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
PartialShape input_pshape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(), input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ", "Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type, input_element_type,
")."); ").");
if (input_pshape.is_dynamic())
{
set_output_type(0, input_element_type, input_pshape);
}
} }
shared_ptr<Node> op::GeluBackpropFactor::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GeluBackpropFactor::copy_with_new_args(const NodeVector& new_args) const
......
...@@ -100,6 +100,10 @@ void op::GroupConvolution::pre_validate_and_infer_types() ...@@ -100,6 +100,10 @@ void op::GroupConvolution::pre_validate_and_infer_types()
get_groups()) == data_shape.to_shape()[1], get_groups()) == data_shape.to_shape()[1],
"Incorrect number of channels per filter"); "Incorrect number of channels per filter");
} }
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
} }
void op::GroupConvolution::post_validate_and_infer_types() void op::GroupConvolution::post_validate_and_infer_types()
......
...@@ -170,7 +170,7 @@ shared_ptr<Node> op::LayerNorm::copy_with_new_args(const NodeVector& new_args) c ...@@ -170,7 +170,7 @@ shared_ptr<Node> op::LayerNorm::copy_with_new_args(const NodeVector& new_args) c
} }
} }
void op::LayerNorm::pre_validate_and_infer_types() void op::LayerNorm::validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
...@@ -509,7 +509,7 @@ shared_ptr<Node> op::LayerNormBackprop::copy_with_new_args(const NodeVector& new ...@@ -509,7 +509,7 @@ shared_ptr<Node> op::LayerNormBackprop::copy_with_new_args(const NodeVector& new
} }
} }
void op::LayerNormBackprop::pre_validate_and_infer_types() void op::LayerNormBackprop::validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
......
...@@ -55,7 +55,7 @@ namespace ngraph ...@@ -55,7 +55,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -121,7 +121,7 @@ namespace ngraph ...@@ -121,7 +121,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment