Commit 4fe2da61 authored by Ivan Tikhonov's avatar Ivan Tikhonov Committed by Scott Cyphers

Fix shape inference of TensorIterator body (#3922)

* fix for shape inference of tensor iterator body

* updated unit test for case end = -2

* indexes in unit tests

* Updated formula for num_iterations
parent ca9adeb1
...@@ -271,6 +271,14 @@ void op::TensorIterator::validate_and_infer_types() ...@@ -271,6 +271,14 @@ void op::TensorIterator::validate_and_infer_types()
std::vector<std::shared_ptr<Node>> ends; std::vector<std::shared_ptr<Node>> ends;
auto make_positive = [](int64_t value, uint64_t dim_size) -> int64_t {
if (value < 0)
{
value = dim_size + value;
}
return value;
};
// Input // Input
uint64_t index_it = 0; uint64_t index_it = 0;
for (auto input_description : m_input_descriptions) for (auto input_description : m_input_descriptions)
...@@ -285,41 +293,26 @@ void op::TensorIterator::validate_and_infer_types() ...@@ -285,41 +293,26 @@ void op::TensorIterator::validate_and_infer_types()
m_body->get_parameters().at(slice_input_description->m_body_parameter_index); m_body->get_parameters().at(slice_input_description->m_body_parameter_index);
auto body_param_partial_shape = body_parameter->get_partial_shape(); auto body_param_partial_shape = body_parameter->get_partial_shape();
auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape(); auto input_partial_shape = inputs().at(index).get_source_output().get_partial_shape();
auto start = slice_input_description->m_start; if (input_partial_shape.is_static())
auto part_size = slice_input_description->m_part_size;
auto end = slice_input_description->m_end;
if (end != -1)
{ {
auto input_shape = input_partial_shape.to_shape();
auto axis = slice_input_description->m_axis;
auto part_size = slice_input_description->m_part_size;
auto dim_size = input_shape[axis];
auto start = make_positive(slice_input_description->m_start, dim_size);
auto end = make_positive(slice_input_description->m_end, dim_size);
if (m_num_iterations == -1) if (m_num_iterations == -1)
{ {
m_num_iterations = end - start; // +1 because the left and right borders are included [start, end]
m_num_iterations = (abs(end - start) + 1) / part_size;
} }
else else
{ {
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(this,
this, m_num_iterations == end - start, "Number of slices not the same"); m_num_iterations == (abs(end - start) + 1) / part_size,
} "Number of slices not the same");
}
if (input_partial_shape.is_static())
{
auto input_shape = input_partial_shape.to_shape();
auto axis = slice_input_description->m_axis;
if (end == -1)
{
// for simple RNN case where stride is the same as part_size
// when end is -1, we assume that we slice the input from "start" to the very
// end.
end = static_cast<size_t>(input_shape[axis]) / part_size + start;
if (m_num_iterations == -1)
{
m_num_iterations = end - start;
}
else
{
NODE_VALIDATION_CHECK(
this, m_num_iterations == end - start, "Number of slices not the same");
}
} }
if (body_param_partial_shape.is_static()) if (body_param_partial_shape.is_static())
...@@ -421,23 +414,10 @@ void op::TensorIterator::validate_and_infer_types() ...@@ -421,23 +414,10 @@ void op::TensorIterator::validate_and_infer_types()
if (body_value_partial_shape.is_static()) if (body_value_partial_shape.is_static())
{ {
auto body_value_shape = body_value_partial_shape.to_shape(); auto body_value_shape = body_value_partial_shape.to_shape();
auto start = concat_output_description->m_start;
auto part_size = concat_output_description->m_part_size; auto part_size = concat_output_description->m_part_size;
auto end = concat_output_description->m_end;
auto axis = concat_output_description->m_axis; auto axis = concat_output_description->m_axis;
Shape out_shape{body_value_shape}; Shape out_shape{body_value_shape};
if (end != -1)
{
if (m_num_iterations != -1)
{
NODE_VALIDATION_CHECK(
this, m_num_iterations == end - start, "Number of slices not the same");
}
else
{
m_num_iterations = end - start;
}
}
if (m_num_iterations != -1) if (m_num_iterations != -1)
{ {
// for simple RNN case where stride is the same as part_size // for simple RNN case where stride is the same as part_size
......
...@@ -486,8 +486,8 @@ TEST(serialize, tensor_iterator_raw) ...@@ -486,8 +486,8 @@ TEST(serialize, tensor_iterator_raw)
auto tensor_iterator = make_shared<op::TensorIterator>(); auto tensor_iterator = make_shared<op::TensorIterator>();
tensor_iterator->set_body(body); tensor_iterator->set_body(body);
// The Xi are the elements of Xseq // The Xi are the elements of Xseq
// start=0, stride=1, part_size=1, end=40, axis=1 // start=0, stride=1, part_size=1, end=39, axis=1
tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, 40, 1); tensor_iterator->set_sliced_input(Xi, X, 0, 1, 1, 39, 1);
// Hi is Hinit on the first iteration, Ho after that // Hi is Hinit on the first iteration, Ho after that
tensor_iterator->set_merged_input(Hi, Hinit, Ho); tensor_iterator->set_merged_input(Hi, Hinit, Ho);
tensor_iterator->set_invariant_input(WH_body, WH); tensor_iterator->set_invariant_input(WH_body, WH);
...@@ -499,8 +499,8 @@ TEST(serialize, tensor_iterator_raw) ...@@ -499,8 +499,8 @@ TEST(serialize, tensor_iterator_raw)
// Output 0 is last Yo // Output 0 is last Yo
auto out0 = tensor_iterator->get_iter_value(Yo, -1); auto out0 = tensor_iterator->get_iter_value(Yo, -1);
// Output 1 is concat of hidden states // Output 1 is concat of hidden states
// start=0, stride=1, part_size=1, end=40, axis=1 // start=0, stride=1, part_size=1, end=39, axis=1
auto out1 = tensor_iterator->get_concatenated_slices(Ho, 0, 1, 1, 40, 1); auto out1 = tensor_iterator->get_concatenated_slices(Ho, 0, 1, 1, 39, 1);
auto results = ResultVector{make_shared<op::Result>(out0), make_shared<op::Result>(out1)}; auto results = ResultVector{make_shared<op::Result>(out0), make_shared<op::Result>(out1)};
auto f = make_shared<Function>(results, ParameterVector{X, Hinit, WH, WX, bH, WY, bY}); auto f = make_shared<Function>(results, ParameterVector{X, Hinit, WH, WX, bH, WY, bY});
...@@ -543,7 +543,7 @@ TEST(serialize, tensor_iterator_lstm) ...@@ -543,7 +543,7 @@ TEST(serialize, tensor_iterator_lstm)
auto tensor_iterator = make_shared<op::TensorIterator>(); auto tensor_iterator = make_shared<op::TensorIterator>();
tensor_iterator->set_body(body); tensor_iterator->set_body(body);
// start=0, stride=1, part_size=1, end=40, axis=1 // start=0, stride=1, part_size=1, end=39, axis=1
tensor_iterator->set_sliced_input(X, SENT, 0, 1, 1, -1, 1); tensor_iterator->set_sliced_input(X, SENT, 0, 1, 1, -1, 1);
// H_t is Hinit on the first iteration, Ho after that // H_t is Hinit on the first iteration, Ho after that
tensor_iterator->set_merged_input(H_t, H_init, H_o); tensor_iterator->set_merged_input(H_t, H_init, H_o);
...@@ -583,8 +583,8 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2) ...@@ -583,8 +583,8 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2)
auto tensor_iterator = make_shared<op::TensorIterator>(); auto tensor_iterator = make_shared<op::TensorIterator>();
tensor_iterator->set_body(body); tensor_iterator->set_body(body);
// The Xi are the elements of Xseq // The Xi are the elements of Xseq
// start=0, stride=2, part_size=2, end=20, axis=1 // start=0, stride=2, part_size=2, end=39, axis=1
tensor_iterator->set_sliced_input(Xi, X, 0, 2, 2, 20, 1); tensor_iterator->set_sliced_input(Xi, X, 0, 2, 2, 39, 1);
// The Yi are the elements of Yseq // The Yi are the elements of Yseq
// start=0, stride=2, part_size=2, end=-1, axis=1 // start=0, stride=2, part_size=2, end=-1, axis=1
tensor_iterator->set_sliced_input(Yi, Y, 0, 2, 2, -1, 1); tensor_iterator->set_sliced_input(Yi, Y, 0, 2, 2, -1, 1);
...@@ -593,8 +593,8 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2) ...@@ -593,8 +593,8 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2)
// Output 0 is last Zo // Output 0 is last Zo
auto out0 = tensor_iterator->get_iter_value(Zo, -1); auto out0 = tensor_iterator->get_iter_value(Zo, -1);
// Output 1 is concat of Zos // Output 1 is concat of Zos
// start=0, stride=2, part_size=2, end=20, axis=1 // start=0, stride=2, part_size=2, end=39, axis=1
auto out1 = tensor_iterator->get_concatenated_slices(Zo, 0, 2, 2, 20, 1); auto out1 = tensor_iterator->get_concatenated_slices(Zo, 0, 2, 2, 39, 1);
auto result0 = make_shared<op::Result>(out0); auto result0 = make_shared<op::Result>(out0);
auto result1 = make_shared<op::Result>(out1); auto result1 = make_shared<op::Result>(out1);
...@@ -631,11 +631,11 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2_dynamic) ...@@ -631,11 +631,11 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2_dynamic)
auto tensor_iterator = make_shared<op::TensorIterator>(); auto tensor_iterator = make_shared<op::TensorIterator>();
tensor_iterator->set_body(body); tensor_iterator->set_body(body);
// The Xi are the elements of Xseq // The Xi are the elements of Xseq
// start=0, stride=2, part_size=2, end=20, axis=1 // start=0, stride=2, part_size=2, end=38, axis=1
tensor_iterator->set_sliced_input(Xi, X, 0, 2, 2, 20, 1); tensor_iterator->set_sliced_input(Xi, X, 0, 2, 2, 38, 1);
// The Yi are the elements of Yseq // The Yi are the elements of Yseq
// start=0, stride=2, part_size=2, end=-1, axis=1 // start=0, stride=2, part_size=2, end=-2, axis=1
tensor_iterator->set_sliced_input(Yi, Y, 0, 2, 2, -1, 1); tensor_iterator->set_sliced_input(Yi, Y, 0, 2, 2, -2, 1);
tensor_iterator->set_invariant_input(M_body, M); tensor_iterator->set_invariant_input(M_body, M);
// check input descriptors // check input descriptors
...@@ -663,8 +663,8 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2_dynamic) ...@@ -663,8 +663,8 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2_dynamic)
// Output 0 is last Zo // Output 0 is last Zo
auto out0 = tensor_iterator->get_iter_value(Zo, -1); auto out0 = tensor_iterator->get_iter_value(Zo, -1);
// Output 1 is concat of Zos // Output 1 is concat of Zos
// start=0, stride=2, part_size=2, end=20, axis=1 // start=0, stride=2, part_size=2, end=38, axis=1
auto out1 = tensor_iterator->get_concatenated_slices(Zo, 0, 2, 2, 20, 1); auto out1 = tensor_iterator->get_concatenated_slices(Zo, 0, 2, 2, 38, 1);
// check output descriptors // check output descriptors
for (auto& desc : tensor_iterator->get_output_descriptions()) for (auto& desc : tensor_iterator->get_output_descriptions())
...@@ -686,7 +686,7 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2_dynamic) ...@@ -686,7 +686,7 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2_dynamic)
auto result0 = make_shared<op::Result>(out0); auto result0 = make_shared<op::Result>(out0);
auto result1 = make_shared<op::Result>(out1); auto result1 = make_shared<op::Result>(out1);
Shape out0_shape{32, 2, 10}; Shape out0_shape{32, 2, 10};
Shape out1_shape{32, 40, 10}; Shape out1_shape{32, 38, 10};
auto results = ResultVector{result0, result1}; auto results = ResultVector{result0, result1};
auto f = make_shared<Function>(results, ParameterVector{X, Y, M}); auto f = make_shared<Function>(results, ParameterVector{X, Y, M});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment