Commit f6949fef authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Switch input/outputs order so outputs come first

parent 0d21bbf4
......@@ -20,30 +20,30 @@ using namespace std;
using namespace ngraph::runtime::cpu;
CallFrame::CallFrame(EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs,
size_t n_inputs,
const TensorViewPtrs& temps)
: m_n_inputs(n_inputs)
, m_n_outputs(n_outputs)
: m_n_outputs(n_outputs)
, m_n_inputs(n_inputs)
, m_tensor_views(n_inputs + n_outputs + temps.size())
, m_compiled_function(compiled_function)
{
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs);
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_outputs + m_n_inputs);
}
void CallFrame::tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs)
{
copy(inputs.begin(), inputs.end(), m_tensor_views.begin());
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
copy(outputs.begin(), outputs.end(), m_tensor_views.begin());
copy(inputs.begin(), inputs.end(), m_tensor_views.begin() + m_n_outputs);
// Invoke compiled computation
m_compiled_function(this, m_tensor_views);
// Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr);
fill_n(m_tensor_views.begin(), m_n_outputs + m_n_inputs, nullptr);
}
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
......
......@@ -39,8 +39,8 @@ namespace ngraph
{
public:
CallFrame(EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs,
size_t n_inputs,
const TensorViewPtrs& temps);
/// @brief Invoke the function with values matching the signature of the function.
......@@ -68,8 +68,8 @@ namespace ngraph
}
protected:
size_t m_n_inputs;
size_t m_n_outputs;
size_t m_n_inputs;
TensorViewPtrs m_tensor_views;
bool m_return;
EntryPoint m_compiled_function;
......
......@@ -149,26 +149,27 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Determine tensor requirements for the call frame
unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index;
// First come the function inputs
for (auto param : m_function->get_parameters())
{
for (const descriptor::Output& output : param->get_outputs())
// First come the function outputs
for (const descriptor::Output& output : m_function->get_result()->get_outputs())
{
auto tv = output.get_tensor_view();
size_t index = tensor_index.size();
tensor_index[tv] = index;
}
}
m_n_inputs = tensor_index.size();
m_n_outputs = tensor_index.size();
// Next are the function outputs
for (const descriptor::Output& output : m_function->get_result()->get_outputs())
// Next are the function inputs
for (auto param : m_function->get_parameters())
{
for (const descriptor::Output& output : param->get_outputs())
{
auto tv = output.get_tensor_view();
size_t index = tensor_index.size();
tensor_index[tv] = index;
}
m_n_outputs = tensor_index.size() - m_n_inputs;
}
m_n_inputs = tensor_index.size() - m_n_outputs;
// All remaining tensor views
for (shared_ptr<Node> node : m_function->get_ordered_ops())
......@@ -336,5 +337,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame()
#undef M
}
return make_shared<ngraph::runtime::cpu::CallFrame>(
compiled_function, m_n_inputs, m_n_outputs, temps);
compiled_function, m_n_outputs, m_n_inputs, temps);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment