Commit f6fe106d authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Remaining reduction cases

parent c056e69c
...@@ -1125,6 +1125,8 @@ void Emitter::EMITTER_DECL(EmitReduce) ...@@ -1125,6 +1125,8 @@ void Emitter::EMITTER_DECL(EmitReduce)
auto& reduction_axes = reduce->get_reduction_axes(); auto& reduction_axes = reduce->get_reduction_axes();
auto arg0_layout = inputs[0].get_layout<DenseTensorViewLayout>();
// Trivial case: no reduction axes (this includes the scalar-reductee case). // Trivial case: no reduction axes (this includes the scalar-reductee case).
if (reduction_axes.empty()) if (reduction_axes.empty())
{ {
...@@ -1230,35 +1232,77 @@ void Emitter::EMITTER_DECL(EmitReduce) ...@@ -1230,35 +1232,77 @@ void Emitter::EMITTER_DECL(EmitReduce)
} }
else else
{ {
// PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type, std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
// "Reduce has unhandled element type", external->make_call_frame());
// runtime::ngvm::eigen::ReduceMatrixRowsInstruction, ef->get_callees().emplace_back(cf);
// external,
// in[0], TU +=
// in[1], " {\n"
// out[0]); " using ET = " + element_type_names[TI(f_result_element_type)] + ";\n"
" auto cf = callees.at(" + to_string(ef->get_callees().size() - 1) + ");\n"
" auto f = [cf](typename ET::type x, typename ET::type y) -> typename ET::type {\n"
" auto tx = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *tx = std::vector<typename ET::type>({x});\n"
" auto ty = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *ty = std::vector<typename ET::type>({y});\n"
" auto tr = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" (*cf)({tx, ty}, {tr});\n"
" return tr->get_vector()[0];\n"
" };\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenVector<" + element_type_names[TI(f_result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenMatrix<" + element_type_names[TI(f_result_element_type)] + ">(arg0, " +
EIGEN_MATRIX_FORMAT(arg0_layout->get_shape(), arg0_layout->get_strides()) + ").rowwise().redux(f);\n"
" }\n";
} }
} }
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{0}) else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{0})
{ {
if (reductee_shape.at(0) == 0) if (reductee_shape.at(0) == 0)
{ {
// PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type, TU += " {\n"
// "Reduce has unhandled element type", " auto arg1 = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
// runtime::ngvm::eigen::BroadcastScalarInstruction, ">(" + to_string(inputs[1].get_index()) + ");\n"
// in[1], " auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
// out[0]); ">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenArray1d<" + element_type_names[TI(f_result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenArray1d<" + element_type_names[TI(f_result_element_type)] + ">(arg1, "
EIGEN_VECTOR_FORMAT(inputs[1].get_layout<DenseTensorViewLayout>()->get_size()) ")(0, 0);\n"
" }\n";
} }
else else
{ {
// PUSH_POLYMORPHIC_INSTRUCTION( std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
// f_result_element_type, external->make_call_frame());
// "Reduce has unhandled element type", ef->get_callees().emplace_back(cf);
// runtime::ngvm::eigen::ReduceMatrixColumnsInstruction,
// external, TU +=
// in[0], " {\n"
// in[1], " using ET = " + element_type_names[TI(f_result_element_type)] + ";\n"
// out[0]); " auto cf = callees.at(" + to_string(ef->get_callees().size() - 1) + ");\n"
" auto f = [cf](typename ET::type x, typename ET::type y) -> typename ET::type {\n"
" auto tx = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *tx = std::vector<typename ET::type>({x});\n"
" auto ty = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *ty = std::vector<typename ET::type>({y});\n"
" auto tr = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" (*cf)({tx, ty}, {tr});\n"
" return tr->get_vector()[0];\n"
" };\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenVector<" + element_type_names[TI(f_result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenMatrix<" + element_type_names[TI(f_result_element_type)] + ">(arg0, " +
EIGEN_MATRIX_FORMAT(arg0_layout->get_shape(), arg0_layout->get_strides()) + ").colwise().redux(f);\n"
" }\n";
} }
} }
else else
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment