Commit 67547753 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement op::Slice

parent c5aa44c8
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/ops/get_tuple_element.hpp" #include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/reduce.hpp" #include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/runtime/tensor_view_info.hpp" #include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp" #include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/emitter.hpp" #include "ngraph/runtime/cpu/emitter.hpp"
...@@ -1326,3 +1327,87 @@ void Emitter::EMITTER_DECL(EmitSign) ...@@ -1326,3 +1327,87 @@ void Emitter::EMITTER_DECL(EmitSign)
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ").sign();\n" EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ").sign();\n"
" }\n"; " }\n";
} }
void Emitter::EMITTER_DECL(EmitSlice)
{
auto slice = static_cast<const op::Slice*>(n);
for (auto d : slice->get_step())
{
if (1 != d)
{
throw ngraph_error("Slice does not support non-unit step yet");
}
}
auto arg_type = slice->get_arguments().at(0)->get_value_type();
auto arg_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(arg_type);
assert(arg_tensor_view_type);
auto arg_shape = arg_tensor_view_type->get_shape();
auto arg_rank = arg_shape.size();
auto& arg_element_type = arg_tensor_view_type->get_element_type();
auto& lower_bounds = slice->get_lower_bounds();
auto& upper_bounds = slice->get_upper_bounds();
// Scalar slice is necessarily just a copy.
if (arg_rank == 0)
{
TU +=
" {\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(arg_element_type)] + ">(" +
to_string(outputs.at(0).get_index()) +
")->get_vector() =\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(arg_element_type)] + ">(" +
to_string(inputs.at(0).get_index()) +
")->get_vector();\n"
" }\n";
}
else if (arg_rank == 1)
{
TU +=
" {\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(arg_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(arg_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenVector<" + element_type_names[TI(arg_element_type)] +
">(out, " EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenVector<" + element_type_names[TI(arg_element_type)] +
">(arg0, " EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ").segment(\n"
" " + to_string(lower_bounds[0]) + ", " + to_string(upper_bounds[0] - lower_bounds[0]) + ");\n"
" }\n";
}
else if (arg_rank == 2)
{
auto arg0_layout = inputs[0].get_layout<DenseTensorViewLayout>();
auto out_layout = outputs[0].get_layout<DenseTensorViewLayout>();
TU +=
" {\n"
" auto arg0 = call_frame->get_tensor_view_data<" +
element_type_names[TI(arg_element_type)] + ">(" + to_string(inputs[0].get_index()) +
");\n"
" auto out = call_frame->get_tensor_view_data<" +
element_type_names[TI(arg_element_type)] + ">(" + to_string(outputs[0].get_index()) +
");\n"
" EigenMatrix<" +
element_type_names[TI(arg_element_type)] + ">(out, " +
EIGEN_MATRIX_FORMAT(out_layout->get_shape(), out_layout->get_strides()) +
") = \n"
" EigenMatrix<" +
element_type_names[TI(arg_element_type)] + ">(arg0, " +
EIGEN_MATRIX_FORMAT(arg0_layout->get_shape(), arg0_layout->get_strides()) +
").block(" + to_string(lower_bounds[0]) + ", " + to_string(lower_bounds[1]) + ",\n"
" " + to_string(upper_bounds[0] - lower_bounds[0]) + ",\n"
" " + to_string(upper_bounds[1] - lower_bounds[1])+ ");\n"
" }\n";
}
// Other cases (reordering of axes for tensors with rank>2) are not handled yet.
else
{
throw ngraph_error("Slice is not implemented yet for tensors with rank>2 in VM");
}
}
...@@ -81,6 +81,7 @@ namespace ngraph ...@@ -81,6 +81,7 @@ namespace ngraph
void EMITTER_DECL(EmitFunctionCall); void EMITTER_DECL(EmitFunctionCall);
void EMITTER_DECL(EmitReduce); void EMITTER_DECL(EmitReduce);
void EMITTER_DECL(EmitSign); void EMITTER_DECL(EmitSign);
void EMITTER_DECL(EmitSlice);
}; };
} }
} }
......
...@@ -51,6 +51,7 @@ ...@@ -51,6 +51,7 @@
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/ops/select.hpp" #include "ngraph/ops/select.hpp"
#include "ngraph/ops/sign.hpp" #include "ngraph/ops/sign.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/subtract.hpp" #include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/tuple.hpp" #include "ngraph/ops/tuple.hpp"
#include "ngraph/pass/assign_layout.hpp" #include "ngraph/pass/assign_layout.hpp"
...@@ -115,6 +116,7 @@ static const OpMap dispatcher{ ...@@ -115,6 +116,7 @@ static const OpMap dispatcher{
{TI(ngraph::op::FunctionCall), &Emitter::EmitFunctionCall}, {TI(ngraph::op::FunctionCall), &Emitter::EmitFunctionCall},
{TI(ngraph::op::Reduce), &Emitter::EmitReduce}, {TI(ngraph::op::Reduce), &Emitter::EmitReduce},
{TI(ngraph::op::Sign), &Emitter::EmitSign}, {TI(ngraph::op::Sign), &Emitter::EmitSign},
{TI(ngraph::op::Slice), &Emitter::EmitSlice},
}; };
#undef TI #undef TI
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment