Commit f6e536cd authored by Scott Cyphers's avatar Scott Cyphers Committed by Adam Procter

Clean up some names. (#149)

parent 3ad5140c
......@@ -85,5 +85,7 @@ namespace ngraph
std::shared_ptr<layout::TensorViewLayout> m_tensor_view_layout;
std::string m_name;
};
using TensorViewPtrs = std::vector<std::shared_ptr<TensorView>>;
}
}
......@@ -22,25 +22,25 @@ using namespace ngraph::runtime;
CallFrame::CallFrame(size_t n_inputs,
size_t n_outputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& temps,
const TensorViewPtrs& temps,
size_t initial_pc,
const shared_ptr<vector<shared_ptr<Instruction>>>& instructions)
: m_n_inputs(n_inputs)
, m_n_outputs(n_outputs)
, m_tensors(n_inputs + n_outputs + temps.size())
, m_tensor_views(n_inputs + n_outputs + temps.size())
, m_initial_pc(initial_pc)
, m_instructions(instructions)
{
copy(temps.begin(), temps.end(), m_tensors.begin() + m_n_inputs + m_n_outputs);
copy(temps.begin(), temps.end(), m_tensor_views.begin() + m_n_inputs + m_n_outputs);
}
void CallFrame::tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outputs)
{
copy(inputs.begin(), inputs.end(), m_tensors.begin());
copy(outputs.begin(), outputs.end(), m_tensors.begin() + m_n_inputs);
copy(inputs.begin(), inputs.end(), m_tensor_views.begin());
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
m_next_pc = m_initial_pc;
m_return = false;
while (!m_return)
......@@ -50,7 +50,7 @@ void CallFrame::tensor_call(
m_instructions->at(m_pc)->execute(*this);
}
// Don't hold onto inputs/outputs
fill_n(m_tensors.begin(), m_n_inputs + m_n_outputs, nullptr);
fill_n(m_tensor_views.begin(), m_n_inputs + m_n_outputs, nullptr);
}
void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Value>>& arguments,
......@@ -58,12 +58,14 @@ void CallFrame::operator()(const std::vector<std::shared_ptr<ngraph::runtime::Va
{
// TODO: Check types of args and result
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> inputs;
for (auto argument : arguments){
for (auto argument : arguments)
{
argument->collect_tensor_views(inputs, argument);
}
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> outputs;
for (auto result : results){
for (auto result : results)
{
result->collect_tensor_views(outputs, result);
}
......
......@@ -34,7 +34,7 @@ namespace ngraph
CallFrame(
size_t n_inputs,
size_t n_outputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& temps,
const TensorViewPtrs& temps,
size_t initial_pc,
const std::shared_ptr<std::vector<std::shared_ptr<Instruction>>>& instructions);
......@@ -45,24 +45,22 @@ namespace ngraph
const std::vector<std::shared_ptr<ngraph::runtime::Value>>& outpus);
/// @brief Invoke the function with tuples pre-expanded to their underlying tensor views.
void tensor_call(
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& inputs,
const std::vector<std::shared_ptr<ngraph::runtime::TensorView>>& outpus);
void tensor_call(const TensorViewPtrs& inputs, const TensorViewPtrs& outpus);
void set_return() { m_return = true; }
std::shared_ptr<TensorView> get_tensor(size_t i) { return m_tensors[i]; }
std::shared_ptr<TensorView> get_tensor_view(size_t i) { return m_tensor_views[i]; }
template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor(size_t i)
ParameterizedTensorView<ET>* get_parameterized_tensor_view(size_t i)
{
return m_tensors[i]->get_parameterized_tensor<ET>();
return m_tensor_views[i]->get_parameterized_tensor_view<ET>();
}
protected:
size_t m_n_inputs;
size_t m_n_outputs;
std::vector<std::shared_ptr<ngraph::runtime::TensorView>> m_tensors;
TensorViewPtrs m_tensor_views;
size_t m_initial_pc;
std::shared_ptr<std::vector<std::shared_ptr<Instruction>>> m_instructions;
size_t m_pc;
......
......@@ -28,7 +28,7 @@ namespace ngraph
template <typename T>
void abs(T arg, T out)
{
set_map(&*out, Eigen::abs(get_map(&*arg)));
set_map_array(&*out, Eigen::abs(get_map_array(&*arg)));
}
template <typename ET>
......@@ -44,8 +44,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::abs(
call_frame.get_parameterized_tensor<ET>(m_arg),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -28,7 +28,7 @@ namespace ngraph
template <typename T>
void add(T arg0, T arg1, T out)
{
set_map(&*out, get_map(&*arg0) + get_map(&*arg1));
set_map_array(&*out, get_map_array(&*arg0) + get_map_array(&*arg1));
}
template <typename ET>
......@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::add(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -76,11 +76,11 @@ namespace ngraph
std::vector<ParameterizedTensorView<ET>*> ptvs;
for(size_t arg : m_args)
{
ptvs.push_back(call_frame.get_parameterized_tensor<ET>(arg));
ptvs.push_back(call_frame.get_parameterized_tensor_view<ET>(arg));
}
runtime::eigen::concat_matrix(
ptvs,
call_frame.get_parameterized_tensor<ET>(m_out),
call_frame.get_parameterized_tensor_view<ET>(m_out),
m_axis);
}
......
......@@ -64,11 +64,11 @@ namespace ngraph
std::vector<ParameterizedTensorView<ET>*> ptvs;
for(size_t arg : m_args)
{
ptvs.push_back(call_frame.get_parameterized_tensor<ET>(arg));
ptvs.push_back(call_frame.get_parameterized_tensor_view<ET>(arg));
}
runtime::eigen::concat_vector(
ptvs,
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -45,7 +45,7 @@ namespace ngraph
{
runtime::eigen::assign_constant(
m_value,
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -42,8 +42,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
call_frame.get_parameterized_tensor<ET>(m_out)->get_vector() =
call_frame.get_parameterized_tensor<ET>(m_in)->get_vector();
call_frame.get_parameterized_tensor_view<ET>(m_out)->get_vector() =
call_frame.get_parameterized_tensor_view<ET>(m_in)->get_vector();
}
protected:
......
......@@ -28,7 +28,7 @@ namespace ngraph
template <typename T>
void divide(T arg0, T arg1, T out)
{
set_map(&*out, get_map(&*arg0) / get_map(&*arg1));
set_map_array(&*out, get_map_array(&*arg0) / get_map_array(&*arg1));
}
template <typename ET>
......@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::divide(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::dot(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -28,9 +28,9 @@ namespace ngraph
template <typename TI,typename TO>
void equal(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map(&*arg0) == get_map(&*arg1);
auto result_as_float = get_map_array(&*arg0) == get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map(&*out, result_as_char);
set_map_array(&*out, result_as_char);
}
template <typename ET>
......@@ -47,9 +47,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::equal(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<element::Bool>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<element::Bool>(m_out));
}
protected:
......
......@@ -28,9 +28,9 @@ namespace ngraph
template <typename TI,typename TO>
void less_than(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map(&*arg0) < get_map(&*arg1);
auto result_as_float = get_map_array(&*arg0) < get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map(&*out, result_as_char);
set_map_array(&*out, result_as_char);
}
template <typename ET>
......@@ -47,9 +47,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::less_than(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<element::Bool>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<element::Bool>(m_out));
}
protected:
......
......@@ -28,7 +28,7 @@ namespace ngraph
template <typename T>
void log(T arg, T out)
{
set_map(&*out, Eigen::log(get_map(&*arg)));
set_map_array(&*out, Eigen::log(get_map_array(&*arg)));
}
template <typename ET>
......@@ -44,8 +44,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::log(
call_frame.get_parameterized_tensor<ET>(m_arg),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::matrix_mult(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::matrix_vector_product(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -28,7 +28,7 @@ namespace ngraph
template <typename T>
void maximum(T arg0, T arg1, T out)
{
set_map(out, get_map(&*arg0).max(get_map(&*arg1)));
set_map_array(out, get_map_array(&*arg0).max(get_map_array(&*arg1)));
}
template <typename ET>
......@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::maximum(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -27,7 +27,7 @@ namespace ngraph
template <typename T>
void multiply(T arg0, T arg1, T out)
{
set_map(&*out, get_map(&*arg0) * get_map(&*arg1));
set_map_array(&*out, get_map_array(&*arg0) * get_map_array(&*arg1));
}
template <typename ET>
......@@ -44,9 +44,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::multiply(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -28,7 +28,7 @@ namespace ngraph
template <typename T>
void negate(T arg, T out)
{
set_map(&*out, -(get_map(&*arg)));
set_map_array(&*out, -(get_map_array(&*arg)));
}
template <typename ET>
......@@ -44,8 +44,8 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::negate(
call_frame.get_parameterized_tensor<ET>(m_arg),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -28,9 +28,9 @@ namespace ngraph
template <typename TI,typename TO>
void not_equal(TI arg0, TI arg1, TO out)
{
auto result_as_float = get_map(&*arg0) != get_map(&*arg1);
auto result_as_float = get_map_array(&*arg0) != get_map_array(&*arg1);
auto result_as_char = result_as_float.template cast<char>();
set_map(&*out, result_as_char);
set_map_array(&*out, result_as_char);
}
template <typename ET>
......@@ -47,9 +47,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::not_equal(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<element::Bool>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<element::Bool>(m_out));
}
protected:
......
......@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::scalar_tensor_product(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -28,7 +28,7 @@ namespace ngraph
template <typename TA,typename TB>
void select(TA arg0, TB arg1, TB arg2, TB out)
{
set_map(&*out, get_map(&*arg0).select(get_map(&*arg1),get_map(&*arg2)));
set_map_array(&*out, get_map_array(&*arg0).select(get_map_array(&*arg1),get_map_array(&*arg2)));
}
template <typename ET>
......@@ -46,10 +46,10 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::select(
call_frame.get_parameterized_tensor<element::Bool>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_arg2),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<element::Bool>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_arg2),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -28,7 +28,7 @@ namespace ngraph
template <typename T>
void subtract(T arg0, T arg1, T out)
{
set_map(&*out, get_map(&*arg0) - get_map(&*arg1));
set_map_array(&*out, get_map_array(&*arg0) - get_map_array(&*arg1));
}
template <typename ET>
......@@ -45,9 +45,9 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override
{
runtime::eigen::subtract(
call_frame.get_parameterized_tensor<ET>(m_arg0),
call_frame.get_parameterized_tensor<ET>(m_arg1),
call_frame.get_parameterized_tensor<ET>(m_out));
call_frame.get_parameterized_tensor_view<ET>(m_arg0),
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
}
protected:
......
......@@ -25,7 +25,7 @@ namespace ngraph
namespace eigen
{
template <typename T, typename U>
void set_map(std::shared_ptr<T>& t, const U& u)
void set_map_array(std::shared_ptr<T>& t, const U& u)
{
auto& v = t->get_vector();
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>(
......@@ -33,7 +33,7 @@ namespace ngraph
}
template <typename T, typename U>
void set_map(T* t, const U& u)
void set_map_array(T* t, const U& u)
{
auto& v = t->get_vector();
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>(
......@@ -57,7 +57,7 @@ namespace ngraph
}
template <typename T, typename U>
void set_map_2d(std::shared_ptr<T>& t, const U& u)
void set_map_array_2d(std::shared_ptr<T>& t, const U& u)
{
auto& v = t->get_vector();
auto& s = t->get_shape();
......@@ -69,7 +69,7 @@ namespace ngraph
}
template <typename T, typename U>
void set_map_2d(T* t, const U& u)
void set_map_array_2d(T* t, const U& u)
{
auto& v = t->get_vector();
auto& s = t->get_shape();
......@@ -106,7 +106,7 @@ namespace ngraph
template <typename T>
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>
get_map(std::shared_ptr<T>& arg)
get_map_array(std::shared_ptr<T>& arg)
{
auto& v = arg->get_vector();
return Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>(
......@@ -114,7 +114,7 @@ namespace ngraph
}
template <typename T>
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>> get_map(T* arg)
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>> get_map_array(T* arg)
{
auto& v = arg->get_vector();
return Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, 1>>(
......@@ -140,7 +140,7 @@ namespace ngraph
template <typename T>
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
get_map_2d(std::shared_ptr<T>& arg)
get_map_array_2d(std::shared_ptr<T>& arg)
{
auto& v = arg->get_vector();
auto& s = arg->get_shape();
......@@ -152,7 +152,7 @@ namespace ngraph
}
template <typename T>
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> get_map_2d(T* arg)
Eigen::Map<Eigen::Array<typename T::value_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> get_map_array_2d(T* arg)
{
auto& v = arg->get_vector();
auto& s = arg->get_shape();
......
......@@ -45,10 +45,10 @@
#include "ngraph/pass/topological_sort.hpp"
#include "ngraph/runtime/eigen/abs.hpp"
#include "ngraph/runtime/eigen/add.hpp"
#include "ngraph/runtime/eigen/copy.hpp"
#include "ngraph/runtime/eigen/concat_matrix.hpp"
#include "ngraph/runtime/eigen/concat_vector.hpp"
#include "ngraph/runtime/eigen/constant.hpp"
#include "ngraph/runtime/eigen/copy.hpp"
#include "ngraph/runtime/eigen/divide.hpp"
#include "ngraph/runtime/eigen/dot.hpp"
#include "ngraph/runtime/eigen/equal.hpp"
......@@ -67,7 +67,7 @@
#include "ngraph/runtime/external_function.hpp"
#include "ngraph/runtime/utils.hpp"
using namespace std;
using namespace std;
using namespace ngraph::runtime;
ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
......@@ -95,20 +95,10 @@ ExternalFunction::ExternalFunction(const std::shared_ptr<ngraph::Function>& func
REGISTER_INSTRUCTION(op_class, instr_class, in[0], in[1], in[2], out[0])
// Define code generators for handled ops.
std::unordered_map<std::type_index,
std::function<void(const ngraph::Node*,
ExternalFunction*,
const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>>&
ExternalFunction::get_op_map()
ExternalFunction::OpMap& ExternalFunction::get_op_map()
{
static bool initialized = false;
static std::unordered_map<std::type_index,
std::function<void(const Node*,
ExternalFunction*,
const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>>
op_map;
static OpMap op_map;
if (!initialized)
{
REGISTER_UNOP(op::Abs, runtime::eigen::AbsInstruction<element::Float32>);
......@@ -150,14 +140,16 @@ std::unordered_map<std::type_index,
if (result_shape.size() == 1)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>(
in, out[0]));
make_shared<runtime::eigen::ConcatVectorInstruction<element::Float32>>(in,
out[0]));
}
else if(result_shape.size() == 2)
else if (result_shape.size() == 2)
{
ef->get_instructions()->push_back(
make_shared<runtime::eigen::ConcatMatrixInstruction<element::Float32>>(
in, (dynamic_cast<const op::Concat *>(n))->get_concatenation_axis(), out[0]));
in,
(dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(),
out[0]));
}
else
{
......
......@@ -27,6 +27,12 @@ namespace ngraph
{
class ExternalFunction
{
using OpFunction = std::function<void(const ngraph::Node*,
ExternalFunction*,
const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>;
using OpMap = std::unordered_map<std::type_index, OpFunction>;
public:
ExternalFunction(const std::shared_ptr<ngraph::Function>& function,
bool release_function = true);
......@@ -50,14 +56,9 @@ namespace ngraph
size_t m_n_outputs;
std::shared_ptr<std::vector<std::shared_ptr<ngraph::runtime::Instruction>>>
m_instructions;
std::vector<std::shared_ptr<ngraph::descriptor::TensorView>> m_temp_views;
ngraph::descriptor::TensorViewPtrs m_temp_views;
static std::unordered_map<std::type_index,
std::function<void(const ngraph::Node*,
ExternalFunction*,
const std::vector<size_t>& inputs,
const std::vector<size_t>& outputs)>>&
get_op_map();
static OpMap& get_op_map();
};
}
}
......@@ -43,7 +43,7 @@ namespace ngraph
virtual ~TensorView() {}
template <typename ET>
ParameterizedTensorView<ET>* get_parameterized_tensor()
ParameterizedTensorView<ET>* get_parameterized_tensor_view()
{
return dynamic_cast<ParameterizedTensorView<ET>*>(this);
}
......@@ -69,5 +69,7 @@ namespace ngraph
protected:
std::shared_ptr<ngraph::descriptor::TensorView> m_descriptor;
};
using TensorViewPtrs = std::vector<std::shared_ptr<TensorView>>;
}
}
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/descriptor/tensor_view.hpp"
namespace ngraph
{
namespace runtime
{
using TensorViewIndex = unordered_map<shared_ptr<ngraph::descriptor::TensorView>, size_t>;
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment