Unverified Commit d5e814aa authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Add support for aliased output to CPU and INTERPRETER backends (#320)

* aliased output unit test
* add support for aliased outputs to INTERPRETER and CPU
parent 83433ef2
......@@ -351,7 +351,24 @@ using namespace ngraph::runtime;
writer << "\n";
writer << "// Define outputs\n";
// create alias list
size_t output_index = 0;
unordered_map<descriptor::TensorView*, vector<size_t>> output_alias_map;
vector<size_t> aliases;
for (const descriptor::Output* output : current_function->get_outputs())
{
shared_ptr<descriptor::TensorView> otv = output->get_tensor_view();
vector<size_t>& al = output_alias_map[otv.get()];
al.push_back(output_index);
if (al.size() > 1)
{
aliases.push_back(output_index);
}
output_index++;
}
output_index = 0;
set<string> output_names;
for (const descriptor::Output* output : current_function->get_outputs())
{
......@@ -373,7 +390,7 @@ using namespace ngraph::runtime;
}
}
}
if (!parameter_as_output)
if (!parameter_as_output && !contains(aliases, output_index))
{
string type = et.c_type_string();
writer << type << "* " << tv->get_tensor().get_name() << " = static_cast<" << type
......@@ -413,13 +430,13 @@ using namespace ngraph::runtime;
for (const descriptor::Input& input : node->get_inputs())
{
const descriptor::Output& output = input.get_output();
auto tv = output.get_tensor_view();
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
in.push_back(TensorViewWrapper(tv));
}
vector<TensorViewWrapper> out;
for (const descriptor::Output& output : node->get_outputs())
{
auto tv = output.get_tensor_view();
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
out.push_back(TensorViewWrapper(tv));
}
if (m_emit_timing)
......@@ -427,6 +444,7 @@ using namespace ngraph::runtime;
emit_debug_function_entry(writer, node.get(), in, out);
}
handler->second(&emitter, node.get(), in, out);
handle_output_alias(writer, *node, output_alias_map);
if (m_emit_timing)
{
emit_debug_function_exit(writer, node.get(), in, out);
......@@ -471,6 +489,35 @@ using namespace ngraph::runtime;
}
}
void runtime::cpu::CPU_ExternalFunction::handle_output_alias(
codegen::CodeWriter& writer,
const Node& node,
const unordered_map<descriptor::TensorView*, vector<size_t>>& output_alias_map)
{
for (const descriptor::Output& output : node.get_outputs())
{
shared_ptr<descriptor::TensorView> otv = output.get_tensor_view();
auto it = output_alias_map.find(otv.get());
if (it != output_alias_map.end())
{
const vector<size_t>& outputs = it->second;
if (outputs.size() > 1)
{
writer << "{ // handle output alias for previous op\n";
writer.indent++;
for (size_t i = 1; i < outputs.size(); i++)
{
writer << "memcpy(static_cast<void*>(outputs[" << outputs[i]
<< "]), static_cast<void*>(outputs[" << outputs[0] << "]), "
<< otv->get_tensor().size() << ");\n";
}
writer.indent--;
writer << "}\n";
}
}
}
}
shared_ptr<ngraph::runtime::CallFrame> runtime::cpu::CPU_ExternalFunction::make_call_frame()
{
if (!m_is_compiled)
......
......@@ -69,6 +69,10 @@ namespace ngraph
Node* node,
const std::vector<TensorViewWrapper>& in,
const std::vector<TensorViewWrapper>& out);
void handle_output_alias(
codegen::CodeWriter& writer,
const Node&,
const std::unordered_map<descriptor::TensorView*, std::vector<size_t>>&);
std::unique_ptr<codegen::Compiler> m_compiler;
std::unique_ptr<codegen::ExecutionEngine> m_execution_engine;
......
......@@ -34,36 +34,53 @@ void runtime::interpreter::INT_CallFrame::call(
const vector<shared_ptr<runtime::interpreter::INT_TensorView>>& input_tvs,
const vector<shared_ptr<runtime::interpreter::INT_TensorView>>& output_tvs)
{
unordered_map<string, shared_ptr<runtime::interpreter::INT_TensorView>> tensor_map;
unordered_map<descriptor::TensorView*, shared_ptr<runtime::interpreter::INT_TensorView>>
tensor_map;
size_t arg_index = 0;
for (shared_ptr<op::Parameter> param : function->get_parameters())
{
for (const descriptor::Output& output : param->get_outputs())
{
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
descriptor::TensorView* tv = output.get_tensor_view().get();
string name = tv->get_tensor().get_name();
tensor_map.insert({name, input_tvs[arg_index++]});
tensor_map.insert({tv, input_tvs[arg_index++]});
}
}
for (size_t i = 0; i < output_tvs.size(); i++)
{
descriptor::Output* output = function->get_outputs().at(i);
shared_ptr<descriptor::TensorView> tv = output->get_tensor_view();
descriptor::TensorView* tv = output->get_tensor_view().get();
string name = tv->get_tensor().get_name();
if (contains_key(tensor_map, name))
if (contains_key(tensor_map, tv))
{
// Here we handle the special case where an output is just a copy of an input
memcpy(output_tvs[i]->get_data_ptr(),
tensor_map.at(name)->get_data_ptr(),
tensor_map.at(tv)->get_data_ptr(),
tv->get_tensor().size());
}
else
{
tensor_map.insert({name, output_tvs[i]});
tensor_map.insert({tv, output_tvs[i]});
}
}
// create alias list
size_t output_index = 0;
unordered_map<descriptor::TensorView*, vector<size_t>> output_alias_map;
vector<size_t> aliases;
for (const descriptor::Output* output : function->get_outputs())
{
shared_ptr<descriptor::TensorView> otv = output->get_tensor_view();
vector<size_t>& al = output_alias_map[otv.get()];
al.push_back(output_index);
if (al.size() > 1)
{
aliases.push_back(output_index);
}
output_index++;
}
// Invoke computation
for (shared_ptr<Node> op : function->get_ordered_ops())
{
......@@ -76,16 +93,16 @@ void runtime::interpreter::INT_CallFrame::call(
vector<shared_ptr<runtime::interpreter::INT_TensorView>> outputs;
for (const descriptor::Input& input : op->get_inputs())
{
shared_ptr<descriptor::TensorView> tv = input.get_output().get_tensor_view();
descriptor::TensorView* tv = input.get_output().get_tensor_view().get();
string name = tv->get_tensor().get_name();
inputs.push_back(tensor_map.at(name));
inputs.push_back(tensor_map.at(tv));
}
for (descriptor::Output& output : op->get_outputs())
{
shared_ptr<descriptor::TensorView> tv = output.get_tensor_view();
descriptor::TensorView* tv = output.get_tensor_view().get();
string name = tv->get_tensor().get_name();
shared_ptr<runtime::interpreter::INT_TensorView> itv;
if (!contains_key(tensor_map, name))
if (!contains_key(tensor_map, tv))
{
// The output tensor is not in the tensor map so create a new tensor
const Shape& shape = output.get_shape();
......@@ -93,11 +110,11 @@ void runtime::interpreter::INT_CallFrame::call(
string tensor_name = output.get_tensor().get_name();
itv = make_shared<runtime::interpreter::INT_TensorView>(
element_type, shape, tensor_name);
tensor_map.insert({name, itv});
tensor_map.insert({tv, itv});
}
else
{
itv = tensor_map.at(name);
itv = tensor_map.at(tv);
}
outputs.push_back(itv);
}
......@@ -135,6 +152,8 @@ void runtime::interpreter::INT_CallFrame::call(
generate_calls(base_type, secondary_type, *op, inputs, outputs);
}
handle_output_alias(*op, output_alias_map, output_tvs);
// Delete any obsolete tensors
for (const descriptor::Tensor* t : op->liveness_free_list)
{
......@@ -150,6 +169,31 @@ void runtime::interpreter::INT_CallFrame::call(
}
}
void runtime::interpreter::INT_CallFrame::handle_output_alias(
const Node& node,
const unordered_map<descriptor::TensorView*, vector<size_t>>& output_alias_map,
const vector<shared_ptr<runtime::interpreter::INT_TensorView>>& output_tvs)
{
for (const descriptor::Output& output : node.get_outputs())
{
shared_ptr<descriptor::TensorView> otv = output.get_tensor_view();
auto it = output_alias_map.find(otv.get());
if (it != output_alias_map.end())
{
const vector<size_t>& outputs = it->second;
if (outputs.size() > 1)
{
for (size_t i = 1; i < outputs.size(); i++)
{
memcpy(static_cast<void*>(output_tvs[i]->get_data_ptr()),
static_cast<void*>(output_tvs[0]->get_data_ptr()),
otv->get_tensor().size());
}
}
}
}
}
void runtime::interpreter::INT_CallFrame::generate_calls(
const element::Type& base_type,
const element::Type& secondary_type,
......
......@@ -120,6 +120,10 @@ private:
void call(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<runtime::interpreter::INT_TensorView>>& input_tvs,
const std::vector<std::shared_ptr<runtime::interpreter::INT_TensorView>>& output_tvs);
void handle_output_alias(
const Node& node,
const std::unordered_map<descriptor::TensorView*, std::vector<size_t>>& output_alias_map,
const std::vector<std::shared_ptr<runtime::interpreter::INT_TensorView>>& output_tvs);
std::shared_ptr<ExternalFunction> m_external_function;
std::shared_ptr<Function> m_function;
......
......@@ -34,6 +34,42 @@ static void copy_data(shared_ptr<runtime::TensorView> tv, const vector<T>& data)
tv->write(data.data(), 0, data_size);
}
TEST(${BACKEND_NAME}, aliased_output)
{
using f32 = element::Float32;
auto shape = Shape{2, 2};
auto A = make_shared<op::Parameter>(f32::element_type(), shape);
auto B = make_shared<op::Parameter>(f32::element_type(), shape);
auto rt1 = make_shared<TensorViewType>(f32::element_type(), shape);
auto rt2 = make_shared<TensorViewType>(f32::element_type(), shape);
auto C = A + B;
auto f = make_shared<Function>(Nodes{C, C}, ValueTypes{rt1, rt2}, op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
shared_ptr<runtime::TensorView> a =
backend->make_primary_tensor_view(f32::element_type(), shape);
shared_ptr<runtime::TensorView> b =
backend->make_primary_tensor_view(f32::element_type(), shape);
shared_ptr<runtime::TensorView> out1 =
backend->make_primary_tensor_view(f32::element_type(), shape);
shared_ptr<runtime::TensorView> out2 =
backend->make_primary_tensor_view(f32::element_type(), shape);
copy_data(a, vector<float>{0, 1, 2, 3});
copy_data(b, vector<float>{1, 2, 3, 4});
vector<float> expected{1, 3, 5, 7};
cf->call({a, b}, {out1, out2});
EXPECT_EQ(expected, out1->get_vector<float>());
EXPECT_EQ(expected, out2->get_vector<float>());
}
TEST(${BACKEND_NAME}, parameter_as_output)
{
auto shape = Shape{3, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment