Commit c1d0e594 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement op::Sum

parent 67547753
......@@ -29,6 +29,7 @@
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/emitter.hpp"
......@@ -1411,3 +1412,90 @@ void Emitter::EMITTER_DECL(EmitSlice)
throw ngraph_error("Slice is not implemented yet for tensors with rank>2 in VM");
}
}
void Emitter::EMITTER_DECL(EmitSum)
{
auto s = static_cast<const op::Sum*>(n);
auto s_tensor_view_type =
dynamic_pointer_cast<const TensorViewType>(s->get_value_type());
assert(s_tensor_view_type);
auto& s_element_type = s_tensor_view_type->get_element_type();
auto s_shape = s_tensor_view_type->get_shape();
auto arg = s->get_arguments().at(0);
auto arg_type = arg->get_value_type();
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
assert(arg_tensor_view_type);
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
auto& reduction_axes = s->get_reduction_axes();
// Trivial case: no reduction axes.
if (reduction_axes.size() == 0)
{
TU +=
" {\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(s_element_type)] + ">(" +
to_string(outputs.at(0).get_index()) +
")->get_vector() =\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(s_element_type)] + ">(" +
to_string(inputs.at(0).get_index()) +
")->get_vector();\n"
" }\n";
}
// Full reduction? Then sum to scalar.
else if ((arg_rank == 1 && reduction_axes == AxisSet{0}) ||
(arg_rank == 2 && reduction_axes == AxisSet{0, 1}))
{
TU +=
" {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(s_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(s_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenArray1d<" + element_type_names[TI(s_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenArray1d<" + element_type_names[TI(s_element_type)] + ">(arg0, "
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ").sum();\n"
" }\n";
}
else if (arg_rank == 2 && reduction_axes == AxisSet{1})
{
auto arg0_layout = inputs[0].get_layout<DenseTensorViewLayout>();
TU +=
" {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(s_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(s_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenVector<" + element_type_names[TI(s_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenMatrix<" + element_type_names[TI(s_element_type)] + ">(arg0, " +
EIGEN_MATRIX_FORMAT(arg0_layout->get_shape(), arg0_layout->get_strides()) + ").rowwise().sum();\n"
" }\n";
}
else if (arg_rank == 2 && reduction_axes == AxisSet{0})
{
auto arg0_layout = inputs[0].get_layout<DenseTensorViewLayout>();
TU +=
" {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(s_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(s_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenVector<" + element_type_names[TI(s_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenMatrix<" + element_type_names[TI(s_element_type)] + ">(arg0, " +
EIGEN_MATRIX_FORMAT(arg0_layout->get_shape(), arg0_layout->get_strides()) + ").colwise().sum();\n"
" }\n";
}
else
{
throw ngraph_error("Sum: only vectors and matrices are currently supported");
}
}
......@@ -82,6 +82,7 @@ namespace ngraph
void EMITTER_DECL(EmitReduce);
void EMITTER_DECL(EmitSign);
void EMITTER_DECL(EmitSlice);
void EMITTER_DECL(EmitSum);
};
}
}
......
......@@ -53,6 +53,7 @@
#include "ngraph/ops/sign.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/assign_tensors.hpp"
......@@ -117,6 +118,7 @@ static const OpMap dispatcher{
{TI(ngraph::op::Reduce), &Emitter::EmitReduce},
{TI(ngraph::op::Sign), &Emitter::EmitSign},
{TI(ngraph::op::Slice), &Emitter::EmitSlice},
{TI(ngraph::op::Sum), &Emitter::EmitSum},
};
#undef TI
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment