Commit f6fe106d authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Remaining reduction cases

parent c056e69c
......@@ -1125,6 +1125,8 @@ void Emitter::EMITTER_DECL(EmitReduce)
auto& reduction_axes = reduce->get_reduction_axes();
auto arg0_layout = inputs[0].get_layout<DenseTensorViewLayout>();
// Trivial case: no reduction axes (this includes the scalar-reductee case).
if (reduction_axes.empty())
{
......@@ -1230,35 +1232,77 @@ void Emitter::EMITTER_DECL(EmitReduce)
}
else
{
// PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
// "Reduce has unhandled element type",
// runtime::ngvm::eigen::ReduceMatrixRowsInstruction,
// external,
// in[0],
// in[1],
// out[0]);
std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
external->make_call_frame());
ef->get_callees().emplace_back(cf);
TU +=
" {\n"
" using ET = " + element_type_names[TI(f_result_element_type)] + ";\n"
" auto cf = callees.at(" + to_string(ef->get_callees().size() - 1) + ");\n"
" auto f = [cf](typename ET::type x, typename ET::type y) -> typename ET::type {\n"
" auto tx = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *tx = std::vector<typename ET::type>({x});\n"
" auto ty = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *ty = std::vector<typename ET::type>({y});\n"
" auto tr = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" (*cf)({tx, ty}, {tr});\n"
" return tr->get_vector()[0];\n"
" };\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenVector<" + element_type_names[TI(f_result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenMatrix<" + element_type_names[TI(f_result_element_type)] + ">(arg0, " +
EIGEN_MATRIX_FORMAT(arg0_layout->get_shape(), arg0_layout->get_strides()) + ").rowwise().redux(f);\n"
" }\n";
}
}
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{0})
{
if (reductee_shape.at(0) == 0)
{
// PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
// "Reduce has unhandled element type",
// runtime::ngvm::eigen::BroadcastScalarInstruction,
// in[1],
// out[0]);
TU += " {\n"
" auto arg1 = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(inputs[1].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenArray1d<" + element_type_names[TI(f_result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenArray1d<" + element_type_names[TI(f_result_element_type)] + ">(arg1, "
EIGEN_VECTOR_FORMAT(inputs[1].get_layout<DenseTensorViewLayout>()->get_size()) ")(0, 0);\n"
" }\n";
}
else
{
// PUSH_POLYMORPHIC_INSTRUCTION(
// f_result_element_type,
// "Reduce has unhandled element type",
// runtime::ngvm::eigen::ReduceMatrixColumnsInstruction,
// external,
// in[0],
// in[1],
// out[0]);
std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
external->make_call_frame());
ef->get_callees().emplace_back(cf);
TU +=
" {\n"
" using ET = " + element_type_names[TI(f_result_element_type)] + ";\n"
" auto cf = callees.at(" + to_string(ef->get_callees().size() - 1) + ");\n"
" auto f = [cf](typename ET::type x, typename ET::type y) -> typename ET::type {\n"
" auto tx = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *tx = std::vector<typename ET::type>({x});\n"
" auto ty = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *ty = std::vector<typename ET::type>({y});\n"
" auto tr = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" (*cf)({tx, ty}, {tr});\n"
" return tr->get_vector()[0];\n"
" };\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenVector<" + element_type_names[TI(f_result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenMatrix<" + element_type_names[TI(f_result_element_type)] + ">(arg0, " +
EIGEN_MATRIX_FORMAT(arg0_layout->get_shape(), arg0_layout->get_strides()) + ").colwise().redux(f);\n"
" }\n";
}
}
else
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment