Commit 1d919bfc authored by Robert Kimball's avatar Robert Kimball Committed by Adam Procter

make call_frame call a little more readable (#265)

parent eebf0b70
...@@ -36,8 +36,8 @@ namespace ngraph ...@@ -36,8 +36,8 @@ namespace ngraph
/// ///
/// Tuples will be expanded into their tensor views to build the call frame. /// Tuples will be expanded into their tensor views to build the call frame.
virtual void virtual void
operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs, call(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs) = 0; const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs) = 0;
/// @brief Invoke the function with tuples pre-expanded to their underlying tensor views. /// @brief Invoke the function with tuples pre-expanded to their underlying tensor views.
virtual void tensor_call(const TensorViewPtrs& inputs, virtual void tensor_call(const TensorViewPtrs& inputs,
......
...@@ -50,8 +50,8 @@ void CallFrame::tensor_call( ...@@ -50,8 +50,8 @@ void CallFrame::tensor_call(
m_compiled_function(inputs.data(), outputs.data()); m_compiled_function(inputs.data(), outputs.data());
} }
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments, void CallFrame::call(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& results) const std::vector<std::shared_ptr<ngraph::runtime::Value>>& results)
{ {
// TODO: Check types of args and result // TODO: Check types of args and result
vector<shared_ptr<ngraph::runtime::TensorView>> inputs; vector<shared_ptr<ngraph::runtime::TensorView>> inputs;
......
...@@ -47,9 +47,8 @@ namespace ngraph ...@@ -47,9 +47,8 @@ namespace ngraph
/// @brief Invoke the function with values matching the signature of the function. /// @brief Invoke the function with values matching the signature of the function.
/// ///
/// Tuples will be expanded into their tensor views to build the call frame. /// Tuples will be expanded into their tensor views to build the call frame.
void void call(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs, const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs);
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs);
/// @brief Invoke the function with tuples pre-expanded to their underlying /// @brief Invoke the function with tuples pre-expanded to their underlying
/// tensor views. /// tensor views.
......
...@@ -63,8 +63,8 @@ void CallFrame::tensor_call( ...@@ -63,8 +63,8 @@ void CallFrame::tensor_call(
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr); fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr);
} }
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments, void CallFrame::call(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& results) const std::vector<std::shared_ptr<ngraph::runtime::Value>>& results)
{ {
// TODO: Check types of args and result // TODO: Check types of args and result
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> inputs; std::vector<std::shared_ptr<ngraph::runtime::TensorView>> inputs;
......
...@@ -47,9 +47,8 @@ namespace ngraph ...@@ -47,9 +47,8 @@ namespace ngraph
/// @brief Invoke the function with values matching the signature of the function. /// @brief Invoke the function with values matching the signature of the function.
/// ///
/// Tuples will be expanded into their tensor views to build the call frame. /// Tuples will be expanded into their tensor views to build the call frame.
void void call(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs,
operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& inputs, const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs);
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outputs);
/// @brief Invoke the function with tuples pre-expanded to their underlying tensor views. /// @brief Invoke the function with tuples pre-expanded to their underlying tensor views.
void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs); void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outputs);
......
...@@ -58,7 +58,7 @@ namespace ngraph ...@@ -58,7 +58,7 @@ namespace ngraph
{ {
outputs.push_back(call_frame.get_tensor_view(out.get_index())); outputs.push_back(call_frame.get_tensor_view(out.get_index()));
} }
(*cf)(inputs, outputs); cf->call(inputs, outputs);
} }
protected: protected:
......
...@@ -55,7 +55,7 @@ namespace ngraph ...@@ -55,7 +55,7 @@ namespace ngraph
auto ty = ngraph::runtime::make_tensor<ET>(Shape{}, {y}); auto ty = ngraph::runtime::make_tensor<ET>(Shape{}, {y});
auto tr = ngraph::runtime::make_tensor<ET>(Shape{}); auto tr = ngraph::runtime::make_tensor<ET>(Shape{});
(*cf)({tx, ty}, {tr}); cf->call({tx, ty}, {tr});
return tr->get_vector()[0]; return tr->get_vector()[0];
}; };
EigenVector<ET>(call_frame, m_out) = EigenVector<ET>(call_frame, m_out) =
......
...@@ -55,7 +55,7 @@ namespace ngraph ...@@ -55,7 +55,7 @@ namespace ngraph
auto ty = ngraph::runtime::make_tensor<ET>(Shape{}, {y}); auto ty = ngraph::runtime::make_tensor<ET>(Shape{}, {y});
auto tr = ngraph::runtime::make_tensor<ET>(Shape{}); auto tr = ngraph::runtime::make_tensor<ET>(Shape{});
(*cf)({tx, ty}, {tr}); cf->call({tx, ty}, {tr});
return tr->get_vector()[0]; return tr->get_vector()[0];
}; };
EigenVector<ET>(call_frame, m_out) = EigenVector<ET>(call_frame, m_out) =
......
...@@ -55,7 +55,7 @@ namespace ngraph ...@@ -55,7 +55,7 @@ namespace ngraph
auto ty = ngraph::runtime::make_tensor<ET>(Shape{}, {y}); auto ty = ngraph::runtime::make_tensor<ET>(Shape{}, {y});
auto tr = ngraph::runtime::make_tensor<ET>(Shape{}); auto tr = ngraph::runtime::make_tensor<ET>(Shape{});
(*cf)({tx, ty}, {tr}); cf->call({tx, ty}, {tr});
return tr->get_vector()[0]; return tr->get_vector()[0];
}; };
EigenArray1d<ET>(call_frame, m_out) = EigenArray1d<ET>(call_frame, m_out) =
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment