Commit c8d9c405 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU: Start op::Reduce implementation and implement no-axes case

parent 3871a4f0
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "ngraph/ops/constant.hpp" #include "ngraph/ops/constant.hpp"
#include "ngraph/ops/function_call.hpp" #include "ngraph/ops/function_call.hpp"
#include "ngraph/ops/get_tuple_element.hpp" #include "ngraph/ops/get_tuple_element.hpp"
#include "ngraph/ops/reduce.hpp"
#include "ngraph/ops/reshape.hpp" #include "ngraph/ops/reshape.hpp"
#include "ngraph/runtime/tensor_view_info.hpp" #include "ngraph/runtime/tensor_view_info.hpp"
#include "ngraph/runtime/cpu/call_frame.hpp" #include "ngraph/runtime/cpu/call_frame.hpp"
...@@ -1081,3 +1082,60 @@ void Emitter::EMITTER_DECL(EmitFunctionCall) ...@@ -1081,3 +1082,60 @@ void Emitter::EMITTER_DECL(EmitFunctionCall)
" (*cf)(inputs, outputs);\n" " (*cf)(inputs, outputs);\n"
" }\n"; " }\n";
} }
void Emitter::EMITTER_DECL(EmitReduce)
{
auto reduce = static_cast<const op::Reduce*>(n);
auto reduction_function = reduce->get_reduction_function();
std::shared_ptr<ExternalFunction> external;
try
{
external = function_map.at(reduction_function);
}
catch (const std::out_of_range)
{
external = make_shared<ExternalFunction>(reduction_function);
function_map.insert({reduction_function, external});
}
auto reductee_type = reduce->get_arguments().at(0)->get_value_type();
auto reductee_tensor_view_type =
dynamic_pointer_cast<const TensorViewType>(reductee_type);
assert(reductee_tensor_view_type);
auto reductee_shape = reductee_tensor_view_type->get_shape();
auto f_result_type = reduction_function->get_result_type();
auto f_result_tensor_view_type =
dynamic_pointer_cast<const TensorViewType>(f_result_type);
assert(f_result_tensor_view_type);
auto& f_result_element_type = f_result_tensor_view_type->get_element_type();
auto result_type = reduce->get_value_type();
auto result_tensor_view_type = dynamic_pointer_cast<const TensorViewType>(result_type);
assert(result_tensor_view_type);
auto result_shape = result_tensor_view_type->get_shape();
auto& reduction_axes = reduce->get_reduction_axes();
// Trivial case: no reduction axes (this includes the scalar-reductee case).
if (reduction_axes.empty())
{
TU +=
" {\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(f_result_element_type)] + ">(" +
to_string(outputs.at(0).get_index()) +
")->get_vector() =\n"
" call_frame->get_parameterized_tensor_view<" +
element_type_names[TI(f_result_element_type)] + ">(" +
to_string(inputs.at(0).get_index()) +
")->get_vector();\n"
" }\n";
}
else
{
throw ngraph_error("Reduce: only vectors and matrices are currently supported");
}
}
...@@ -79,6 +79,7 @@ namespace ngraph ...@@ -79,6 +79,7 @@ namespace ngraph
void EMITTER_DECL(EmitConstant); void EMITTER_DECL(EmitConstant);
void EMITTER_DECL(EmitReshape); void EMITTER_DECL(EmitReshape);
void EMITTER_DECL(EmitFunctionCall); void EMITTER_DECL(EmitFunctionCall);
void EMITTER_DECL(EmitReduce);
}; };
} }
} }
......
...@@ -111,6 +111,7 @@ static const OpMap dispatcher{ ...@@ -111,6 +111,7 @@ static const OpMap dispatcher{
{TI(ngraph::op::Constant), &Emitter::EmitConstant}, {TI(ngraph::op::Constant), &Emitter::EmitConstant},
{TI(ngraph::op::Reshape), &Emitter::EmitReshape}, {TI(ngraph::op::Reshape), &Emitter::EmitReshape},
{TI(ngraph::op::FunctionCall), &Emitter::EmitFunctionCall}, {TI(ngraph::op::FunctionCall), &Emitter::EmitFunctionCall},
{TI(ngraph::op::Reduce), &Emitter::EmitReduce},
}; };
#undef TI #undef TI
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment