Commit 6576932f authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

CPU Direct Execution: Refactor and implement builder auto-registration

This allows op builders to be self-contained changesets
parent 1df7602e
...@@ -31,32 +31,21 @@ ...@@ -31,32 +31,21 @@
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/asin.hpp" #include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp" #include "ngraph/op/atan.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/ceiling.hpp" #include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/cos.hpp" #include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp" #include "ngraph/op/cosh.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/function_call.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp" #include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp" #include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp" #include "ngraph/op/log.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp" #include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
...@@ -74,9 +63,7 @@ ...@@ -74,9 +63,7 @@
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/remainder.hpp" #include "ngraph/op/remainder.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/select_and_scatter.hpp" #include "ngraph/op/select_and_scatter.hpp"
...@@ -87,7 +74,6 @@ ...@@ -87,7 +74,6 @@
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp" #include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp" #include "ngraph/op/tanh.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp" #include "ngraph/runtime/cpu/cpu_kernels.hpp"
...@@ -96,18 +82,27 @@ ...@@ -96,18 +82,27 @@
#include "ngraph/runtime/cpu/kernel/add.hpp" #include "ngraph/runtime/cpu/kernel/add.hpp"
#include "ngraph/runtime/cpu/kernel/broadcast.hpp" #include "ngraph/runtime/cpu/kernel/broadcast.hpp"
#include "ngraph/runtime/cpu/kernel/ceil.hpp" #include "ngraph/runtime/cpu/kernel/ceil.hpp"
#include "ngraph/runtime/cpu/kernel/cwise_pow.hpp"
#include "ngraph/runtime/cpu/kernel/divide.hpp"
#include "ngraph/runtime/cpu/kernel/equal.hpp"
#include "ngraph/runtime/cpu/kernel/exp.hpp"
#include "ngraph/runtime/cpu/kernel/floor.hpp"
#include "ngraph/runtime/cpu/kernel/greater.hpp"
#include "ngraph/runtime/cpu/kernel/greater_eq.hpp"
#include "ngraph/runtime/cpu/kernel/less.hpp"
#include "ngraph/runtime/cpu/kernel/less_eq.hpp"
#include "ngraph/runtime/cpu/kernel/log.hpp"
#include "ngraph/runtime/cpu/kernel/maximum.hpp"
#include "ngraph/runtime/cpu/kernel/minimum.hpp"
#include "ngraph/runtime/cpu/kernel/multiply.hpp" #include "ngraph/runtime/cpu/kernel/multiply.hpp"
#include "ngraph/runtime/cpu/kernel/negative.hpp"
#include "ngraph/runtime/cpu/kernel/not.hpp"
#include "ngraph/runtime/cpu/kernel/not_equal.hpp"
#include "ngraph/runtime/cpu/kernel/relu.hpp" #include "ngraph/runtime/cpu/kernel/relu.hpp"
#include "ngraph/runtime/cpu/kernel/result.hpp" #include "ngraph/runtime/cpu/kernel/result.hpp"
#include "ngraph/runtime/cpu/op/batch_norm_relu.hpp" #include "ngraph/runtime/cpu/kernel/sqrt.hpp"
#include "ngraph/runtime/cpu/op/conv_bias.hpp" #include "ngraph/runtime/cpu/kernel/subtract.hpp"
#include "ngraph/runtime/cpu/op/conv_relu.hpp"
#include "ngraph/runtime/cpu/op/convert_layout.hpp" #include "ngraph/runtime/cpu/op/convert_layout.hpp"
#include "ngraph/runtime/cpu/op/lstm.hpp"
#include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/op/max_pool_with_indices.hpp"
#include "ngraph/runtime/cpu/op/rnn.hpp"
#include "ngraph/runtime/cpu/op/sigmoid.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
...@@ -119,39 +114,6 @@ ...@@ -119,39 +114,6 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
#define BUILD_UNARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, size_t)> kernel; \
\
SELECT_KERNEL(kernel, out[0].get_element_type(), OP); \
\
auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \
auto& out0_tensor = tensor_data[out[0].get_name()]; \
\
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, out0_tensor, element_count); \
}; \
functors.emplace_back(functor);
#define BUILD_BINARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, void*, size_t)> kernel; \
\
SELECT_KERNEL(kernel, out[0].get_element_type(), OP); \
\
auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \
auto& arg1_tensor = tensor_data[args[1].get_name()]; \
auto& out0_tensor = tensor_data[out[0].get_name()]; \
\
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count); \
}; \
functors.emplace_back(functor);
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -222,148 +184,21 @@ namespace ngraph ...@@ -222,148 +184,21 @@ namespace ngraph
} }
template <> template <>
void Builder::BUILDER_DECL(ngraph::op::MatmulBias) void Builder::BUILDER_DECL(ngraph::op::Exp)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& arg0_tensor = tensor_data[args[0].get_name()];
auto& arg1_tensor = tensor_data[args[1].get_name()];
auto& out0_tensor = tensor_data[out[0].get_name()];
const ngraph::op::MatmulBias* mm = static_cast<const ngraph::op::MatmulBias*>(node);
const auto& arg0_shape = mm->get_arg0_shape();
const auto& arg1_shape = mm->get_arg1_shape();
const auto& arg2_shape = node->get_shape();
auto m = arg0_shape[0];
auto n = arg1_shape[1];
auto k = arg0_shape[1];
bool transpose_A = false, transpose_B = false;
auto lda = arg0_shape[1];
auto ldb = arg1_shape[1];
if (mm->get_is_arg0_transposed())
{ {
transpose_A = true; BUILD_UNARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::exp);
m = arg0_shape[1];
k = arg0_shape[0];
} }
if (mm->get_is_arg1_transposed()) template <>
void Builder::BUILDER_DECL(ngraph::op::Log)
{ {
transpose_B = true; BUILD_UNARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::log);
n = arg1_shape[0];
} }
const float beta = 0.0f; template <>
void Builder::BUILDER_DECL(ngraph::op::Not)
auto mm_functor =
[&, transpose_A, transpose_B, m, n, k, lda, ldb, beta, arg2_shape](
CPURuntimeContext* ctx) {
cblas::cblas_sgemm(
cblas::Layout::RowMajor,
transpose_A ? cblas::Transpose::Transpose : cblas::Transpose::None,
transpose_B ? cblas::Transpose::Transpose : cblas::Transpose::None,
m,
n,
k,
1.0f,
static_cast<float*>(arg0_tensor),
max(1UL, lda),
static_cast<float*>(arg1_tensor),
max(1UL, ldb),
beta,
static_cast<float*>(out0_tensor),
max(1UL, arg2_shape[1]));
};
function<void(CPURuntimeContext*)> bias_functor = [](CPURuntimeContext* ctx) {};
if (args.size() > 2)
{
auto& arg2_tensor = tensor_data[args[2].get_name()];
auto axes = mm->get_broadcast_axes();
if (axes.size() == 1)
{
if (*(axes.begin()) == 0)
{
vector<float> ones_row(arg2_shape[0], 1.0f);
bias_functor = [&, ones_row, arg2_shape](CPURuntimeContext* ctx) {
cblas::cblas_sgemm(cblas::Layout::RowMajor,
cblas::Transpose::None,
cblas::Transpose::None,
arg2_shape[0],
arg2_shape[1],
1,
1.0f,
ones_row.data(),
1UL,
static_cast<float*>(arg2_tensor),
max(1UL, arg2_shape[1]),
1.0f,
static_cast<float*>(out0_tensor),
max(1UL, arg2_shape[1]));
};
}
else
{
vector<float> ones_col(arg2_shape[1], 1.0f);
bias_functor = [&, ones_col, arg2_shape](CPURuntimeContext* ctx) {
cblas::cblas_sgemm(cblas::Layout::RowMajor,
cblas::Transpose::None,
cblas::Transpose::None,
arg2_shape[0],
arg2_shape[1],
1,
1.0f,
static_cast<float*>(arg2_tensor),
1UL,
ones_col.data(),
max(1UL, arg2_shape[1]),
1.0f,
static_cast<float*>(out0_tensor),
max(1UL, arg2_shape[1]));
};
}
}
else
{
if (axes.size() != 2)
{ {
throw ngraph_error("unexpected broadcast rank"); BUILD_UNARY_ELEMWISE_FUNCTOR(runtime::cpu::kernel::logical_not);
}
vector<float> ones_scalar(arg2_shape[0], 1.0f);
bias_functor = [&, ones_scalar, arg2_shape](CPURuntimeContext* ctx) {
vector<float> bias(arg2_shape[1], *static_cast<float*>(arg2_tensor));
cblas::cblas_sgemm(cblas::Layout::RowMajor,
cblas::Transpose::None,
cblas::Transpose::None,
arg2_shape[0],
arg2_shape[1],
1,
1.0f,
ones_scalar.data(),
1UL,
bias.data(),
max(1UL, arg2_shape[1]),
1.0f,
static_cast<float*>(out0_tensor),
max(1UL, arg2_shape[1]));
};
}
}
auto functor = [&, mm_functor, bias_functor](CPURuntimeContext* ctx) {
mm_functor(ctx);
bias_functor(ctx);
};
functors.emplace_back(functor);
} }
template <> template <>
...@@ -393,35 +228,36 @@ namespace ngraph ...@@ -393,35 +228,36 @@ namespace ngraph
#define TI(x) type_index(typeid(x)) #define TI(x) type_index(typeid(x))
const BuildOpMap build_dispatcher{ BuildOpMap build_dispatcher{
{TI(ngraph::op::Add), &runtime::cpu::Builder::build<ngraph::op::Add>},
{TI(ngraph::op::Multiply), &runtime::cpu::Builder::build<ngraph::op::Multiply>},
{TI(ngraph::op::Parameter), &runtime::cpu::Builder::nop}, {TI(ngraph::op::Parameter), &runtime::cpu::Builder::nop},
{TI(ngraph::op::Abs), &runtime::cpu::Builder::build<ngraph::op::Abs>},
{TI(ngraph::op::AvgPool), &runtime::cpu::Builder::build<ngraph::op::AvgPool>},
{TI(ngraph::op::Broadcast), &runtime::cpu::Builder::build<ngraph::op::Broadcast>},
{TI(ngraph::op::Ceiling), &runtime::cpu::Builder::build<ngraph::op::Ceiling>},
{TI(ngraph::runtime::cpu::op::ConvertLayout), {TI(ngraph::runtime::cpu::op::ConvertLayout),
&runtime::cpu::Builder::build<ngraph::runtime::cpu::op::ConvertLayout>}, &runtime::cpu::Builder::build<ngraph::runtime::cpu::op::ConvertLayout>}};
{TI(ngraph::op::Convolution),
&runtime::cpu::Builder::build<ngraph::op::Convolution>}, REGISTER_OP_BUILDER(Constant);
{TI(ngraph::op::ConvolutionRelu), REGISTER_OP_BUILDER(Result);
&runtime::cpu::Builder::build<ngraph::op::ConvolutionRelu>}, REGISTER_OP_BUILDER(Add);
{TI(ngraph::op::ConvolutionBias), REGISTER_OP_BUILDER(Subtract);
&runtime::cpu::Builder::build<ngraph::op::ConvolutionBias>}, REGISTER_OP_BUILDER(Multiply);
{TI(ngraph::op::ConvolutionBiasAdd), REGISTER_OP_BUILDER(Divide);
&runtime::cpu::Builder::build<ngraph::op::ConvolutionBiasAdd>}, REGISTER_OP_BUILDER(Power);
{TI(ngraph::op::ConvolutionBackpropData), REGISTER_OP_BUILDER(Abs);
&runtime::cpu::Builder::build<ngraph::op::ConvolutionBackpropData>}, REGISTER_OP_BUILDER(Ceiling);
{TI(ngraph::op::ConvolutionBackpropFilters), REGISTER_OP_BUILDER(Floor);
&runtime::cpu::Builder::build<ngraph::op::ConvolutionBackpropFilters>}, REGISTER_OP_BUILDER(Negative);
{TI(ngraph::op::ConvolutionBiasBackpropFiltersBias), REGISTER_OP_BUILDER(Relu);
&runtime::cpu::Builder::build<ngraph::op::ConvolutionBiasBackpropFiltersBias>}, REGISTER_OP_BUILDER(Exp);
{TI(ngraph::op::Relu), &runtime::cpu::Builder::build<ngraph::op::Relu>}, REGISTER_OP_BUILDER(Log);
{TI(ngraph::op::Reshape), &runtime::cpu::Builder::build<ngraph::op::Reshape>}, REGISTER_OP_BUILDER(Sqrt);
{TI(ngraph::op::Result), &runtime::cpu::Builder::build<ngraph::op::Result>},
{TI(ngraph::op::MatmulBias), &runtime::cpu::Builder::build<ngraph::op::MatmulBias>}, REGISTER_OP_BUILDER(Not);
{TI(ngraph::op::Constant), &runtime::cpu::Builder::build<ngraph::op::Constant>}}; REGISTER_OP_BUILDER(Equal);
REGISTER_OP_BUILDER(NotEqual);
REGISTER_OP_BUILDER(Greater);
REGISTER_OP_BUILDER(GreaterEq);
REGISTER_OP_BUILDER(Less);
REGISTER_OP_BUILDER(LessEq);
REGISTER_OP_BUILDER(Maximum);
REGISTER_OP_BUILDER(Minimum);
} }
} }
} }
...@@ -157,6 +157,49 @@ ...@@ -157,6 +157,49 @@
SELECT_RANK(KV, uint64_t, R, K); \ SELECT_RANK(KV, uint64_t, R, K); \
} }
#define BUILD_UNARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, size_t)> kernel; \
\
SELECT_KERNEL(kernel, args[0].get_element_type(), OP); \
\
auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \
auto& out0_tensor = tensor_data[out[0].get_name()]; \
\
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, out0_tensor, element_count); \
}; \
functors.emplace_back(functor);
#define BUILD_BINARY_ELEMWISE_FUNCTOR(OP) \
auto& functors = external_function->get_functors(); \
auto& tensor_data = external_function->get_tensor_data(); \
std::function<void(void*, void*, void*, size_t)> kernel; \
\
SELECT_KERNEL(kernel, args[0].get_element_type(), OP); \
\
auto element_count = out[0].get_size(); \
auto& arg0_tensor = tensor_data[args[0].get_name()]; \
auto& arg1_tensor = tensor_data[args[1].get_name()]; \
auto& out0_tensor = tensor_data[out[0].get_name()]; \
\
auto functor = [&, kernel, element_count](CPURuntimeContext* ctx) { \
kernel(arg0_tensor, arg1_tensor, out0_tensor, element_count); \
}; \
functors.emplace_back(functor);
#define REGISTER_OP_BUILDER(OP) \
static struct __register_##OP##_builder \
{ \
__register_##OP##_builder() \
{ \
build_dispatcher.insert({type_index(typeid(ngraph::op::OP)), \
&runtime::cpu::Builder::build<ngraph::op::OP>}); \
} \
} __register_##OP##_builder_instance;
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -171,7 +214,7 @@ namespace ngraph ...@@ -171,7 +214,7 @@ namespace ngraph
using BuildOpMap = std::unordered_map<std::type_index, BuildOpFunction>; using BuildOpMap = std::unordered_map<std::type_index, BuildOpFunction>;
extern const BuildOpMap build_dispatcher; extern BuildOpMap build_dispatcher;
class Builder class Builder
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment