Commit 72493caf authored by baojun's avatar baojun Committed by Sang Ik Lee

Set output element type for case of dynamic input (#4071)

* set output type

* set output type
parent 415a852d
...@@ -132,12 +132,18 @@ shared_ptr<Node> op::PartialSlice::copy_with_new_args(const NodeVector& new_args ...@@ -132,12 +132,18 @@ shared_ptr<Node> op::PartialSlice::copy_with_new_args(const NodeVector& new_args
void op::PartialSlice::pre_validate_and_infer_types() void op::PartialSlice::pre_validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
PartialShape data_pshape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(), input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ", "Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type, input_element_type,
")."); ").");
if (data_pshape.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
} }
void op::PartialSlice::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas) void op::PartialSlice::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas)
...@@ -222,11 +228,16 @@ shared_ptr<Node> op::PartialSliceBackprop::copy_with_new_args(const NodeVector& ...@@ -222,11 +228,16 @@ shared_ptr<Node> op::PartialSliceBackprop::copy_with_new_args(const NodeVector&
void op::PartialSliceBackprop::pre_validate_and_infer_types() void op::PartialSliceBackprop::pre_validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
PartialShape data_pshape = get_input_partial_shape(0);
PartialShape delta_pshape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(), input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ", "Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type, input_element_type,
")."); ").");
set_output_type(0, get_input_element_type(0), PartialShape::dynamic()); if (data_pshape.is_dynamic() || delta_pshape.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
} }
...@@ -133,6 +133,24 @@ NodeVector op::SoftmaxCrossEntropy::decompose_op() const ...@@ -133,6 +133,24 @@ NodeVector op::SoftmaxCrossEntropy::decompose_op() const
} }
} }
void op::SoftmaxCrossEntropy::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape data_pshape = get_input_partial_shape(0);
PartialShape labels_pshape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
if (data_pshape.is_dynamic() || labels_pshape.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
}
shared_ptr<Node> op::SoftmaxCrossEntropy::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::SoftmaxCrossEntropy::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
...@@ -157,12 +175,20 @@ op::SoftmaxCrossEntropyBackprop::SoftmaxCrossEntropyBackprop(const Output<Node>& ...@@ -157,12 +175,20 @@ op::SoftmaxCrossEntropyBackprop::SoftmaxCrossEntropyBackprop(const Output<Node>&
void op::SoftmaxCrossEntropyBackprop::pre_validate_and_infer_types() void op::SoftmaxCrossEntropyBackprop::pre_validate_and_infer_types()
{ {
element::Type input_element_type = get_input_element_type(0); element::Type input_element_type = get_input_element_type(0);
PartialShape delta_pshape = get_input_partial_shape(0);
PartialShape softmax_pshape = get_input_partial_shape(1);
PartialShape labels_pshape = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(), input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ", "Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type, input_element_type,
")."); ").");
if (delta_pshape.is_dynamic() || softmax_pshape.is_dynamic() || labels_pshape.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
} }
shared_ptr<Node> shared_ptr<Node>
......
...@@ -48,6 +48,8 @@ namespace ngraph ...@@ -48,6 +48,8 @@ namespace ngraph
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment