Commit c77c9c67 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Implement reduce to scalar

parent a4d4d161
...@@ -1171,21 +1171,46 @@ void Emitter::EMITTER_DECL(EmitReduce) ...@@ -1171,21 +1171,46 @@ void Emitter::EMITTER_DECL(EmitReduce)
if (reductee_shape.at(0) == 0 || if (reductee_shape.at(0) == 0 ||
(reductee_shape.size() == 2 && reductee_shape.at(1) == 0)) (reductee_shape.size() == 2 && reductee_shape.at(1) == 0))
{ {
// PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type, TU +=
// "Reduce has unhandled element type", " {\n"
// runtime::ngvm::eigen::CopyInstruction, " call_frame->get_parameterized_tensor_view<" +
// in.at(1).get_index(), element_type_names[TI(f_result_element_type)] + ">(" +
// out.at(0).get_index()); to_string(outputs.at(0).get_index()) +
")->get_vector() =\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(f_result_element_type)] + ">(" +
to_string(inputs.at(1).get_index()) +
")->get_vector();\n"
" }\n";
} }
else else
{ {
// PUSH_POLYMORPHIC_INSTRUCTION(f_result_element_type, std::shared_ptr<CallFrame> cf = std::dynamic_pointer_cast<CallFrame>(
// "Reduce has unhandled element type", external->make_call_frame());
// runtime::ngvm::eigen::ReduceToScalarInstruction, ef->get_callees().emplace_back(cf);
// external,
// in[0], TU +=
// in[1], " {\n"
// out[0]); " using ET = " + element_type_names[TI(f_result_element_type)] + ";\n"
" auto cf = callees.at(" + to_string(ef->get_callees().size() - 1) + ");\n"
" auto f = [cf](typename ET::type x, typename ET::type y) -> typename ET::type {\n"
" auto tx = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *tx = std::vector<typename ET::type>({x});\n"
" auto ty = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" *ty = std::vector<typename ET::type>({y});\n"
" auto tr = ngraph::runtime::make_tensor<ET>(ngraph::Shape{});\n"
" (*cf)({tx, ty}, {tr});\n"
" return tr->get_vector()[0];\n"
" };\n"
" auto arg0 = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(inputs[0].get_index()) + ");\n"
" auto out = call_frame->get_tensor_view_data<" + element_type_names[TI(f_result_element_type)] +
">(" + to_string(outputs[0].get_index()) + ");\n"
" EigenArray1d<" + element_type_names[TI(f_result_element_type)] + ">(out, "
EIGEN_VECTOR_FORMAT(outputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ") =\n"
" EigenArray1d<" + element_type_names[TI(f_result_element_type)] + ">(arg0, "
EIGEN_VECTOR_FORMAT(inputs[0].get_layout<DenseTensorViewLayout>()->get_size()) ").redux(f);\n"
" }\n";
} }
} }
else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{1}) else if (reductee_shape.size() == 2 && reduction_axes == AxisSet{1})
......
...@@ -189,6 +189,7 @@ void ExternalFunction::compile(FunctionMap& function_map) ...@@ -189,6 +189,7 @@ void ExternalFunction::compile(FunctionMap& function_map)
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/runtime/tensor_view_info.hpp" #include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/utils.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp" #include "ngraph/runtime/cpu/call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp" #include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/eigen_utils.hpp" #include "ngraph/runtime/cpu/eigen_utils.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment