Commit e4cefb21 authored by Jaikrishnan Menon's avatar Jaikrishnan Menon Committed by Robert Kimball

DEX SGEMM batch for rank-3 rank-2 dot (#1425)

parent af618a2b
......@@ -18,6 +18,7 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/kernel/dot.hpp"
using namespace std;
......@@ -151,13 +152,86 @@ namespace ngraph
if ((arg0_shape.size() == 3) && (arg1_shape.size() == 2) &&
reduction_axes_count == 1)
{
std::function<decltype(runtime::cpu::kernel::dot_3d_2d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_3d_2d_1rd);
auto functor =
[&, kernel, arg0_shape, arg1_shape, result_shape](CPURuntimeContext* ctx) {
if (args[0].get_element_type() == element::f32)
{
auto shape_a = args[0].get_shape();
auto shape_b = args[1].get_shape();
const int64_t m = shape_a[1];
const int64_t k = shape_a[2];
const int64_t n = shape_b[1];
// this also works when mat_a is shape (1, m, k)
const int64_t offset_a = m * k;
// we do not offset mat_b
const int64_t offset_b = 0;
const int64_t offset_c = m * n;
const int64_t group_count = 1;
const int64_t group_size = shape_a[0];
auto functor =
[&, offset_a, offset_b, offset_c, m, n, k, group_size, group_count](
CPURuntimeContext* ctx) {
cblas::Transpose transpose = cblas::Transpose::None;
float alpha = 1.0f;
vector<const float*> a;
for (size_t i = 0; i < group_size; i++)
{
a.emplace_back(static_cast<const float*>(arg0_tensor) +
i * offset_a);
}
const float** a_array = a.data();
int64_t lda_array = std::max(1L, k);
vector<const float*> b;
for (size_t i = 0; i < group_size; i++)
{
b.emplace_back(static_cast<const float*>(arg1_tensor) +
i * offset_b);
}
const float** b_array = b.data();
const int64_t ldb_array = std::max(1L, n);
float beta = 0.0f;
vector<float*> c;
for (size_t i = 0; i < group_size; i++)
{
c.emplace_back(static_cast<float*>(out_tensor) + i * offset_c);
}
float** c_array = c.data();
const int64_t ldc_array = std::max(1L, n);
cblas_sgemm_batch(cblas::Layout::RowMajor,
&transpose,
&transpose,
&m,
&n,
&k,
&alpha,
a_array,
&lda_array,
b_array,
&ldb_array,
&beta,
c_array,
&ldc_array,
group_count,
&group_size);
};
functors.emplace_back(functor);
return;
}
else
{
std::function<decltype(runtime::cpu::kernel::dot_3d_2d_1rd<float>)> kernel;
SELECT_KERNEL(
kernel, out[0].get_element_type(), runtime::cpu::kernel::dot_3d_2d_1rd);
auto functor = [&, kernel, arg0_shape, arg1_shape, result_shape](
CPURuntimeContext* ctx) {
kernel(arg0_tensor,
arg1_tensor,
out_tensor,
......@@ -165,8 +239,9 @@ namespace ngraph
arg1_shape,
result_shape);
};
functors.emplace_back(functor);
return;
functors.emplace_back(functor);
return;
}
}
std::function<decltype(runtime::cpu::kernel::dot<float>)> kernel;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment