Commit fee3d1a7 authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MKLDNN] Emit dgemm for 2D DP FP Dot op (#3990)

* [MLIR] Update MLIR/LLVM repos

* Move MLIR/LLVM repos forward

This includes fix to affine fusion algorithm.

* Fix issues after merge

* Fix lit test

* [MKLDNN] Emit dgemm for 2D DP FP Dot op

Add support for emitting MKLDNN's double precision FP gemm from a 2D double
precision floating point Dot operation.

* Removed unnecessarily duplicated pattern

* Add f64 matmul support to CPU Emitter + unit test

* Add check for DP unsupported bias in cpu_fusion.cpp
parent a9a3ae79
......@@ -23,6 +23,7 @@
using namespace std;
using namespace ngraph;
using namespace ngraph::element;
namespace ngraph
{
......@@ -44,6 +45,9 @@ namespace ngraph
const auto& arg0_shape = mm->get_a_shape();
const auto& arg1_shape = mm->get_b_shape();
const auto& arg2_shape = node->get_shape();
const auto element_type = mm->get_input_element_type(0);
NGRAPH_CHECK(element_type == element::f32 || element_type == element::f64,
"MatmulBias element type not supported");
auto m = arg0_shape[0];
auto n = arg1_shape[1];
......@@ -80,8 +84,12 @@ namespace ngraph
arg2_shape,
arg0_buffer_index,
arg1_buffer_index,
out0_buffer_index](CPURuntimeContext* ctx,
out0_buffer_index,
element_type](CPURuntimeContext* ctx,
CPUExecutionContext* /* ectx */) {
switch (element_type)
{
case Type_t::f32:
cblas::cblas_sgemm(
cblas::Layout::RowMajor,
transpose_A ? cblas::Transpose::Transpose : cblas::Transpose::None,
......@@ -97,6 +105,26 @@ namespace ngraph
beta,
static_cast<float*>(ctx->buffer_data[out0_buffer_index]),
max<size_t>(1, arg2_shape[1]));
break;
case Type_t::f64:
cblas::cblas_dgemm(
cblas::Layout::RowMajor,
transpose_A ? cblas::Transpose::Transpose : cblas::Transpose::None,
transpose_B ? cblas::Transpose::Transpose : cblas::Transpose::None,
m,
n,
k,
1.0f,
static_cast<double*>(ctx->buffer_data[arg0_buffer_index]),
max<size_t>(1, lda),
static_cast<double*>(ctx->buffer_data[arg1_buffer_index]),
max<size_t>(1, ldb),
beta,
static_cast<double*>(ctx->buffer_data[out0_buffer_index]),
max<size_t>(1, arg2_shape[1]));
break;
default: NGRAPH_UNREACHABLE("Matmul element type is not supported");
}
};
CPUKernelFunctor bias_functor = [](CPURuntimeContext* /* ctx */,
......@@ -104,6 +132,8 @@ namespace ngraph
if (args.size() > 2)
{
NGRAPH_CHECK(element_type == element::f32,
"Bias element type is not supported");
auto arg2_buffer_index =
external_function->get_buffer_index(args[2].get_name());
......@@ -400,6 +430,6 @@ namespace ngraph
REGISTER_OP_BUILDER(BatchMatMul);
REGISTER_OP_BUILDER(BatchMatMulTranspose);
}
}
}
}
} // namespace cpu
} // namespace runtime
} // namespace ngraph
......@@ -945,7 +945,8 @@ namespace ngraph
dot->get_reduction_axes_count() == 1)
{
// Emit an MKL SGEMM call if possible
if (args[0].get_element_type() == element::f32)
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
writer.block_begin();
writer << "cblas::cblas_sgemm("
......@@ -960,6 +961,21 @@ namespace ngraph
<< ");\n";
writer.block_end();
}
else if (element_type == element::f64)
{
writer.block_begin();
writer << "cblas::cblas_dgemm("
<< "cblas::Layout::RowMajor, "
<< "cblas::Transpose::None, "
<< "cblas::Transpose::None, " << arg0_shape[0] << ", "
<< arg1_shape[1] << ", " << arg0_shape[1] << ",\n"
<< " 1.0f, " << args[0].get_name() << ", "
<< max(1UL, arg0_shape[1]) << ", " << args[1].get_name() << ", "
<< max(1UL, arg1_shape[1]) << ", 0.0f,\n"
<< " " << out[0].get_name() << ", " << max(1UL, arg1_shape[1])
<< ");\n";
writer.block_end();
}
else
{
writer.block_begin();
......
......@@ -92,6 +92,21 @@ namespace cblas
float* C,
const int64_t ldc);
void cblas_dgemm(const Layout layout,
const Transpose TransA,
const Transpose TransB,
const int64_t M,
const int64_t N,
const int64_t K,
const double alpha,
const double* A,
const int64_t lda,
const double* B,
const int64_t ldb,
const double beta,
double* C,
const int64_t ldc);
void cblas_sgemm_batch(const Layout Layout,
const Transpose* transa_array,
const Transpose* transb_array,
......
......@@ -166,6 +166,9 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
auto m_bias = m_broadcast->get_argument(0);
auto pattern_map = m.get_pattern_map();
NGRAPH_CHECK(mpattern->get_element_type() != element::f64 || m_bias == nullptr,
"Bias in DP MatMulBias is not supported yet");
auto mmb = std::make_shared<ngraph::op::MatmulBias>(pattern_map[W],
pattern_map[x],
m_bias,
......@@ -207,10 +210,12 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmul()
auto mpattern = m.get_match_root();
auto dot = m.get_match_root();
auto element_type = mpattern->get_element_type();
if (mpattern->get_element_type() != element::f32)
if (element_type != element::f32 && element_type != element::f64)
{
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name() << " type is not float!";
NGRAPH_DEBUG << "mpattern = " << mpattern->get_name()
<< " type is not float or double!";
return false;
}
......
......@@ -295,6 +295,25 @@ TEST(cpu_fusion, cpu_fusion_pass_basic)
ASSERT_NE(as_type_ptr<op::MatmulBias>(graph->get_argument(0)), nullptr);
}
TEST(cpu_fusion, matmul_f64)
{
Shape shape{};
Shape shape_w{2, 4};
Shape shape_x{4, 1};
Shape shape_b{1};
auto A = make_shared<op::Parameter>(element::f64, shape_w);
auto B = make_shared<op::Parameter>(element::f64, shape_x);
auto C = make_shared<op::Parameter>(element::f64, shape_b);
auto dot = make_shared<op::Dot>(A, B);
auto graph = make_shared<op::Abs>(dot);
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>(pass::FusionType::REGULAR_FUSIONS);
auto func = make_shared<Function>(graph, ParameterVector{A, B, C});
pass_manager.run_passes(func);
ASSERT_NE(as_type_ptr<op::MatmulBias>(graph->get_argument(0)), nullptr);
}
TEST(cpu_fusion, commutative_matmul_bias)
{
Shape shape{};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment