Commit c3113593 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

Remove old {get,set}_{array_matrix}{_2d,} methods; update dot product to use new…

Remove old {get,set}_{array_matrix}{_2d,} methods; update dot product to use new Eigen wrapper (#158)
parent a90f6bf4
...@@ -25,17 +25,13 @@ namespace ngraph ...@@ -25,17 +25,13 @@ namespace ngraph
{ {
namespace eigen namespace eigen
{ {
template <typename T>
void dot(T arg0, T arg1, T out)
{
(&*out)->get_vector()[0] = get_map_matrix(&*arg0).dot(get_map_matrix(&*arg1));
}
template <typename ET> template <typename ET>
class DotInstruction : public Instruction class DotInstruction : public Instruction
{ {
public: public:
DotInstruction(size_t arg0, size_t arg1, size_t out) DotInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -44,16 +40,14 @@ namespace ngraph ...@@ -44,16 +40,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::dot( EigenArray1d<ET>(call_frame, m_out) <<
call_frame.get_parameterized_tensor_view<ET>(m_arg0), EigenVector<ET>(call_frame, m_arg0).dot(EigenVector<ET>(call_frame, m_arg1));
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
size_t m_arg0; TensorViewInfo m_arg0;
size_t m_arg1; TensorViewInfo m_arg1;
size_t m_out; TensorViewInfo m_out;
}; };
} }
} }
......
...@@ -25,17 +25,13 @@ namespace ngraph ...@@ -25,17 +25,13 @@ namespace ngraph
{ {
namespace eigen namespace eigen
{ {
template <typename T>
void matrix_mult(T arg0, T arg1, T out)
{
set_map_matrix_2d(&*out,get_map_matrix_2d(&*arg0) * get_map_matrix_2d(&*arg1));
}
template <typename ET> template <typename ET>
class MatrixMultInstruction : public Instruction class MatrixMultInstruction : public Instruction
{ {
public: public:
MatrixMultInstruction(size_t arg0, size_t arg1, size_t out) MatrixMultInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -44,16 +40,14 @@ namespace ngraph ...@@ -44,16 +40,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::matrix_mult( EigenMatrix<ET>(call_frame, m_out) =
call_frame.get_parameterized_tensor_view<ET>(m_arg0), EigenMatrix<ET>(call_frame, m_arg0) * EigenMatrix<ET>(call_frame, m_arg1);
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
size_t m_arg0; TensorViewInfo m_arg0;
size_t m_arg1; TensorViewInfo m_arg1;
size_t m_out; TensorViewInfo m_out;
}; };
} }
} }
......
...@@ -25,17 +25,13 @@ namespace ngraph ...@@ -25,17 +25,13 @@ namespace ngraph
{ {
namespace eigen namespace eigen
{ {
template <typename T>
void matrix_vector_product(T arg0, T arg1, T out)
{
set_map_matrix(&*out,get_map_matrix_2d(&*arg0) * get_map_matrix(&*arg1));
}
template <typename ET> template <typename ET>
class MatrixVectorProductInstruction : public Instruction class MatrixVectorProductInstruction : public Instruction
{ {
public: public:
MatrixVectorProductInstruction(size_t arg0, size_t arg1, size_t out) MatrixVectorProductInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -44,16 +40,14 @@ namespace ngraph ...@@ -44,16 +40,14 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::matrix_vector_product( EigenVector<ET>(call_frame, m_out) =
call_frame.get_parameterized_tensor_view<ET>(m_arg0), EigenMatrix<ET>(call_frame, m_arg0) * EigenVector<ET>(call_frame, m_arg1);
call_frame.get_parameterized_tensor_view<ET>(m_arg1),
call_frame.get_parameterized_tensor_view<ET>(m_out));
} }
protected: protected:
size_t m_arg0; TensorViewInfo m_arg0;
size_t m_arg1; TensorViewInfo m_arg1;
size_t m_out; TensorViewInfo m_out;
}; };
} }
} }
......
...@@ -25,17 +25,13 @@ namespace ngraph ...@@ -25,17 +25,13 @@ namespace ngraph
{ {
namespace eigen namespace eigen
{ {
template <typename T>
void scalar_tensor_product(T arg0, T arg1, T out)
{
set_map_matrix(&*out,(&*arg0)->get_vector()[0] * get_map_matrix(&*arg1));
}
template <typename ET> template <typename ET>
class ScalarTensorProductInstruction : public Instruction class ScalarTensorProductInstruction : public Instruction
{ {
public: public:
ScalarTensorProductInstruction(size_t arg0, size_t arg1, size_t out) ScalarTensorProductInstruction(const TensorViewInfo& arg0,
const TensorViewInfo& arg1,
const TensorViewInfo& out)
: m_arg0(arg0) : m_arg0(arg0)
, m_arg1(arg1) , m_arg1(arg1)
, m_out(out) , m_out(out)
...@@ -44,16 +40,19 @@ namespace ngraph ...@@ -44,16 +40,19 @@ namespace ngraph
virtual void execute(CallFrame& call_frame) const override virtual void execute(CallFrame& call_frame) const override
{ {
runtime::eigen::scalar_tensor_product( // This is a bit hacky: regardless of the tensor rank we
call_frame.get_parameterized_tensor_view<ET>(m_arg0), // pull it out as a vector. This works because of the way
call_frame.get_parameterized_tensor_view<ET>(m_arg1), // fmt::V computes sizes---it lumps together any higher
call_frame.get_parameterized_tensor_view<ET>(m_out)); // dimensions---while fmt::M ignores them.
EigenVector<ET>(call_frame, m_out) =
call_frame.get_tensor_view_data<ET>(m_arg0.get_index())[0]
* EigenVector<ET>(call_frame, m_arg1);
} }
protected: protected:
size_t m_arg0; TensorViewInfo m_arg0;
size_t m_arg1; TensorViewInfo m_arg1;
size_t m_out; TensorViewInfo m_out;
}; };
} }
} }
......
...@@ -150,94 +150,6 @@ namespace ngraph ...@@ -150,94 +150,6 @@ namespace ngraph
template <typename ET, typename FMT = fmt::V> template <typename ET, typename FMT = fmt::V>
using EigenVector = EigenWrapper<ET, FMT, EigenVectorBase<ET>, VectorStrides>; using EigenVector = EigenWrapper<ET, FMT, EigenVectorBase<ET>, VectorStrides>;
template <typename T>
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>,
0,
Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>
get_map_array(T* t, size_t l0, size_t l1, size_t s0, size_t s1)
{
return Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>,
0,
Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>>(
t, l0, l1, Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>(s0, s1));
}
template <typename T, typename U>
void set_map_matrix(std::shared_ptr<T>& t, const U& u)
{
auto& v = t->get_vector();
Eigen::Map<Eigen::Matrix<typename T::value_type, Eigen::Dynamic, 1>>(
&v[0], v.size(), 1) = u;
}
template <typename T, typename U>
void set_map_matrix(T* t, const U& u)
{
auto& v = t->get_vector();
Eigen::Map<Eigen::Matrix<typename T::value_type, Eigen::Dynamic, 1>>(
&v[0], v.size(), 1) = u;
}
template <typename T, typename U>
void set_map_matrix_2d(std::shared_ptr<T>& t, const U& u)
{
auto& v = t->get_vector();
auto& s = t->get_shape();
auto s_rest = std::vector<size_t>(s.begin() + 1, s.end());
Eigen::Map<Eigen::Matrix<typename T::value_type,
Eigen::Dynamic,
Eigen::Dynamic,
Eigen::RowMajor>>(
&v[0], s[0], ngraph::shape_size(s_rest)) = u;
}
template <typename T, typename U>
void set_map_matrix_2d(T* t, const U& u)
{
auto& v = t->get_vector();
auto& s = t->get_shape();
auto s_rest = std::vector<size_t>(s.begin() + 1, s.end());
Eigen::Map<Eigen::Matrix<typename T::value_type,
Eigen::Dynamic,
Eigen::Dynamic,
Eigen::RowMajor>>(
&v[0], s[0], ngraph::shape_size(s_rest)) = u;
}
template <typename T>
Eigen::Map<Eigen::Matrix<typename T::value_type, Eigen::Dynamic, 1>>
get_map_matrix(std::shared_ptr<T>& arg)
{
auto& v = arg->get_vector();
return Eigen::Map<Eigen::Matrix<typename T::value_type, Eigen::Dynamic, 1>>(
&v[0], v.size(), 1);
}
template <typename T>
Eigen::Map<Eigen::Matrix<typename T::value_type, Eigen::Dynamic, 1>>
get_map_matrix(T* arg)
{
auto& v = arg->get_vector();
return Eigen::Map<Eigen::Matrix<typename T::value_type, Eigen::Dynamic, 1>>(
&v[0], v.size(), 1);
}
template <typename T>
Eigen::Map<
Eigen::
Matrix<typename T::value_type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
get_map_matrix_2d(T* arg)
{
auto& v = arg->get_vector();
auto& s = arg->get_shape();
auto s_rest = std::vector<size_t>(s.begin() + 1, s.end());
return Eigen::Map<Eigen::Matrix<typename T::value_type,
Eigen::Dynamic,
Eigen::Dynamic,
Eigen::RowMajor>>(
&v[0], s[0], ngraph::shape_size(s_rest));
}
} }
} }
} }
...@@ -188,14 +188,14 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -188,14 +188,14 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::ScalarTensorProductInstruction<element::Float32>>( make_shared<runtime::eigen::ScalarTensorProductInstruction<element::Float32>>(
in[0].get_index(), in[1].get_index(), out[0].get_index())); in[0], in[1], out[0]));
} }
else if (arg1_shape.size() == 0) else if (arg1_shape.size() == 0)
{ {
// If arg1 is the scalar, do the same thing but switch the order of operands. // If arg1 is the scalar, do the same thing but switch the order of operands.
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::ScalarTensorProductInstruction<element::Float32>>( make_shared<runtime::eigen::ScalarTensorProductInstruction<element::Float32>>(
in[1].get_index(), in[0].get_index(), out[0].get_index())); in[1], in[0], out[0]));
} }
// If arg0 and arg1 are both vectors, emit a dot product. // If arg0 and arg1 are both vectors, emit a dot product.
...@@ -203,7 +203,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -203,7 +203,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::DotInstruction<element::Float32>>( make_shared<runtime::eigen::DotInstruction<element::Float32>>(
in[0].get_index(), in[1].get_index(), out[0].get_index())); in[0], in[1], out[0]));
} }
// If arg0 is a matrix and arg1 is a vector, emit a matrix-vector product. // If arg0 is a matrix and arg1 is a vector, emit a matrix-vector product.
...@@ -211,7 +211,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -211,7 +211,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::MatrixVectorProductInstruction<element::Float32>>( make_shared<runtime::eigen::MatrixVectorProductInstruction<element::Float32>>(
in[0].get_index(), in[1].get_index(), out[0].get_index())); in[0], in[1], out[0]));
} }
// If arg0 and arg1 are both matrices, emit a matrix product. // If arg0 and arg1 are both matrices, emit a matrix product.
...@@ -219,7 +219,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map() ...@@ -219,7 +219,7 @@ ExternalFunction::OpMap& ExternalFunction::get_op_map()
{ {
ef->get_instructions()->push_back( ef->get_instructions()->push_back(
make_shared<runtime::eigen::MatrixMultInstruction<element::Float32>>( make_shared<runtime::eigen::MatrixMultInstruction<element::Float32>>(
in[0].get_index(), in[1].get_index(), out[0].get_index())); in[0], in[1], out[0]));
} }
else else
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment