Unverified Commit 1974a90d authored by Evgenya Stepyreva's avatar Evgenya Stepyreva Committed by GitHub

Setting outputs for dynamic v1::Split (#4256)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent e37b386e
...@@ -135,13 +135,11 @@ void op::v1::Split::validate_and_infer_types() ...@@ -135,13 +135,11 @@ void op::v1::Split::validate_and_infer_types()
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
this, axis_et.is_integral(), "The 'axis' input only accepts integral types"); this, axis_et.is_integral(), "The 'axis' input only accepts integral types");
if (input_value(1).get_node_shared_ptr()->is_constant()) if (input_value(1).get_node_shared_ptr()->is_constant() && data_ps.is_static())
{ {
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr()); const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis = axis_input->cast_vector<int64_t>()[0]; auto axis = axis_input->cast_vector<int64_t>()[0];
if (data_ps.is_static())
{
const auto data_shape = data_ps.to_shape(); const auto data_shape = data_ps.to_shape();
axis = ngraph::normalize_axis(this, axis, data_shape.size()); axis = ngraph::normalize_axis(this, axis, data_shape.size());
...@@ -162,7 +160,6 @@ void op::v1::Split::validate_and_infer_types() ...@@ -162,7 +160,6 @@ void op::v1::Split::validate_and_infer_types()
set_output_type(i, input(0).get_element_type(), each_output_shape); set_output_type(i, input(0).get_element_type(), each_output_shape);
} }
} }
}
else else
{ {
for (size_t i = 0; i < m_num_splits; ++i) for (size_t i = 0; i < m_num_splits; ++i)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment