Commit c77c9c67 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement reduce to scalar

parent a4d4d161
......@@ -1171,21 +1171,46 @@ void Emitter::EMITTER_DECL(EmitReduce)
if (reductee_shape.at(0) == 0 ||
(reductee_shape.size() == 2 && reductee_shape.at(1) == 0))
{
// PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
// "Reduce has unhandled element type",
// runtime::ngvm::eigen::CopyInstruction,
// in.at(1).get_index(),
// out.at(0).get_index());
TU +=
" {\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(f_result_element_type)] + ">(" +
to_string(outputs.at(0).get_index()) +
")->get_vector() =\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(f_result_element_type)] + ">(" +
to_string(inputs.at(1).get_index()) +
")->get_vector();\n"
" }\n";
}
else
{
// PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type,
// "Reduce has unhandled element type",
// runtime::ngvm::eigen::ReduceToScalarInstruction,
// external,
// in[0],
// in[1],
// out[0]);
std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
external->make_call_frame());
ef->get_callees().emplace_back(cf);
TU +=
" {\n"
" using ET = " + element_type_names[TI(f_result_element_type)] + ";\n"
" auto cf = callees.at(" + to_string(ef->get_callees().size() - 1) + ");\n"
" auto f = [cf](typename ET::type x, typename ET::type y) -> typename ET::type {\n"
" auto tx = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *tx = std::vector<typename ET::type>({x});\n"
" auto ty = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *ty = std::vector<typename ET::type>({y});\n"
" auto tr = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" (*cf)({tx, ty}, {tr});\n"
" return tr->get_vector()[0];\n"
" };\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenArray1d<" + element_type_names[TI(f_result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenArray1d<" + element_type_names[TI(f_result_element_type)] + ">(arg0, "
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ").redux(f);\n"
" }\n";
}
}
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{1})
......
......@@ -189,6 +189,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/utils.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/eigen_utils.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment