Commit 1c2cd853 authored by Ivan Tikhonov's avatar Ivan Tikhonov Committed by Scott Cyphers

Fix for the bug with as_type_ptr for TensorIterator::Input/Ouput desc (#3906)

* Updated unit test to reproduce a bug

* Code style

* Add exports

* Added missed export
parent f7dc9104
......@@ -44,6 +44,7 @@ namespace ngraph
class BodyLambda : public Lambda
{
public:
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"BodyLamdba", 0};
const DiscreteTypeInfo& get_type_info() const { return type_info; }
BodyLambda(const OutputVector& outputs, const ParameterVector& parameters)
......@@ -78,6 +79,7 @@ namespace ngraph
class SliceInputDescription : public InputDescription
{
public:
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"SliceInputDescription", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
/// \param input_index Position of the TensorIterator input
......@@ -108,6 +110,7 @@ namespace ngraph
class MergedInputDescription : public InputDescription
{
public:
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"MergedInputDescription", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
/// \param input_index Position of the TensorIterator input supplying a value to
......@@ -127,6 +130,7 @@ namespace ngraph
class InvariantInputDescription : public InputDescription
{
public:
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"InvariantInputDescription", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
......@@ -158,6 +162,7 @@ namespace ngraph
class ConcatOutputDescription : public OutputDescription
{
public:
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"ConcatOutputDescription", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
/// \param body_value_index A body value that produces the output
......@@ -188,6 +193,7 @@ namespace ngraph
class BodyOutputDescription : public OutputDescription
{
public:
NGRAPH_API
static constexpr DiscreteTypeInfo type_info{"BodyOutputDescription", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
/// \param body_value_index A body value that produces the output
......
......@@ -638,12 +638,51 @@ TEST(serialize, tensor_iterator_2_slice_inputs_part_size_2_dynamic)
tensor_iterator->set_sliced_input(Yi, Y, 0, 2, 2, -1, 1);
tensor_iterator->set_invariant_input(M_body, M);
// check input descriptors
for (auto& desc : tensor_iterator->get_input_descriptions())
{
auto type_info = desc->get_type_info();
if (std::strcmp(type_info.name, "InvariantInputDescription") == 0)
{
auto input_desc =
as_type_ptr<ngraph::op::TensorIterator::InvariantInputDescription>(desc);
EXPECT_NE(input_desc, nullptr);
}
else if (std::strcmp(type_info.name, "SliceInputDescription") == 0)
{
auto input_desc = as_type_ptr<ngraph::op::TensorIterator::SliceInputDescription>(desc);
EXPECT_NE(input_desc, nullptr);
}
else if (std::strcmp(type_info.name, "MergedInputDescription") == 0)
{
auto input_desc = as_type_ptr<ngraph::op::TensorIterator::MergedInputDescription>(desc);
EXPECT_NE(input_desc, nullptr);
}
}
// Output 0 is last Zo
auto out0 = tensor_iterator->get_iter_value(Zo, -1);
// Output 1 is concat of Zos
// start=0, stride=2, part_size=2, end=20, axis=1
auto out1 = tensor_iterator->get_concatenated_slices(Zo, 0, 2, 2, 20, 1);
// check output descriptors
for (auto& desc : tensor_iterator->get_output_descriptions())
{
auto type_info = desc->get_type_info();
if (std::strcmp(type_info.name, "ConcatOutputDescription") == 0)
{
auto output_desc =
as_type_ptr<ngraph::op::TensorIterator::ConcatOutputDescription>(desc);
EXPECT_NE(output_desc, nullptr);
}
else if (std::strcmp(type_info.name, "BodyOutputDescription") == 0)
{
auto output_desc = as_type_ptr<ngraph::op::TensorIterator::BodyOutputDescription>(desc);
EXPECT_NE(output_desc, nullptr);
}
}
auto result0 = make_shared<op::Result>(out0);
auto result1 = make_shared<op::Result>(out1);
Shape out0_shape{32, 2, 10};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment