Commit d3027ca3 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU Direct Execution: Refactor

parent e1aa2621
...@@ -166,6 +166,39 @@ using namespace ngraph; ...@@ -166,6 +166,39 @@ using namespace ngraph;
KV = K<uint64_t>; \ KV = K<uint64_t>; \
} }
#define BUILD_UNARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, size_t)> kernel; \
\
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::OP); \
\
auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \
auto& out0_tensor = tensor_data[out[0].get_name()]; \
\
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, out0_tensor, element_count); \
}; \
functors.emplace_back(functor);
#define BUILD_BINARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, void*, size_t)> kernel; \
\
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::OP); \
\
auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \
auto& arg1_tensor = tensor_data[args[1].get_name()]; \
auto& out0_tensor = tensor_data[out[0].get_name()]; \
\
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count); \
}; \
functors.emplace_back(functor);
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -175,117 +208,37 @@ namespace ngraph ...@@ -175,117 +208,37 @@ namespace ngraph
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::Add) void Builder::BUILDER_DECL(ngraph::op::Add)
{ {
auto& functors = external_function->get_functors(); BUILD_BINARY_ELEMWISE_FUNCTOR(add);
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::add);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
} }
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::Multiply) void Builder::BUILDER_DECL(ngraph::op::Multiply)
{ {
auto& functors = external_function->get_functors(); BUILD_BINARY_ELEMWISE_FUNCTOR(multiply);
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::multiply);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
} }
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::Abs) void Builder::BUILDER_DECL(ngraph::op::Abs)
{ {
auto& functors = external_function->get_functors(); BUILD_UNARY_ELEMWISE_FUNCTOR(abs);
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::abs);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
} }
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::Ceiling) void Builder::BUILDER_DECL(ngraph::op::Ceiling)
{ {
auto& functors = external_function->get_functors(); BUILD_UNARY_ELEMWISE_FUNCTOR(ceil);
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::ceil);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
} }
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::Relu) void Builder::BUILDER_DECL(ngraph::op::Relu)
{ {
auto& functors = external_function->get_functors(); BUILD_UNARY_ELEMWISE_FUNCTOR(relu);
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::relu);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
} }
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::Result) void Builder::BUILDER_DECL(ngraph::op::Result)
{ {
auto& functors = external_function->get_functors(); BUILD_UNARY_ELEMWISE_FUNCTOR(result);
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::result);
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto size = shape_size(node->get_shape());
auto functor = [&, kernel, size](CPURuntimeContext* ctx) {
kernel(arg0_tensor, out0_tensor, size);
};
functors.emplace_back(functor);
} }
template <> template <>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment