Commit 034d0849 authored by Scott Cyphers's avatar Scott Cyphers Committed by Robert Kimball

Copy aliased outputs properly in interpreter, add test (#443)

parent 9a54569b
......@@ -41,27 +41,35 @@ void runtime::interpreter::INT_CallFrame::call(
perform_nan_check(input_tvs);
}
unordered_map<descriptor::TensorView*, shared_ptr<runtime::HostTensorView>> tensor_map;
size_t arg_index = 0;
for (shared_ptr<op::Parameter> param : function->get_parameters())
{
for (size_t i = 0; i < param->get_output_size(); ++i)
{
descriptor::TensorView* tv = param->get_output_tensor_view(i).get();
string name = tv->get_tensor().get_name();
tensor_map.insert({tv, input_tvs[arg_index++]});
}
}
std::vector<size_t> aliased_outputs;
for (size_t i = 0; i < output_tvs.size(); i++)
{
descriptor::TensorView* tv = function->get_output_op(i)->get_output_tensor_view(0).get();
shared_ptr<Node> op = function->get_output_op(i);
descriptor::TensorView* tv = op->get_output_tensor_view(0).get();
string name = tv->get_tensor().get_name();
if (contains_key(tensor_map, tv))
{
// Here we handle the special case where an output is just a copy of an input
memcpy(output_tvs[i]->get_data_ptr(),
tensor_map.at(tv)->get_data_ptr(),
tv->get_tensor().size());
if (op->description() == "Parameter")
{
// Here we handle the special case where an output is just a copy of an input
memcpy(output_tvs[i]->get_data_ptr(),
tensor_map.at(tv)->get_data_ptr(),
tv->get_tensor().size());
}
else
{
// This is a computed value returned more than once and will need to be copied at the end
aliased_outputs.push_back(i);
}
}
else
{
......@@ -69,23 +77,6 @@ void runtime::interpreter::INT_CallFrame::call(
}
}
// create alias list
size_t output_index = 0;
unordered_map<descriptor::TensorView*, vector<size_t>> output_alias_map;
vector<size_t> aliases;
for (size_t i = 0; i < function->get_output_size(); ++i)
{
shared_ptr<descriptor::TensorView> otv =
function->get_output_op(i)->get_output_tensor_view(0);
vector<size_t>& al = output_alias_map[otv.get()];
al.push_back(output_index);
if (al.size() > 1)
{
aliases.push_back(output_index);
}
output_index++;
}
// Invoke computation
for (shared_ptr<Node> op : function->get_ordered_ops())
{
......@@ -157,8 +148,6 @@ void runtime::interpreter::INT_CallFrame::call(
perform_nan_check(outputs, op.get());
}
handle_output_alias(*op, output_alias_map, output_tvs);
// Delete any obsolete tensors
for (const descriptor::Tensor* t : op->liveness_free_list)
{
......@@ -172,30 +161,28 @@ void runtime::interpreter::INT_CallFrame::call(
}
}
}
}
void runtime::interpreter::INT_CallFrame::handle_output_alias(
const Node& node,
const unordered_map<descriptor::TensorView*, vector<size_t>>& output_alias_map,
const vector<shared_ptr<runtime::HostTensorView>>& output_tvs)
{
for (size_t i = 0; i < node.get_output_size(); ++i)
for (size_t i : aliased_outputs)
{
shared_ptr<descriptor::TensorView> otv = node.get_output_tensor_view(i);
auto it = output_alias_map.find(otv.get());
if (it != output_alias_map.end())
shared_ptr<Node> op = function->get_output_op(i);
size_t first_output;
for (first_output = 0; first_output <= i; ++first_output)
{
const vector<size_t>& outputs = it->second;
if (outputs.size() > 1)
if (function->get_output_op(first_output) == op)
{
for (size_t j = 1; j < outputs.size(); j++)
{
memcpy(static_cast<void*>(output_tvs[j]->get_data_ptr()),
static_cast<void*>(output_tvs[0]->get_data_ptr()),
otv->get_tensor().size());
}
break;
}
}
if (first_output == i)
{
throw ngraph_error("Internal error: duplicate output missing");
}
descriptor::TensorView* tv = op->get_output_tensor_view(0).get();
string name = tv->get_tensor().get_name();
// Here we handle the special case where an output is just a copy of an input
memcpy(output_tvs[i]->get_data_ptr(),
output_tvs[first_output]->get_data_ptr(),
tv->get_tensor().size());
}
}
......
......@@ -134,10 +134,6 @@ private:
void call(std::shared_ptr<Function> function,
const std::vector<std::shared_ptr<runtime::HostTensorView>>& input_tvs,
const std::vector<std::shared_ptr<runtime::HostTensorView>>& output_tvs);
void handle_output_alias(
const Node& node,
const std::unordered_map<descriptor::TensorView*, std::vector<size_t>>& output_alias_map,
const std::vector<std::shared_ptr<runtime::HostTensorView>>& output_tvs);
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensorView>>&,
const Node* op = nullptr);
......
......@@ -36,7 +36,8 @@ TEST(${BACKEND_NAME}, aliased_output)
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = A + B;
auto f = make_shared<Function>(Nodes{C, C}, op::Parameters{A, B});
auto D = A * B;
auto f = make_shared<Function>(Nodes{C, C, D, D, C}, op::Parameters{A, B});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
......@@ -48,14 +49,21 @@ TEST(${BACKEND_NAME}, aliased_output)
shared_ptr<runtime::TensorView> b = backend->make_primary_tensor_view(element::f32, shape);
shared_ptr<runtime::TensorView> out1 = backend->make_primary_tensor_view(element::f32, shape);
shared_ptr<runtime::TensorView> out2 = backend->make_primary_tensor_view(element::f32, shape);
shared_ptr<runtime::TensorView> out3 = backend->make_primary_tensor_view(element::f32, shape);
shared_ptr<runtime::TensorView> out4 = backend->make_primary_tensor_view(element::f32, shape);
shared_ptr<runtime::TensorView> out5 = backend->make_primary_tensor_view(element::f32, shape);
copy_data(a, vector<float>{0, 1, 2, 3});
copy_data(b, vector<float>{1, 2, 3, 4});
vector<float> expected{1, 3, 5, 7};
cf->call({a, b}, {out1, out2});
EXPECT_EQ(expected, read_vector<float>(out1));
EXPECT_EQ(expected, read_vector<float>(out2));
vector<float> expectedC{1, 3, 5, 7};
vector<float> expectedD{0, 2, 6, 12};
cf->call({a, b}, {out1, out2, out3, out4, out5});
EXPECT_EQ(expectedC, read_vector<float>(out1));
EXPECT_EQ(expectedC, read_vector<float>(out2));
EXPECT_EQ(expectedD, read_vector<float>(out3));
EXPECT_EQ(expectedD, read_vector<float>(out4));
EXPECT_EQ(expectedC, read_vector<float>(out5));
}
TEST(${BACKEND_NAME}, parameter_as_output)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment