Commit 238ce788 authored by Louis Feng's avatar Louis Feng Committed by Scott Cyphers

Batch dot operation for rank 3 multiply with rank 2 tensors (#1180)

* hacking to support dot of 3 by 2 inputs with gemm_batch.

* clean up.
parent 9d09c7e5
......@@ -409,7 +409,7 @@ namespace ngraph
n = shape_b[1];
ldb = std::max(1UL, k);
}
size_t ldc = max(1UL, n);
size_t ldc = std::max(1UL, n);
const size_t offset_a = m * k;
const size_t offset_b = k * n;
const size_t offset_c = m * n;
......@@ -423,14 +423,7 @@ namespace ngraph
[&writer](const std::string& var, size_t size, size_t offset) {
for (size_t i = 0; i < size; ++i)
{
if (i < size - 1)
{
writer << var << "+" << i * offset << ", ";
}
else
{
writer << var << "+" << i * offset;
}
writer << var << "+" << i * offset << ((i < size - 1) ? ", " : "");
}
};
writer << "cblas::Transpose transa_array[] = {" << transpose_a << "};\n";
......@@ -991,6 +984,69 @@ namespace ngraph
writer.block_end();
}
}
// Specialized handling of rank 3 tensor multiply rank 2 tensor where
// each of the
else if ((arg0_shape.size() == 3) && (arg1_shape.size() == 2) &&
dot->get_reduction_axes_count() == 1 &&
args[0].get_element_type() == element::f32)
{
auto mat_a = args[0];
auto mat_b = args[1];
auto mat_c = out[0];
const Shape& shape_a = mat_a.get_shape();
const Shape& shape_b = mat_b.get_shape();
const size_t m = shape_a[1];
const size_t k = shape_a[2];
const size_t n = shape_b[1];
// this also works when mat_a is shape (1, m, k)
const size_t offset_a = m * k;
// we do not offset mat_b
const size_t offset_b = 0;
const size_t offset_c = m * n;
const size_t group_count = 1;
const size_t group_size = shape_a[0];
auto populate_array =
[&writer](const std::string& var, size_t size, size_t offset) {
for (size_t i = 0; i < size; ++i)
{
writer << var << "+" << i * offset << ((i < size - 1) ? ", " : "");
}
};
writer.block_begin();
writer << "cblas::Transpose transa_array[] = {cblas::Transpose::None};\n";
writer << "cblas::Transpose transb_array[] = {cblas::Transpose::None};\n";
writer << "int64_t m_array[] = {" << m << "};\n";
writer << "int64_t n_array[] = {" << n << "};\n";
writer << "int64_t k_array[] = {" << k << "};\n";
writer << "float alpha_array[] = {1.0f};\n";
writer << "std::vector<const float*> a{";
populate_array(mat_a.get_name(), group_size, offset_a);
writer << "};\n";
writer << "const float** a_array = &a[0];\n";
writer << "int64_t lda_array[] = {" << std::max(1UL, k) << "};\n";
writer << "std::vector<const float*> b{";
populate_array(mat_b.get_name(), group_size, offset_b);
writer << "};\n";
writer << "const float** b_array = &b[0];\n";
writer << "int64_t ldb_array[] = {" << std::max(1UL, n) << "};\n";
writer << "float beta_array[] = {0.0f};\n";
writer << "std::vector<float*> c{";
populate_array(mat_c.get_name(), group_size, offset_c);
writer << "};\n";
writer << "float** c_array = &c[0];\n";
writer << "int64_t ldc_array[] = {" << std::max(1UL, n) << "};\n";
writer << "int64_t group_size[] = {" << group_size << "};\n";
writer << "cblas_sgemm_batch(cblas::Layout::RowMajor, ";
writer << "transa_array, transb_array, m_array, n_array, k_array, \n";
writer << "alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, \n";
writer << "c_array, ldc_array, " << group_count << ", group_size);\n";
writer.block_end();
}
else
{
writer << "reference::dot(" << args[0].get_name() << ",\n";
......
......@@ -2634,3 +2634,34 @@ TEST(cpu_fusion, fuse_bounded_relu_inter_vs_cpu)
check_bounded_relu(Shape{4, 3}, 4.0f);
check_bounded_relu(Shape{4, 3, 2}, 2.0f);
}
TEST(cpu_fusion, dot_batch_forward)
{
const Shape shape_a{2, 3, 2};
const Shape shape_b{2, 3};
auto generate_func = [&shape_a, &shape_b]() -> shared_ptr<Function> {
auto a = make_shared<op::Parameter>(element::f32, shape_a);
auto b = make_shared<op::Parameter>(element::f32, shape_b);
auto dot = make_shared<op::Dot>(a, b);
return make_shared<Function>(dot, op::ParameterVector{a, b});
};
shared_ptr<Function> cpu_func = generate_func();
shared_ptr<Function> int_func = generate_func();
test::Uniform<float> rng(0.0f, 1.0f);
vector<vector<float>> args;
for (shared_ptr<op::Parameter> param : int_func->get_parameters())
{
vector<float> tensor_val(shape_size(param->get_shape()));
rng.initialize(tensor_val);
args.push_back(tensor_val);
}
auto int_results = execute(int_func, args, "INTERPRETER");
auto cpu_results = execute(cpu_func, args, "CPU");
for (size_t i = 0; i < cpu_results.size(); i++)
{
EXPECT_TRUE(test::all_close(cpu_results.at(i), int_results.at(i), 1.0e-4f, 1.0e-4f));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment