Commit 9e7ad170 authored by baojun's avatar baojun Committed by Scott Cyphers

Update fused ops groupconvolution, gelu and layernorm to be dynamic friendly (#3876)

* set output et

* set output et

* overwrote validate and infer
parent bc448701
......@@ -66,12 +66,18 @@ shared_ptr<Node> op::Gelu::copy_with_new_args(const NodeVector& new_args) const
void op::Gelu::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape input_pshape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
if (input_pshape.is_dynamic())
{
set_output_type(0, input_element_type, input_pshape);
}
}
void op::Gelu::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......@@ -94,12 +100,18 @@ op::GeluBackpropFactor::GeluBackpropFactor(const Output<Node>& x)
void op::GeluBackpropFactor::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape input_pshape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
if (input_pshape.is_dynamic())
{
set_output_type(0, input_element_type, input_pshape);
}
}
shared_ptr<Node> op::GeluBackpropFactor::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -100,6 +100,10 @@ void op::GroupConvolution::pre_validate_and_infer_types()
get_groups()) == data_shape.to_shape()[1],
"Incorrect number of channels per filter");
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
}
}
void op::GroupConvolution::post_validate_and_infer_types()
......
......@@ -170,7 +170,7 @@ shared_ptr<Node> op::LayerNorm::copy_with_new_args(const NodeVector& new_args) c
}
}
void op::LayerNorm::pre_validate_and_infer_types()
void op::LayerNorm::validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
......@@ -509,7 +509,7 @@ shared_ptr<Node> op::LayerNormBackprop::copy_with_new_args(const NodeVector& new
}
}
void op::LayerNormBackprop::pre_validate_and_infer_types()
void op::LayerNormBackprop::validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
......
......@@ -55,7 +55,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -121,7 +121,7 @@ namespace ngraph
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment