Commit 504d3585 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement column-wise concatenation

parent 0e6aae41
......@@ -40,7 +40,7 @@ void CallFrame::tensor_call(
copy(inputs.begin(), inputs.end(), m_tensor_views.begin());
copy(outputs.begin(), outputs.end(), m_tensor_views.begin() + m_n_inputs);
// TODO: Execute!
// Invoke compiled computation
m_compiled_function(this, m_tensor_views);
// Don't hold onto inputs/outputs
......
......@@ -85,6 +85,13 @@ namespace ngraph
class M
{
M(const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: M(layout->get_shape(), layout->get_strides())
{
}
public:
M(const Shape& shape, const Strides& strides)
: l0(shape.at(0))
, l1(shape.at(1))
......@@ -93,13 +100,6 @@ namespace ngraph
{
}
M(const std::shared_ptr<ngraph::descriptor::layout::DenseTensorViewLayout>&
layout)
: M(layout->get_shape(), layout->get_strides())
{
}
public:
M(const TensorViewInfo& tensor_view_info)
: M(tensor_view_info.get_layout<
ngraph::descriptor::layout::DenseTensorViewLayout>())
......
......@@ -20,6 +20,7 @@
#include "ngraph/node.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/ops/concatenate.hpp"
#include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/external_function.hpp"
......@@ -44,22 +45,43 @@ static unordered_map<type_index, string> element_type_names = {{TI(ngraph::eleme
#define EIGEN_VECTOR_FORMAT(x) "{" + to_string(x) + "}"
//#define EIGEN_MATRIX_FORMAT(x)
void Emitter::EmitNop(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)
static std::string EIGEN_MATRIX_FORMAT(const ngraph::Shape& shape,
const ngraph::Strides& strides)
{
std::string I;
for (size_t i = 0; i < shape.size(); i++)
{
if (!i)
{
I += "{" + to_string(shape[i]);
}
else
{
I += ", " + to_string(shape[i]);
}
}
I += "}, ";
for (size_t i = 0; i < strides.size(); i++)
{
if (!i)
{
I += "{" + to_string(strides[i]);
}
else
{
I += ", " + to_string(strides[i]);
}
}
I += "}";
return I;
}
void Emitter::EMITTER_DECL(EmitNop)
{
}
void Emitter::EmitAdd(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)
void Emitter::EMITTER_DECL(EmitAdd)
{
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>(
n->get_arguments().at(0)->get_value_type()))
......@@ -78,19 +100,11 @@ void Emitter::EmitAdd(const ngraph::Node* n,
" }\n";
}
void Emitter::EmitDot(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)
void Emitter::EMITTER_DECL(EmitDot)
{
}
void Emitter::EmitMultiply(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)
void Emitter::EMITTER_DECL(EmitMultiply)
{
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>(
n->get_arguments().at(0)->get_value_type()))
......@@ -109,11 +123,7 @@ void Emitter::EmitMultiply(const ngraph::Node* n,
" }\n";
}
void Emitter::EmitGetTupleElement(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)
void Emitter::EMITTER_DECL(EmitGetTupleElement)
{
auto get_tuple_element = static_cast<const op::GetTupleElement*>(n);
auto result_tensor_type =
......@@ -129,11 +139,7 @@ void Emitter::EmitGetTupleElement(const ngraph::Node* n,
" }\n";
}
void Emitter::EmitTuple(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)
void Emitter::EMITTER_DECL(EmitTuple)
{
assert(inputs.size() == outputs.size());
......@@ -149,11 +155,7 @@ void Emitter::EmitTuple(const ngraph::Node* n,
TU += " }\n";
}
void Emitter::EmitAbs(const ngraph::Node* n,
ExternalFunction* ef,
FunctionMap& function_map,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs)
void Emitter::EMITTER_DECL(EmitAbs)
{
const element::Type& et = (dynamic_pointer_cast<const TensorViewType>(
n->get_arguments().at(0)->get_value_type()))
......@@ -168,3 +170,75 @@ void Emitter::EmitAbs(const ngraph::Node* n,
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) "));\n"
" }\n";
}
void Emitter::EMITTER_DECL(EmitConcat)
{
auto result_tensor_type =
dynamic_pointer_cast<const TensorViewType>(n->get_value_type());
assert(result_tensor_type);
auto result_shape = result_tensor_type->get_shape();
auto& result_element_type = result_tensor_type->get_element_type();
if (result_shape.size() == 1)
{
TU += " {\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenVector<" + element_type_names[TI(result_element_type)] + "> out_vector(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ");\n";
size_t concat_pos = 0;
for (size_t i = 0; i < inputs.size(); i++)
{
TU += " out_vector.segment(" + to_string(concat_pos) + ", " +
to_string(inputs[i].get_tensor_view_layout()->get_shape().at(0)) + ") << "
"EigenVector<" + element_type_names[TI(result_element_type)] + ">(call_frame->"
"get_tensor_view_data<" + element_type_names[TI(result_element_type)] + ">(" +
to_string(outputs[0].get_index()) + "), "
EIGEN_VECTOR_FORMAT(inputs[i].get_layout<DenseTensorViewLayout>()->get_size()) ");\n";
concat_pos += inputs[i].get_tensor_view_layout()->get_shape().at(0);
}
TU += " }\n";
}
else if (result_shape.size() == 2)
{
/*
PUSH_POLYMORPHIC_INSTRUCTION(
result_element_type,
"Concat has unhandled element type",
eigen::ConcatMatrixInstruction,
in,
(dynamic_cast<const op::Concat*>(n))->get_concatenation_axis(),
out[0]);
*/
auto out_layout = outputs[0].get_layout<DenseTensorViewLayout>();
auto axis = (dynamic_cast<const op::Concat*>(n))->get_concatenation_axis();
TU += " {\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenMatrix<" + element_type_names[TI(result_element_type)] + "> out_matrix(out, {" +
EIGEN_MATRIX_FORMAT(out_layout->get_shape(), out_layout->get_strides()) + "});\n";
size_t concat_pos[2]{0, 0};
for (size_t i = 0; i < inputs.size(); i++)
{
auto arg_layout = inputs[i].get_layout<DenseTensorViewLayout>();
auto& arg_shape = inputs[i].get_tensor_view_layout()->get_shape();
TU += " out_matrix.block(" + to_string(concat_pos[0]) + ", " +
to_string(concat_pos[1]) + ", " + to_string(arg_shape.at(0)) + ", " +
to_string(arg_shape.at(1)) + ") << "
"EigenMatrix<" + element_type_names[TI(result_element_type)] + ">(call_frame->"
"get_tensor_view_data<" + element_type_names[TI(result_element_type)] + ">(" +
to_string(inputs[i].get_index()) + "), {" +
EIGEN_MATRIX_FORMAT(arg_layout->get_shape(), arg_layout->get_strides()) + "});\n";
concat_pos[axis] += arg_shape.at(axis);
}
TU += " }\n";
}
}
......@@ -21,6 +21,13 @@
#include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/external_function.hpp"
#define EMITTER_DECL(E) E(const ngraph::Node* n, \
ExternalFunction* ef, \
FunctionMap& function_map, \
const std::vector<TensorViewInfo>& inputs, \
const std::vector<TensorViewInfo>& outputs)
namespace ngraph
{
namespace runtime
......@@ -35,48 +42,15 @@ namespace ngraph
public:
Emitter() : TU("") { }
std::string& GetTU() { return TU; }
void EmitNop(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs);
void EmitAdd(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs);
void EmitDot(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs);
void EmitMultiply(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs);
void EmitGetTupleElement(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs);
void EmitTuple(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs);
void EmitAbs(const ngraph::Node*,
ExternalFunction*,
FunctionMap&,
const std::vector<TensorViewInfo>& inputs,
const std::vector<TensorViewInfo>& outputs);
void EMITTER_DECL(EmitNop);
void EMITTER_DECL(EmitAdd);
void EMITTER_DECL(EmitDot);
void EMITTER_DECL(EmitMultiply);
void EMITTER_DECL(EmitGetTupleElement);
void EMITTER_DECL(EmitTuple);
void EMITTER_DECL(EmitAbs);
void EMITTER_DECL(EmitConcat);
};
}
......
......@@ -72,7 +72,8 @@ static const OpMap dispatcher{{TI(ngraph::op::Add), &Emitter::EmitAdd},
{TI(ngraph::op::Parameter), &Emitter::EmitNop},
{TI(ngraph::op::GetTupleElement), &Emitter::EmitGetTupleElement},
{TI(ngraph::op::Tuple), &Emitter::EmitTuple},
{TI(ngraph::op::Abs), &Emitter::EmitAbs}
{TI(ngraph::op::Abs), &Emitter::EmitAbs},
{TI(ngraph::op::Concat), &Emitter::EmitConcat}
};
#undef TI
......
......@@ -261,8 +261,7 @@ TEST(cpu, abs)
ASSERT_EQ((vector<float>{1, 2, 0, 4.8f}), result->get_vector());
}
/*
TEST(execute, concat_matrix_colwise)
TEST(cpu, concat_matrix_colwise)
{
auto shape_a = Shape{2, 2};
auto A = make_shared<op::Parameter>(element::Float32::element_type(), shape_a);
......@@ -275,7 +274,7 @@ TEST(execute, concat_matrix_colwise)
auto f = make_shared<Function>(
make_shared<op::Concat>(Nodes{A, B, C}, 1), rt, op::Parameters{A, B, C});
auto manager = runtime::Manager::get("NGVM");
auto manager = runtime::Manager::get("CPU");
auto external = manager->compile(f);
auto backend = manager->allocate_backend();
auto cf = backend->make_call_frame(external);
......@@ -294,6 +293,7 @@ TEST(execute, concat_matrix_colwise)
result->get_vector());
}
/*
TEST(execute, concat_matrix_rowwise)
{
auto shape_a = Shape{2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment