Commit ea6a5b85 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

any/all stop-gap CPU implementation (#2250)

* any/all stop-gap CPU implementation

* remove pass
parent fa3200f1
......@@ -15,10 +15,14 @@
//*****************************************************************************
#include "ngraph/runtime/cpu/kernel/reduce_function.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/reduce.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/tensor.hpp"
using namespace std;
......@@ -30,6 +34,52 @@ namespace ngraph
{
namespace cpu
{
template <>
void Builder::BUILDER_DECL(ngraph::op::Any)
{
auto& functors = external_function->get_functors();
auto reduce = static_cast<const ngraph::op::Any*>(node);
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto arg0_shape = args[0].get_shape();
auto out_shape = out[0].get_shape();
auto reduction_axes = reduce->get_reduction_axes();
auto functor = [&, arg0_shape, out_shape, reduction_axes](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
runtime::reference::any(static_cast<char*>(arg0_tensor),
static_cast<char*>(out_tensor),
arg0_shape,
out_shape,
reduction_axes);
};
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::All)
{
auto& functors = external_function->get_functors();
auto reduce = static_cast<const ngraph::op::All*>(node);
auto& arg0_tensor = external_function->get_tensor_data(args[0].get_name());
auto& out_tensor = external_function->get_tensor_data(out[0].get_name());
auto arg0_shape = args[0].get_shape();
auto out_shape = out[0].get_shape();
auto reduction_axes = reduce->get_reduction_axes();
auto functor = [&, arg0_shape, out_shape, reduction_axes](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
runtime::reference::all(static_cast<char*>(arg0_tensor),
static_cast<char*>(out_tensor),
arg0_shape,
out_shape,
reduction_axes);
};
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Reduce)
{
......@@ -136,6 +186,8 @@ namespace ngraph
}
REGISTER_OP_BUILDER(Reduce);
REGISTER_OP_BUILDER(Any);
REGISTER_OP_BUILDER(All);
}
}
}
......@@ -26,8 +26,10 @@
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
......@@ -1177,6 +1179,38 @@ namespace ngraph
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Any)
{
const ngraph::op::Any* any = static_cast<const ngraph::op::Any*>(node);
writer.block_begin();
{
writer << "reference::any(";
writer << " " << args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " {" << join(any->get_reduction_axes()) << "});\n";
}
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::All)
{
const ngraph::op::All* all = static_cast<const ngraph::op::All*>(node);
writer.block_begin();
{
writer << "reference::all(";
writer << " " << args[0].get_name() << ",\n";
writer << " " << out[0].get_name() << ",\n";
writer << " {" << join(args[0].get_shape()) << "},\n";
writer << " {" << join(out[0].get_shape()) << "},\n";
writer << " {" << join(all->get_reduction_axes()) << "});\n";
}
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::LRN)
{
......
......@@ -42,8 +42,10 @@
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
......@@ -294,6 +296,8 @@ static const runtime::cpu::OpMap dispatcher{
{TI(ngraph::op::Multiply), &runtime::cpu::CPU_Emitter::emit<op::Multiply>},
{TI(ngraph::op::Parameter), &runtime::cpu::CPU_Emitter::nop},
{TI(ngraph::op::Abs), &runtime::cpu::CPU_Emitter::emit<op::Abs>},
{TI(ngraph::op::Any), &runtime::cpu::CPU_Emitter::emit<op::Any>},
{TI(ngraph::op::All), &runtime::cpu::CPU_Emitter::emit<op::All>},
{TI(ngraph::op::BatchDot), &runtime::cpu::CPU_Emitter::emit<op::BatchDot>},
{TI(ngraph::op::Concat), &runtime::cpu::CPU_Emitter::emit<op::Concat>},
{TI(ngraph::op::Divide), &runtime::cpu::CPU_Emitter::emit<op::Divide>},
......@@ -490,7 +494,9 @@ void runtime::cpu::CPU_ExternalFunction::compile()
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/reference/all.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/argmax.hpp"
#include "ngraph/runtime/reference/argmin.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
......@@ -1095,7 +1101,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
{
auto pass_map = pass_manager.get_pass_config().get_enables();
REGISTER_KNOBBED_PASS(AnyAllReplacement, true, ngraph::pass);
REGISTER_KNOBBED_PASS(LikeReplacement, true, ngraph::pass);
REGISTER_KNOBBED_PASS(NopElimination, true, ngraph::pass);
REGISTER_KNOBBED_PASS(ZeroDimTensorElimination, true, ngraph::pass);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment