Commit f6e536cd authored by Scott Cyphers's avatar Scott Cyphers Committed by Adam Procter

Clean up some names. (#149)

parent 3ad5140c
...@@ -85,5 +85,7 @@ namespace ngraph ...@@ -85,5 +85,7 @@ namespace ngraph
std::shared_ptr<layout::TensorViewLayout> m_tensor_view_layout; std::shared_ptr<layout::TensorViewLayout> m_tensor_view_layout;
std::string m_name; std::string m_name;
}; };
using TensorViewPtrs = std::vector<std::shared_ptr<TensorView>>;
} }
} }
...@@ -22,25 +22,25 @@ using namespace ngraph::runtime; ...@@ -22,25 +22,25 @@ using namespace ngraph::runtime;
CallFrame::CallFrame(size_t n_inputs, CallFrame::CallFrame(size_t n_inputs,
size_t n_outputs, size_t n_outputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& temps, const TensorViewPtrs& temps,
size_t initial_pc, size_t initial_pc,
const shared_ptr<vector<shared_ptr<Instruction>>>& instructions) const shared_ptr<vector<shared_ptr<Instruction>>>& instructions)
: m_n_inputs(n_inputs) : m_n_inputs(n_inputs)
, m_n_outputs(n_outputs) , m_n_outputs(n_outputs)
, m_tensors(n_inputs + n_outputs + temps.size()) , m_tensor_views(n_inputs + n_outputs + temps.size())
, m_initial_pc(initial_pc) , m_initial_pc(initial_pc)
, m_instructions(instructions) , m_instructions(instructions)
{ {
copy(temps.begin(), temps.end(), m_tensors.begin() + m_n_inputs + m_n_outputs); copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs);
} }
void CallFrame::tensor_call( void CallFrame::tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs, const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs) const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs)
{ {
copy(inputs.begin(), inputs.end(), m_tensors.begin()); copy(inputs.begin(), inputs.end(), m_tensor_views.begin());
copy(outputs.begin(), outputs.end(), m_tensors.begin() + m_n_inputs); copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
m_next_pc = m_initial_pc; m_next_pc = m_initial_pc;
m_return = false; m_return = false;
while (!m_return) while (!m_return)
...@@ -50,7 +50,7 @@ void CallFrame::tensor_call( ...@@ -50,7 +50,7 @@ void CallFrame::tensor_call(
m_instructions->at(m_pc)->execute(*this); m_instructions->at(m_pc)->execute(*this);
} }
// Don't hold onto inputs/outputs // Don't hold onto inputs/outputs
fill_n(m_tensors.begin(), m_n_inputs + m_n_outputs, nullptr); fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr);
} }
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments, void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
...@@ -58,12 +58,14 @@ void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Va ...@@ -58,12 +58,14 @@ void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Va
{ {
// TODO: Check types of args and result // TODO: Check types of args and result
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> inputs; std::vector<std::shared_ptr<ngraph::runtime::TensorView>> inputs;
for (auto argument : arguments){ for (auto argument : arguments)
{
argument->collect_tensor_views(inputs, argument); argument->collect_tensor_views(inputs, argument);
} }
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> outputs; std::vector<std::shared_ptr<ngraph::runtime::TensorView>> outputs;
for (auto result : results){ for (auto result : results)
{
result->collect_tensor_views(outputs, result); result->collect_tensor_views(outputs, result);
} }
......
...@@ -34,7 +34,7 @@ namespace ngraph ...@@ -34,7 +34,7 @@ namespace ngraph
CallFrame( CallFrame(
size_t n_inputs, size_t n_inputs,
size_t n_outputs, size_t n_outputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& temps, const TensorViewPtrs& temps,
size_t initial_pc, size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions); const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
...@@ -45,24 +45,22 @@ namespace ngraph ...@@ -45,24 +45,22 @@ namespace ngraph
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outpus); const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outpus);
/// @brief Invoke the function with tuples pre-expanded to their underlying tensor views. /// @brief Invoke the function with tuples pre-expanded to their underlying tensor views.
void tensor_call( void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outpus);
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outpus);
void set_return() { m_return = true; } void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor(size_t i) { return m_tensors[i]; } std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; }
template <typename ET> template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor(size_t i) ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i)
{ {
return m_tensors[i]->get_parameterized_tensor<ET>(); return m_tensor_views[i]->get_parameterized_tensor_view<ET>();
} }
protected: protected:
size_t m_n_inputs; size_t m_n_inputs;
size_t m_n_outputs; size_t m_n_outputs;
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> m_tensors; TensorViewPtrs m_tensor_views;
size_t m_initial_pc; size_t m_initial_pc;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions; std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
size_t m_pc; size_t m_pc;
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
template <typename T> template <typename T>
void abs(T arg, T out) void abs(T arg, T out)
{ {
set_map(&*out, Eigen::abs(get_map(&*arg))); set_map_array(&*out, Eigen::abs(get_map_array(&*arg)));
} }
template <typename ET> template <typename ET>
...@@ -44,8 +44,8 @@ namespace ngraph ...@@ -44,8 +44,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::abs( runtime::eigen::abs(
call_frame.get_parameterized_tensor<ET>(m_arg), call_frame.get_parameterized_tensor_view<ET>(m_arg),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
template <typename T> template <typename T>
void add(T arg0, T arg1, T out) void add(T arg0, T arg1, T out)
{ {
set_map(&*out, get_map(&*arg0) + get_map(&*arg1)); set_map_array(&*out, get_map_array(&*arg0) + get_map_array(&*arg1));
} }
template <typename ET> template <typename ET>
...@@ -45,9 +45,9 @@ namespace ngraph ...@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::add( runtime::eigen::add(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -76,11 +76,11 @@ namespace ngraph ...@@ -76,11 +76,11 @@ namespace ngraph
std::vector<ParameterizedTensorView<ET>*> ptvs; std::vector<ParameterizedTensorView<ET>*> ptvs;
for(size_t arg : m_args) for(size_t arg : m_args)
{ {
ptvs.push_back(call_frame.get_parameterized_tensor<ET>(arg)); ptvs.push_back(call_frame.get_parameterized_tensor_view<ET>(arg));
} }
runtime::eigen::concat_matrix( runtime::eigen::concat_matrix(
ptvs, ptvs,
call_frame.get_parameterized_tensor<ET>(m_out), call_frame.get_parameterized_tensor_view<ET>(m_out),
m_axis); m_axis);
} }
......
...@@ -64,11 +64,11 @@ namespace ngraph ...@@ -64,11 +64,11 @@ namespace ngraph
std::vector<ParameterizedTensorView<ET>*> ptvs; std::vector<ParameterizedTensorView<ET>*> ptvs;
for(size_t arg : m_args) for(size_t arg : m_args)
{ {
ptvs.push_back(call_frame.get_parameterized_tensor<ET>(arg)); ptvs.push_back(call_frame.get_parameterized_tensor_view<ET>(arg));
} }
runtime::eigen::concat_vector( runtime::eigen::concat_vector(
ptvs, ptvs,
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -45,7 +45,7 @@ namespace ngraph ...@@ -45,7 +45,7 @@ namespace ngraph
{ {
runtime::eigen::assign_constant( runtime::eigen::assign_constant(
m_value, m_value,
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -42,8 +42,8 @@ namespace ngraph ...@@ -42,8 +42,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
call_frame.get_parameterized_tensor<ET>(m_out)->get_vector() = call_frame.get_parameterized_tensor_view<ET>(m_out)->get_vector() =
call_frame.get_parameterized_tensor<ET>(m_in)->get_vector(); call_frame.get_parameterized_tensor_view<ET>(m_in)->get_vector();
} }
protected: protected:
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
template <typename T> template <typename T>
void divide(T arg0, T arg1, T out) void divide(T arg0, T arg1, T out)
{ {
set_map(&*out, get_map(&*arg0) / get_map(&*arg1)); set_map_array(&*out, get_map_array(&*arg0) / get_map_array(&*arg1));
} }
template <typename ET> template <typename ET>
...@@ -45,9 +45,9 @@ namespace ngraph ...@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::divide( runtime::eigen::divide(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -45,9 +45,9 @@ namespace ngraph ...@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::dot( runtime::eigen::dot(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -28,9 +28,9 @@ namespace ngraph ...@@ -28,9 +28,9 @@ namespace ngraph
template <typename TI,typename TO> template <typename TI,typename TO>
void equal(TI arg0, TI arg1, TO out) void equal(TI arg0, TI arg1, TO out)
{ {
auto result_as_float = get_map(&*arg0) == get_map(&*arg1); auto result_as_float = get_map_array(&*arg0) == get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>(); auto result_as_char = result_as_float.template cast<char>();
set_map(&*out, result_as_char); set_map_array(&*out, result_as_char);
} }
template <typename ET> template <typename ET>
...@@ -47,9 +47,9 @@ namespace ngraph ...@@ -47,9 +47,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::equal( runtime::eigen::equal(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<element::Bool>(m_out)); call_frame.get_parameterized_tensor_view<element::Bool>(m_out));
} }
protected: protected:
......
...@@ -28,9 +28,9 @@ namespace ngraph ...@@ -28,9 +28,9 @@ namespace ngraph
template <typename TI,typename TO> template <typename TI,typename TO>
void less_than(TI arg0, TI arg1, TO out) void less_than(TI arg0, TI arg1, TO out)
{ {
auto result_as_float = get_map(&*arg0) < get_map(&*arg1); auto result_as_float = get_map_array(&*arg0) < get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>(); auto result_as_char = result_as_float.template cast<char>();
set_map(&*out, result_as_char); set_map_array(&*out, result_as_char);
} }
template <typename ET> template <typename ET>
...@@ -47,9 +47,9 @@ namespace ngraph ...@@ -47,9 +47,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::less_than( runtime::eigen::less_than(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<element::Bool>(m_out)); call_frame.get_parameterized_tensor_view<element::Bool>(m_out));
} }
protected: protected:
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
template <typename T> template <typename T>
void log(T arg, T out) void log(T arg, T out)
{ {
set_map(&*out, Eigen::log(get_map(&*arg))); set_map_array(&*out, Eigen::log(get_map_array(&*arg)));
} }
template <typename ET> template <typename ET>
...@@ -44,8 +44,8 @@ namespace ngraph ...@@ -44,8 +44,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::log( runtime::eigen::log(
call_frame.get_parameterized_tensor<ET>(m_arg), call_frame.get_parameterized_tensor_view<ET>(m_arg),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -45,9 +45,9 @@ namespace ngraph ...@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::matrix_mult( runtime::eigen::matrix_mult(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -45,9 +45,9 @@ namespace ngraph ...@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::matrix_vector_product( runtime::eigen::matrix_vector_product(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
template <typename T> template <typename T>
void maximum(T arg0, T arg1, T out) void maximum(T arg0, T arg1, T out)
{ {
set_map(out, get_map(&*arg0).max(get_map(&*arg1))); set_map_array(out, get_map_array(&*arg0).max(get_map_array(&*arg1)));
} }
template <typename ET> template <typename ET>
...@@ -45,9 +45,9 @@ namespace ngraph ...@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::maximum( runtime::eigen::maximum(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -27,7 +27,7 @@ namespace ngraph ...@@ -27,7 +27,7 @@ namespace ngraph
template <typename T> template <typename T>
void multiply(T arg0, T arg1, T out) void multiply(T arg0, T arg1, T out)
{ {
set_map(&*out, get_map(&*arg0) * get_map(&*arg1)); set_map_array(&*out, get_map_array(&*arg0) * get_map_array(&*arg1));
} }
template <typename ET> template <typename ET>
...@@ -44,9 +44,9 @@ namespace ngraph ...@@ -44,9 +44,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::multiply( runtime::eigen::multiply(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
template <typename T> template <typename T>
void negate(T arg, T out) void negate(T arg, T out)
{ {
set_map(&*out, -(get_map(&*arg))); set_map_array(&*out, -(get_map_array(&*arg)));
} }
template <typename ET> template <typename ET>
...@@ -44,8 +44,8 @@ namespace ngraph ...@@ -44,8 +44,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::negate( runtime::eigen::negate(
call_frame.get_parameterized_tensor<ET>(m_arg), call_frame.get_parameterized_tensor_view<ET>(m_arg),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -28,9 +28,9 @@ namespace ngraph ...@@ -28,9 +28,9 @@ namespace ngraph
template <typename TI,typename TO> template <typename TI,typename TO>
void not_equal(TI arg0, TI arg1, TO out) void not_equal(TI arg0, TI arg1, TO out)
{ {
auto result_as_float = get_map(&*arg0) != get_map(&*arg1); auto result_as_float = get_map_array(&*arg0) != get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>(); auto result_as_char = result_as_float.template cast<char>();
set_map(&*out, result_as_char); set_map_array(&*out, result_as_char);
} }
template <typename ET> template <typename ET>
...@@ -47,9 +47,9 @@ namespace ngraph ...@@ -47,9 +47,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::not_equal( runtime::eigen::not_equal(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<element::Bool>(m_out)); call_frame.get_parameterized_tensor_view<element::Bool>(m_out));
} }
protected: protected:
......
...@@ -45,9 +45,9 @@ namespace ngraph ...@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::scalar_tensor_product( runtime::eigen::scalar_tensor_product(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
template <typename TA,typename TB> template <typename TA,typename TB>
void select(TA arg0, TB arg1, TB arg2, TB out) void select(TA arg0, TB arg1, TB arg2, TB out)
{ {
set_map(&*out, get_map(&*arg0).select(get_map(&*arg1),get_map(&*arg2))); set_map_array(&*out, get_map_array(&*arg0).select(get_map_array(&*arg1),get_map_array(&*arg2)));
} }
template <typename ET> template <typename ET>
...@@ -46,10 +46,10 @@ namespace ngraph ...@@ -46,10 +46,10 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::select( runtime::eigen::select(
call_frame.get_parameterized_tensor<element::Bool>(m_arg0), call_frame.get_parameterized_tensor_view<element::Bool>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_arg2), call_frame.get_parameterized_tensor_view<ET>(m_arg2),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -28,7 +28,7 @@ namespace ngraph ...@@ -28,7 +28,7 @@ namespace ngraph
template <typename T> template <typename T>
void subtract(T arg0, T arg1, T out) void subtract(T arg0, T arg1, T out)
{ {
set_map(&*out, get_map(&*arg0) - get_map(&*arg1)); set_map_array(&*out, get_map_array(&*arg0) - get_map_array(&*arg1));
} }
template <typename ET> template <typename ET>
...@@ -45,9 +45,9 @@ namespace ngraph ...@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::subtract( runtime::eigen::subtract(
call_frame.get_parameterized_tensor<ET>(m_arg0), call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1), call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out)); call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
......
...@@ -25,7 +25,7 @@ namespace ngraph ...@@ -25,7 +25,7 @@ namespace ngraph
namespace eigen namespace eigen
{ {
template <typename T, typename U> template <typename T, typename U>
void set_map(std::shared_ptr<T>& t, const U& u) void set_map_array(std::shared_ptr<T>& t, const U& u)
{ {
auto& v = t->get_vector(); auto& v = t->get_vector();
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>( Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>(
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
} }
template <typename T, typename U> template <typename T, typename U>
void set_map(T* t, const U& u) void set_map_array(T* t, const U& u)
{ {
auto& v = t->get_vector(); auto& v = t->get_vector();
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>( Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>(
...@@ -57,7 +57,7 @@ namespace ngraph ...@@ -57,7 +57,7 @@ namespace ngraph
} }
template <typename T, typename U> template <typename T, typename U>
void set_map_2d(std::shared_ptr<T>& t, const U& u) void set_map_array_2d(std::shared_ptr<T>& t, const U& u)
{ {
auto& v = t->get_vector(); auto& v = t->get_vector();
auto& s = t->get_shape(); auto& s = t->get_shape();
...@@ -69,7 +69,7 @@ namespace ngraph ...@@ -69,7 +69,7 @@ namespace ngraph
} }
template <typename T, typename U> template <typename T, typename U>
void set_map_2d(T* t, const U& u) void set_map_array_2d(T* t, const U& u)
{ {
auto& v = t->get_vector(); auto& v = t->get_vector();
auto& s = t->get_shape(); auto& s = t->get_shape();
...@@ -106,7 +106,7 @@ namespace ngraph ...@@ -106,7 +106,7 @@ namespace ngraph
template <typename T> template <typename T>
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>> Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>
get_map(std::shared_ptr<T>& arg) get_map_array(std::shared_ptr<T>& arg)
{ {
auto& v = arg->get_vector(); auto& v = arg->get_vector();
return Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>( return Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>(
...@@ -114,7 +114,7 @@ namespace ngraph ...@@ -114,7 +114,7 @@ namespace ngraph
} }
template <typename T> template <typename T>
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>> get_map(T* arg) Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>> get_map_array(T* arg)
{ {
auto& v = arg->get_vector(); auto& v = arg->get_vector();
return Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>( return Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>(
...@@ -140,7 +140,7 @@ namespace ngraph ...@@ -140,7 +140,7 @@ namespace ngraph
template <typename T> template <typename T>
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
get_map_2d(std::shared_ptr<T>& arg) get_map_array_2d(std::shared_ptr<T>& arg)
{ {
auto& v = arg->get_vector(); auto& v = arg->get_vector();
auto& s = arg->get_shape(); auto& s = arg->get_shape();
...@@ -152,7 +152,7 @@ namespace ngraph ...@@ -152,7 +152,7 @@ namespace ngraph
} }
template <typename T> template <typename T>
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> get_map_2d(T* arg) Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> get_map_array_2d(T* arg)
{ {
auto& v = arg->get_vector(); auto& v = arg->get_vector();
auto& s = arg->get_shape(); auto& s = arg->get_shape();
......
...@@ -45,10 +45,10 @@ ...@@ -45,10 +45,10 @@
#include "ngraph/pass/topological_sort.hpp" #include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/eigen/abs.hpp" #include "ngraph/runtime/eigen/abs.hpp"
#include "ngraph/runtime/eigen/add.hpp" #include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/copy.hpp"
#include "ngraph/runtime/eigen/concat_matrix.hpp" #include "ngraph/runtime/eigen/concat_matrix.hpp"
#include "ngraph/runtime/eigen/concat_vector.hpp" #include "ngraph/runtime/eigen/concat_vector.hpp"
#include "ngraph/runtime/eigen/constant.hpp" #include "ngraph/runtime/eigen/constant.hpp"
#include "ngraph/runtime/eigen/copy.hpp"
#include "ngraph/runtime/eigen/divide.hpp" #include "ngraph/runtime/eigen/divide.hpp"
#include "ngraph/runtime/eigen/dot.hpp" #include "ngraph/runtime/eigen/dot.hpp"
#include "ngraph/runtime/eigen/equal.hpp" #include "ngraph/runtime/eigen/equal.hpp"
...@@ -67,7 +67,7 @@ ...@@ -67,7 +67,7 @@
#include "ngraph/runtime/external_function.hpp" #include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/utils.hpp" #include "ngraph/runtime/utils.hpp"
using namespace std; using namespace std;
using namespace ngraph::runtime; using namespace ngraph::runtime;
ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& function, ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
...@@ -95,20 +95,10 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func ...@@ -95,20 +95,10 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], in[2], out[0]) REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], in[2], out[0])
// Define code generators for handled ops. // Define code generators for handled ops.
std::unordered_map<std::type_index, ExternalFunction::OpMap& ExternalFunction::get_op_map()
std::function<void(const ngraph::Node*,
ExternalFunction*,
const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>>&
ExternalFunction::get_op_map()
{ {
static bool initialized = false; static bool initialized = false;
static std::unordered_map<std::type_index, static OpMap op_map;
std::function<void(const Node*,
ExternalFunction*,
const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>>
op_map;
if (!initialized) if (!initialized)
{ {
REGISTER_UNOP(op::Abs, runtime::eigen::AbsInstruction<element::Float32>); REGISTER_UNOP(op::Abs, runtime::eigen::AbsInstruction<element::Float32>);
...@@ -150,14 +140,16 @@ std::unordered_map<std::type_index, ...@@ -150,14 +140,16 @@ std::unordered_map<std::type_index,
if (result_shape.size() == 1) if (result_shape.size() == 1)
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>( make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>(in,
in, out[0])); out[0]));
} }
else if(result_shape.size() == 2) else if (result_shape.size() == 2)
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::ConcatMatrixInstruction<element::Float32>>( make_shared<runtime::eigen::ConcatMatrixInstruction<element::Float32>>(
in, (dynamic_cast<const op::Concat *>(n))->get_concatenation_axis(), out[0])); in,
(dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(),
out[0]));
} }
else else
{ {
......
...@@ -27,6 +27,12 @@ namespace ngraph ...@@ -27,6 +27,12 @@ namespace ngraph
{ {
class ExternalFunction class ExternalFunction
{ {
using OpFunction = std::function<void(const ngraph::Node*,
ExternalFunction*,
const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>;
using OpMap = std::unordered_map<std::type_index, OpFunction>;
public: public:
ExternalFunction(const std::shared_ptr<ngraph::Function>& function, ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true); bool release_function = true);
...@@ -50,14 +56,9 @@ namespace ngraph ...@@ -50,14 +56,9 @@ namespace ngraph
size_t m_n_outputs; size_t m_n_outputs;
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>> std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
m_instructions; m_instructions;
std::vector<std::shared_ptr<ngraph::descriptor::TensorView>> m_temp_views; ngraph::descriptor::TensorViewPtrs m_temp_views;
static std::unordered_map<std::type_index, static OpMap& get_op_map();
std::function<void(const ngraph::Node*,
ExternalFunction*,
const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>>&
get_op_map();
}; };
} }
} }
...@@ -43,7 +43,7 @@ namespace ngraph ...@@ -43,7 +43,7 @@ namespace ngraph
virtual ~TensorView() {} virtual ~TensorView() {}
template <typename ET> template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor() ParameterizedTensorView<ET>* get_parameterized_tensor_view()
{ {
return dynamic_cast<ParameterizedTensorView<ET>*>(this); return dynamic_cast<ParameterizedTensorView<ET>*>(this);
} }
...@@ -69,5 +69,7 @@ namespace ngraph ...@@ -69,5 +69,7 @@ namespace ngraph
protected: protected:
std::shared_ptr<ngraph::descriptor::TensorView> m_descriptor; std::shared_ptr<ngraph::descriptor::TensorView> m_descriptor;
}; };
using TensorViewPtrs = std::vector<std::shared_ptr<TensorView>>;
} }
} }
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/descriptor/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
using TensorViewIndex = unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t>;
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment