Commit e786fcfe authored by Ivan Tikhonov's avatar Ivan Tikhonov Committed by Michał Karzyński

TensorIterator: reshape support (#4038)

parent c8988ca9
......@@ -16,6 +16,8 @@
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/pass/get_output_element_elimination.hpp"
#include "ngraph/specialize_function.hpp"
using namespace std;
using namespace ngraph;
......@@ -220,7 +222,7 @@ void op::TensorIterator::revalidate_and_infer_types_for_body_ops()
std::stack<std::shared_ptr<Node>, std::vector<std::shared_ptr<Node>>> nodes_to_do;
std::unordered_set<std::shared_ptr<Node>> nodes_done;
for (auto r : m_body->get_results())
for (const auto& r : m_body->get_results())
{
nodes_to_do.push(r);
}
......@@ -281,7 +283,7 @@ void op::TensorIterator::validate_and_infer_types()
// Input
uint64_t index_it = 0;
for (auto input_description : m_input_descriptions)
for (const auto& input_description : m_input_descriptions)
{
auto index = input_description->m_input_index;
NODE_VALIDATION_CHECK(this, index == index_it, "Input_index not in order");
......@@ -398,7 +400,7 @@ void op::TensorIterator::validate_and_infer_types()
// Output
index_it = 0;
for (auto output_description : m_output_descriptions)
for (const auto& output_description : m_output_descriptions)
{
auto index = output_description->m_output_index;
NODE_VALIDATION_CHECK(this, index == index_it, "Output_index not in order");
......@@ -437,6 +439,48 @@ void op::TensorIterator::validate_and_infer_types()
std::shared_ptr<Node> op::TensorIterator::copy_with_new_args(const NodeVector& new_args) const
{
auto op = make_shared<op::TensorIterator>(as_output_vector(new_args));
op->set_output_size(m_output_descriptions.size());
std::vector<::ngraph::element::Type> types(m_body->get_parameters().size());
std::vector<::ngraph::PartialShape> new_shapes(m_body->get_parameters().size());
for (size_t input_index = 0; input_index < new_args.size(); ++input_index)
{
for (auto& input_description : m_input_descriptions)
{
if (input_description->m_input_index == input_index)
{
types[input_description->m_body_parameter_index] =
new_args[input_index]->get_element_type();
new_shapes[input_description->m_body_parameter_index] =
new_args[input_index]->get_output_partial_shape(0);
if (new_shapes[input_description->m_body_parameter_index].is_static())
{
if (auto slice_in = ::ngraph::as_type_ptr<
ngraph::op::TensorIterator::SliceInputDescription>(input_description))
{
new_shapes[slice_in->m_body_parameter_index][slice_in->m_axis] =
slice_in->m_part_size;
}
}
}
}
}
auto func = std::make_shared<Function>(m_body->get_results(), m_body->get_parameters());
auto spec_func = specialize_function(
func, types, new_shapes, std::vector<void*>(new_args.size(), nullptr), false, true);
op->m_body =
std::make_shared<BodyLambda>(spec_func->get_results(), spec_func->get_parameters());
// TODO: remove this code after the fix on the nGraph side (GetOutputElements)
::ngraph::pass::GetOutputElementElimination goe_elimination;
for (const auto& n : spec_func->get_ops())
{
goe_elimination.run_on_node(n);
}
for (auto& input_description : m_input_descriptions)
{
op->m_input_descriptions.push_back(input_description->copy());
......
......@@ -17,6 +17,7 @@
#include "ngraph/specialize_function.hpp"
#include <pass/constant_folding.hpp>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/tensor_iterator.hpp"
using namespace ngraph;
......@@ -84,6 +85,11 @@ std::shared_ptr<Function>
else
{
m[old_node.get()] = old_node->copy_with_new_inputs(new_args);
// TODO: workaround for shape inference, delete it after fix
if (::ngraph::as_type_ptr<ngraph::op::TensorIterator>(m[old_node.get()]))
{
m[old_node.get()]->validate_and_infer_types();
}
m[old_node.get()]->get_rt_info() = old_node->get_rt_info();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment