Unverified Commit 8fc6473e authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Parameter as output support (#315)

* add unit test

* fix Parameter as output in INTERPRETER

* cpu working

* parameter_as_output passing for INTERPRETER

* still works

* cleanup
parent 8a569f27
...@@ -370,9 +370,28 @@ using namespace ngraph::runtime; ...@@ -370,9 +370,28 @@ using namespace ngraph::runtime;
{ {
shared_ptr<descriptor::TensorView> tv = output->get_tensor_view(); shared_ptr<descriptor::TensorView> tv = output->get_tensor_view();
const element::Type& et = tv->get_tensor_view_type()->get_element_type(); const element::Type& et = tv->get_tensor_view_type()->get_element_type();
string type = et.c_type_string(); bool parameter_as_output = false;
writer << type << "* " << tv->get_tensor().get_name() << " = static_cast<" << type for (shared_ptr<op::Parameter> param : current_function->get_parameters())
<< "*>(outputs[" << output_index << "]);\n"; {
for (const descriptor::Output& pout : param->get_outputs())
{
shared_ptr<descriptor::TensorView> ptv = pout.get_tensor_view();
if (tv == ptv)
{
parameter_as_output = true;
writer << "memcpy(static_cast<" << et.c_type_string() << "*>(outputs["
<< output_index << "]), " << ptv->get_tensor().get_name() << ", "
<< ptv->get_tensor().size() << ");\n";
break;
}
}
}
if (!parameter_as_output)
{
string type = et.c_type_string();
writer << type << "* " << tv->get_tensor().get_name() << " = static_cast<" << type
<< "*>(outputs[" << output_index << "]);\n";
}
output_names.insert(tv->get_tensor().get_name()); output_names.insert(tv->get_tensor().get_name());
output_index++; output_index++;
} }
......
...@@ -51,7 +51,17 @@ void runtime::interpreter::INT_CallFrame::call( ...@@ -51,7 +51,17 @@ void runtime::interpreter::INT_CallFrame::call(
descriptor::Output* output = function->get_outputs().at(i); descriptor::Output* output = function->get_outputs().at(i);
shared_ptr<descriptor::TensorView> tv = output->get_tensor_view(); shared_ptr<descriptor::TensorView> tv = output->get_tensor_view();
string name = tv->get_tensor().get_name(); string name = tv->get_tensor().get_name();
tensor_map.insert({name, output_tvs[i]}); if (contains_key(tensor_map, name))
{
// Here we handle the special case where an output is just a copy of an input
memcpy(output_tvs[i]->get_data_ptr(),
tensor_map.at(name)->get_data_ptr(),
tv->get_tensor().size());
}
else
{
tensor_map.insert({name, output_tvs[i]});
}
} }
// Invoke computation // Invoke computation
...@@ -91,6 +101,7 @@ void runtime::interpreter::INT_CallFrame::call( ...@@ -91,6 +101,7 @@ void runtime::interpreter::INT_CallFrame::call(
} }
outputs.push_back(itv); outputs.push_back(itv);
} }
auto tuple = dynamic_pointer_cast<op::XLATuple>(op); auto tuple = dynamic_pointer_cast<op::XLATuple>(op);
if (tuple) if (tuple)
{ {
......
...@@ -34,6 +34,30 @@ static void copy_data(shared_ptr<runtime::TensorView> tv, const vector<T>& data) ...@@ -34,6 +34,30 @@ static void copy_data(shared_ptr<runtime::TensorView> tv, const vector<T>& data)
tv->write(data.data(), 0, data_size); tv->write(data.data(), 0, data_size);
} }
TEST(${BACKEND_NAME}, parameter_as_output)
{
auto shape = Shape{3, 4};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto rt = make_shared<TensorViewType>(element::f32, shape);
auto f = make_shared<Function>(A, rt, op::Parameters{A});
auto manager = runtime::Manager::get("${BACKEND_NAME}");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
// Create some tensors for input/output
shared_ptr<runtime::TensorView> a = backend->make_primary_tensor_view(element::f32, shape);
shared_ptr<runtime::TensorView> result = backend->make_primary_tensor_view(element::f32, shape);
vector<float> expected{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
vector<float> zero(shape_size(shape), 0);
copy_data(a, expected);
cf->call({a}, {result});
EXPECT_EQ(result->get_vector<float>(), expected);
}
TEST(${BACKEND_NAME}, ab) TEST(${BACKEND_NAME}, ab)
{ {
using f32 = element::Float32; using f32 = element::Float32;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment