Commit 21012673 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

DEX BatchDot (#1319)

* batchdot with debug statements

* clean up

* address feedback
parent 15d39100
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "ngraph/runtime/cpu/op/matmul_bias.hpp" #include "ngraph/runtime/cpu/op/matmul_bias.hpp"
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "ngraph/runtime/cpu/cpu_kernels.hpp" #include "ngraph/runtime/cpu/cpu_kernels.hpp"
#include "ngraph/runtime/cpu/op/batch_dot.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -172,7 +173,179 @@ namespace ngraph ...@@ -172,7 +173,179 @@ namespace ngraph
functors.emplace_back(functor); functors.emplace_back(functor);
} }
struct CblasGemmOptions
{
CblasGemmOptions(void*& da, void*& db, void*& dc)
: data_a(da)
, data_b(db)
, data_c(dc)
{
}
std::vector<cblas::Transpose> transa_array;
std::vector<cblas::Transpose> transb_array;
std::vector<int64_t> m_array;
std::vector<int64_t> n_array;
std::vector<int64_t> k_array;
std::vector<int64_t> lda_array;
std::vector<int64_t> ldb_array;
std::vector<int64_t> ldc_array;
std::vector<int64_t> group_sizes;
std::vector<float> alpha_array;
std::vector<float> beta_array;
size_t offset_a;
size_t offset_b;
size_t offset_c;
void*& data_a;
void*& data_b;
void*& data_c;
int64_t group_count;
void call(CPURuntimeContext* ctx)
{
std::vector<float*> a_array(group_sizes[0]);
std::vector<float*> b_array(group_sizes[0]);
std::vector<float*> c_array(group_sizes[0]);
auto populate_array = [](std::vector<float*>& offsets_vector,
void* data,
int64_t size,
size_t offset) {
for (size_t i = 0; i < size; ++i)
{
offsets_vector.at(i) = static_cast<float*>(data) + (i * offset);
}
};
populate_array(a_array, data_a, group_sizes[0], offset_a);
populate_array(b_array, data_b, group_sizes[0], offset_b);
populate_array(c_array, data_c, group_sizes[0], offset_c);
const float** a = const_cast<const float**>(&a_array[0]);
const float** b = const_cast<const float**>(&b_array[0]);
cblas_sgemm_batch(cblas::Layout::RowMajor,
&transa_array[0],
&transb_array[0],
&m_array[0],
&n_array[0],
&k_array[0],
&alpha_array[0],
a,
&lda_array[0],
b,
&ldb_array[0],
&beta_array[0],
&c_array[0],
&ldc_array[0],
group_count,
&group_sizes[0]);
}
};
static function<void(CPURuntimeContext*)> emitCblasSgemmBatch(const Shape& shape_a,
const Shape& shape_b,
const Shape& shape_c,
bool transpose_a,
bool transpose_b,
void*& data_a,
void*& data_b,
void*& data_c,
const float alpha,
const float beta,
size_t group_size)
{
size_t m = shape_a[1];
size_t k = shape_a[2];
size_t n = shape_b[2];
size_t lda = std::max(1UL, k);
size_t ldb = std::max(1UL, n);
cblas::Transpose ctranspose_a = cblas::Transpose::None;
cblas::Transpose ctranspose_b = cblas::Transpose::None;
if (transpose_a)
{
ctranspose_a = cblas::Transpose::Transpose;
m = shape_a[2];
k = shape_a[1];
lda = std::max(1UL, m);
}
if (transpose_b)
{
ctranspose_b = cblas::Transpose::Transpose;
n = shape_b[1];
ldb = std::max(1UL, k);
}
size_t ldc = std::max(1UL, n);
CblasGemmOptions options(data_a, data_b, data_c);
const size_t offset_a = (shape_a.at(0) > 1) ? m * k : 0;
const size_t offset_b = (shape_b.at(0) > 1) ? k * n : 0;
const size_t offset_c = (shape_c.at(0) > 1) ? m * n : 0;
options.offset_a = offset_a;
options.offset_b = offset_b;
options.offset_c = offset_c;
//if we were to support more groups
const size_t group_count = 1;
options.group_count = group_count;
options.transa_array.push_back(ctranspose_a);
options.transb_array.push_back(ctranspose_b);
options.m_array.push_back(m);
options.n_array.push_back(n);
options.k_array.push_back(k);
options.alpha_array.push_back(alpha);
options.beta_array.push_back(beta);
options.lda_array.push_back(lda);
options.ldb_array.push_back(ldb);
options.ldc_array.push_back(ldc);
options.group_sizes.push_back(group_size);
function<void(CPURuntimeContext*)> cblas_func =
[options](CPURuntimeContext* ctx) mutable { options.call(ctx); };
return cblas_func;
}
template <>
void Builder::BUILDER_DECL(ngraph::op::BatchDot)
{
auto& functors = external_function->get_functors();
auto& tensor_data = external_function->get_tensor_data();
auto& mat_a = tensor_data[args[0].get_name()];
auto& mat_b = tensor_data[args[1].get_name()];
auto& mat_c = tensor_data[out[0].get_name()];
const auto* cg = static_cast<const ngraph::op::BatchDot*>(node);
const auto& shape_a = cg->get_a_shape();
const auto& shape_b = cg->get_b_shape();
const auto& shape_c = out[0].get_shape();
const size_t group_size = shape_a.at(0);
auto func = emitCblasSgemmBatch(shape_a,
shape_b,
shape_c,
cg->get_is_a_transposed(),
cg->get_is_b_transposed(),
mat_a,
mat_b,
mat_c,
1.f,
0.f,
group_size);
functors.emplace_back(func);
}
REGISTER_OP_BUILDER(MatmulBias); REGISTER_OP_BUILDER(MatmulBias);
REGISTER_OP_BUILDER(BatchDot);
} }
} }
} }
...@@ -225,18 +225,18 @@ namespace ngraph ...@@ -225,18 +225,18 @@ namespace ngraph
} }
#endif #endif
void emitCblasSgemmBatch(codegen::CodeWriter& writer, static void emitCblasSgemmBatch(codegen::CodeWriter& writer,
const Shape& shape_a, const Shape& shape_a,
const Shape& shape_b, const Shape& shape_b,
const Shape& shape_c, const Shape& shape_c,
bool transpose_a, bool transpose_a,
bool transpose_b, bool transpose_b,
const std::string& data_a, const std::string& data_a,
const std::string& data_b, const std::string& data_b,
const std::string& data_c, const std::string& data_c,
const std::string& alpha, const std::string& alpha,
const std::string& beta, const std::string& beta,
size_t group_size) size_t group_size)
{ {
static const char* cblas_transpose = "cblas::Transpose::Transpose"; static const char* cblas_transpose = "cblas::Transpose::Transpose";
static const char* cblas_no_transpose = "cblas::Transpose::None"; static const char* cblas_no_transpose = "cblas::Transpose::None";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment