Commit 72493caf authored by baojun's avatar baojun Committed by Sang Ik Lee

Set output element type for case of dynamic input (#4071)

* set output type

* set output type
parent 415a852d
......@@ -132,12 +132,18 @@ shared_ptr<Node> op::PartialSlice::copy_with_new_args(const NodeVector& new_args
void op::PartialSlice::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape data_pshape = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
if (data_pshape.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
}
void op::PartialSlice::generate_adjoints(autodiff::Adjoints& adjoints, const OutputVector& deltas)
......@@ -222,11 +228,16 @@ shared_ptr<Node> op::PartialSliceBackprop::copy_with_new_args(const NodeVector&
void op::PartialSliceBackprop::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape data_pshape = get_input_partial_shape(0);
PartialShape delta_pshape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
if (data_pshape.is_dynamic() || delta_pshape.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
}
......@@ -133,6 +133,24 @@ NodeVector op::SoftmaxCrossEntropy::decompose_op() const
}
}
void op::SoftmaxCrossEntropy::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape data_pshape = get_input_partial_shape(0);
PartialShape labels_pshape = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
if (data_pshape.is_dynamic() || labels_pshape.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
}
shared_ptr<Node> op::SoftmaxCrossEntropy::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......@@ -157,12 +175,20 @@ op::SoftmaxCrossEntropyBackprop::SoftmaxCrossEntropyBackprop(const Output<Node>&
void op::SoftmaxCrossEntropyBackprop::pre_validate_and_infer_types()
{
element::Type input_element_type = get_input_element_type(0);
PartialShape delta_pshape = get_input_partial_shape(0);
PartialShape softmax_pshape = get_input_partial_shape(1);
PartialShape labels_pshape = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
input_element_type.is_dynamic() || input_element_type.is_real(),
"Argument element type must be f16, bf16, f32, f64 or dynamic (got ",
input_element_type,
").");
if (delta_pshape.is_dynamic() || softmax_pshape.is_dynamic() || labels_pshape.is_dynamic())
{
set_output_type(0, input_element_type, PartialShape::dynamic());
}
}
shared_ptr<Node>
......
......@@ -48,6 +48,8 @@ namespace ngraph
virtual NodeVector decompose_op() const override;
void pre_validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment