Commit f6949fef authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Switch input/outputs order so outputs come first

parent 0d21bbf4
...@@ -20,30 +20,30 @@ using namespace std; ...@@ -20,30 +20,30 @@ using namespace std;
using namespace ngraph::runtime::cpu; using namespace ngraph::runtime::cpu;
CallFrame::CallFrame(EntryPoint compiled_function, CallFrame::CallFrame(EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs, size_t n_outputs,
size_t n_inputs,
const TensorViewPtrs& temps) const TensorViewPtrs& temps)
: m_n_inputs(n_inputs) : m_n_outputs(n_outputs)
, m_n_outputs(n_outputs) , m_n_inputs(n_inputs)
, m_tensor_views(n_inputs + n_outputs + temps.size()) , m_tensor_views(n_inputs + n_outputs + temps.size())
, m_compiled_function(compiled_function) , m_compiled_function(compiled_function)
{ {
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs); copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_outputs + m_n_inputs);
} }
void CallFrame::tensor_call( void CallFrame::tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs, const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs) const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs)
{ {
copy(inputs.begin(), inputs.end(), m_tensor_views.begin()); copy(outputs.begin(), outputs.end(), m_tensor_views.begin());
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs); copy(inputs.begin(), inputs.end(), m_tensor_views.begin() + m_n_outputs);
// Invoke compiled computation // Invoke compiled computation
m_compiled_function(this, m_tensor_views); m_compiled_function(this, m_tensor_views);
// Don't hold onto inputs/outputs // Don't hold onto inputs/outputs
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr); fill_n(m_tensor_views.begin(), m_n_outputs + m_n_inputs, nullptr);
} }
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments, void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
......
...@@ -39,8 +39,8 @@ namespace ngraph ...@@ -39,8 +39,8 @@ namespace ngraph
{ {
public: public:
CallFrame(EntryPoint compiled_function, CallFrame(EntryPoint compiled_function,
size_t n_inputs,
size_t n_outputs, size_t n_outputs,
size_t n_inputs,
const TensorViewPtrs& temps); const TensorViewPtrs& temps);
/// @brief Invoke the function with values matching the signature of the function. /// @brief Invoke the function with values matching the signature of the function.
...@@ -68,8 +68,8 @@ namespace ngraph ...@@ -68,8 +68,8 @@ namespace ngraph
} }
protected: protected:
size_t m_n_inputs;
size_t m_n_outputs; size_t m_n_outputs;
size_t m_n_inputs;
TensorViewPtrs m_tensor_views; TensorViewPtrs m_tensor_views;
bool m_return; bool m_return;
EntryPoint m_compiled_function; EntryPoint m_compiled_function;
......
...@@ -149,26 +149,27 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -149,26 +149,27 @@ void ExternalFunction::compile(FunctionMap& function_map)
// Determine tensor requirements for the call frame // Determine tensor requirements for the call frame
unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index; unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t> tensor_index;
// First come the function inputs
for (auto param : m_function->get_parameters()) // First come the function outputs
{ for (const descriptor::Output& output : m_function->get_result()->get_outputs())
for (const descriptor::Output& output : param->get_outputs())
{ {
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
size_t index = tensor_index.size(); size_t index = tensor_index.size();
tensor_index[tv] = index; tensor_index[tv] = index;
} }
} m_n_outputs = tensor_index.size();
m_n_inputs = tensor_index.size();
// Next are the function outputs // Next are the function inputs
for (const descriptor::Output& output : m_function->get_result()->get_outputs()) for (auto param : m_function->get_parameters())
{
for (const descriptor::Output& output : param->get_outputs())
{ {
auto tv = output.get_tensor_view(); auto tv = output.get_tensor_view();
size_t index = tensor_index.size(); size_t index = tensor_index.size();
tensor_index[tv] = index; tensor_index[tv] = index;
} }
m_n_outputs = tensor_index.size() - m_n_inputs; }
m_n_inputs = tensor_index.size() - m_n_outputs;
// All remaining tensor views // All remaining tensor views
for (shared_ptr<Node> node : m_function->get_ordered_ops()) for (shared_ptr<Node> node : m_function->get_ordered_ops())
...@@ -336,5 +337,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame() ...@@ -336,5 +337,5 @@ shared_ptr<ngraph::runtime::CallFrame> ExternalFunction::make_call_frame()
#undef M #undef M
} }
return make_shared<ngraph::runtime::cpu::CallFrame>( return make_shared<ngraph::runtime::cpu::CallFrame>(
compiled_function, m_n_inputs, m_n_outputs, temps); compiled_function, m_n_outputs, m_n_inputs, temps);
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment