Commit d3027ca3 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU Direct Execution: Refactor

parent e1aa2621
......@@ -166,6 +166,39 @@ using namespace ngraph;
KV = K<uint64_t>; \
}
#define BUILD_UNARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, size_t)> kernel; \
\
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::OP); \
\
auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \
auto& out0_tensor = tensor_data[out[0].get_name()]; \
\
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, out0_tensor, element_count); \
}; \
functors.emplace_back(functor);
#define BUILD_BINARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, void*, size_t)> kernel; \
\
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::OP); \
\
auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \
auto& arg1_tensor = tensor_data[args[1].get_name()]; \
auto& out0_tensor = tensor_data[out[0].get_name()]; \
\
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count); \
}; \
functors.emplace_back(functor);
namespace ngraph
{
namespace runtime
......@@ -175,117 +208,37 @@ namespace ngraph
template <>
void Builder::BUILDER_DECL(ngraph::op::Add)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::add);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
BUILD_BINARY_ELEMWISE_FUNCTOR(add);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Multiply)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::multiply);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
BUILD_BINARY_ELEMWISE_FUNCTOR(multiply);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Abs)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::abs);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
BUILD_UNARY_ELEMWISE_FUNCTOR(abs);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Ceiling)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::ceil);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
BUILD_UNARY_ELEMWISE_FUNCTOR(ceil);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Relu)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::relu);
auto element_count = out[0].get_size();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) {
kernel(arg0_tensor, out0_tensor, element_count);
};
functors.emplace_back(functor);
BUILD_UNARY_ELEMWISE_FUNCTOR(relu);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Result)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
std::function<void(void*, void*, size_t)> kernel;
SELECT_KERNEL(kernel, out[0].get_element_type(), runtime::cpu::kernel::result);
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
auto size = shape_size(node->get_shape());
auto functor = [&, kernel, size](CPURuntimeContext* ctx) {
kernel(arg0_tensor, out0_tensor, size);
};
functors.emplace_back(functor);
BUILD_UNARY_ELEMWISE_FUNCTOR(result);
}
template <>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment