Commit 36473a8a authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Inplace results (#1162)

* inplace results

* fix parameter propagation

* fix python tests
parent f078800c
......@@ -659,7 +659,10 @@ using namespace ngraph::runtime;
{
shared_ptr<descriptor::TensorView> itv =
res->get_inputs().at(0).get_output().get_tensor_view();
auto output_name = ss.str();
m_variable_name_map[itv->get_tensor().get_name()] = ss.str();
propagate_in_place_output(&(res->get_inputs().at(0).get_output()), output_name);
}
}
......@@ -954,6 +957,43 @@ using namespace ngraph::runtime;
}
}
void runtime::cpu::CPU_ExternalFunction::propagate_in_place_output(
ngraph::descriptor::Output* res_src_output, std::string output_name)
{
//we start with a particular output
//which is an argument to a given op::Result
size_t offset = res_src_output->get_tensor().get_pool_offset();
auto it = res_src_output;
bool propagate_further = false;
do
{
propagate_further = false;
auto arg = std::dynamic_pointer_cast<ngraph::op::Op>(it->get_node());
if (!arg)
{
break;
}
if (auto op_annotations = arg->get_op_annotations())
{
auto oi_pairs = op_annotations->get_in_place_oi_pairs();
if (oi_pairs.count(it->get_index()) != 0)
{
size_t input_index = oi_pairs.at(it->get_index());
auto& input_tensor = arg->get_inputs().at(input_index).get_tensor();
if (input_tensor.get_pool_offset() == offset &&
!arg->get_inputs().at(input_index).get_output().get_node()->is_parameter())
{
NGRAPH_DEBUG << "Reusing " << output_name << " for " << input_tensor.get_name();
m_variable_name_map[input_tensor.get_name()] = output_name;
it = &arg->get_inputs().at(input_index).get_output();
propagate_further = true;
}
}
}
} while (propagate_further);
}
void runtime::cpu::CPU_ExternalFunction::build()
{
if (m_is_built)
......
......@@ -111,6 +111,8 @@ namespace ngraph
void compile();
private:
void propagate_in_place_output(ngraph::descriptor::Output* res_src_output,
std::string output_name);
void emit_debug_function_entry(codegen::CodeWriter& writer,
Node* node,
const std::vector<TensorViewWrapper>& in,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment