Commit 6bca3efd authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Refactoring MMB (#1224)

* rank3xrank2 cpu_emitter version 1

* refactoring matmulbias

* add comment
parent 5927bbe4
......@@ -39,8 +39,8 @@ namespace ngraph
const ngraph::op::MatmulBias* mm = static_cast<const ngraph::op::MatmulBias*>(node);
const auto& arg0_shape = mm->get_arg0_shape();
const auto& arg1_shape = mm->get_arg1_shape();
const auto& arg0_shape = mm->get_a_shape();
const auto& arg1_shape = mm->get_b_shape();
const auto& arg2_shape = node->get_shape();
auto m = arg0_shape[0];
......@@ -51,14 +51,14 @@ namespace ngraph
auto lda = arg0_shape[1];
auto ldb = arg1_shape[1];
if (mm->get_is_arg0_transposed())
if (mm->get_is_a_transposed())
{
transpose_A = true;
m = arg0_shape[1];
k = arg0_shape[0];
}
if (mm->get_is_arg1_transposed())
if (mm->get_is_b_transposed())
{
transpose_B = true;
n = arg1_shape[0];
......
......@@ -233,159 +233,19 @@ namespace ngraph
}
#endif
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::MatmulBias)
{
const ngraph::op::MatmulBias* cg = static_cast<const ngraph::op::MatmulBias*>(node);
const Shape& arg0_shape = cg->get_arg0_shape(); //W
const Shape& arg1_shape = cg->get_arg1_shape(); //x
const Shape& arg2_shape = node->get_shape(); //bias (C)
static const char* ctranspose = "cblas::Transpose::Transpose, ";
static const char* cnotranspose = "cblas::Transpose::None, ";
size_t m = arg0_shape[0];
size_t n = arg1_shape[1];
size_t k = arg0_shape[1];
//
const char* tranpose_a = cnotranspose;
const char* tranpose_b = cnotranspose;
size_t lda = arg0_shape[1];
size_t ldb = arg1_shape[1];
if (cg->get_is_arg0_transposed())
{
tranpose_a = ctranspose;
m = arg0_shape[1];
k = arg0_shape[0];
}
if (cg->get_is_arg1_transposed())
{
tranpose_b = ctranspose;
n = arg1_shape[0];
}
writer.block_begin();
const char* cbeta = "0.0f";
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << tranpose_a << tranpose_b << m << ", " << n
<< ", " << k << ",\n"
<< " 1.0f, " << args[0].get_name() << ", " << max(1UL, lda) << ", "
<< args[1].get_name() << ", " << max(1UL, ldb) << ", " << cbeta << ",\n"
<< " " << out[0].get_name() << ", " << max(1UL, arg2_shape[1])
<< ");\n";
if (args.size() > 2)
{
auto axes = cg->get_broadcast_axes();
if (axes.size() == 1)
{
if (*(axes.begin()) == 0)
{
writer << "static " << out[0].get_element_type().c_type_string()
<< " ones_row[" << arg2_shape[0] << "]"
<< " = { 1.0f";
for (size_t i = 1; i < arg2_shape[0]; ++i)
{
writer << ", 1.0f";
}
writer << "};\n";
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << cnotranspose << cnotranspose
<< arg2_shape[0] << ", " << arg2_shape[1] << ", 1"
<< ",\n"
<< " 1.0f, ones_row, "
<< "1"
<< ", " << args[2].get_name() << ", " << max(1UL, arg2_shape[1])
<< ", "
<< "1.0f"
<< ",\n"
<< " " << out[0].get_name() << ", "
<< max(1UL, arg2_shape[1]) << ");\n";
}
else
{
writer << "static " << out[0].get_element_type().c_type_string()
<< " ones_col[" << arg2_shape[1] << "]"
<< " = { 1.0f";
for (size_t i = 1; i < arg2_shape[1]; ++i)
{
writer << ", 1.0f";
}
writer << "};\n";
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << cnotranspose << cnotranspose
<< arg2_shape[0] << ", " << arg2_shape[1] << ", 1,\n"
<< "1.0f, " << args[2].get_name() << ", 1, "
<< "ones_col, " << max(1UL, arg2_shape[1]) << ", "
<< "1.0f"
<< ",\n"
<< " " << out[0].get_name() << ", "
<< max(1UL, arg2_shape[1]) << ");\n";
}
}
else
{
if (axes.size() != 2)
{
throw ngraph_error("unexpected broadcast rank");
}
writer << out[0].get_element_type().c_type_string() << " bias["
<< arg2_shape[1] << "]"
<< " = { " << args[2].get_name() << "[0]";
for (size_t i = 1; i < arg2_shape[1]; ++i)
{
writer << "," << args[2].get_name() << "[0]";
}
writer << "};\n";
writer << "static " << out[0].get_element_type().c_type_string()
<< " ones_scalar[" << arg2_shape[0] << "]"
<< " = { 1.0f";
for (size_t i = 1; i < arg2_shape[0]; ++i)
{
writer << ", 1.0f";
}
writer << "};\n";
writer << "cblas::cblas_sgemm("
<< "cblas::Layout::RowMajor, " << cnotranspose << cnotranspose
<< arg2_shape[0] << ", " << arg2_shape[1] << ", 1"
<< ",\n"
<< " 1.0f, ones_scalar, "
<< "1"
<< ", "
<< "bias"
<< ", " << max(1UL, arg2_shape[1]) << ", "
<< "1.0f"
<< ",\n"
<< " " << out[0].get_name() << ", " << max(1UL, arg2_shape[1])
<< ");\n";
}
}
writer.block_end();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchDot)
void emitCblasSgemmBatch(codegen::CodeWriter& writer,
const Shape& shape_a,
const Shape& shape_b,
const Shape& shape_c,
bool transpose_a,
bool transpose_b,
const std::string& data_a,
const std::string& data_b,
const std::string& data_c,
const std::string& alpha,
const std::string& beta,
size_t group_size)
{
const ngraph::op::BatchDot* batch_dot =
static_cast<const ngraph::op::BatchDot*>(node);
auto mat_a = args[0];
auto mat_b = args[1];
auto mat_c = out[0];
const Shape& shape_a = mat_a.get_shape();
const Shape& shape_b = mat_b.get_shape();
static const char* cblas_transpose = "cblas::Transpose::Transpose";
static const char* cblas_no_transpose = "cblas::Transpose::None";
......@@ -394,31 +254,30 @@ namespace ngraph
size_t n = shape_b[2];
size_t lda = std::max(1UL, k);
size_t ldb = std::max(1UL, n);
const char* transpose_a = cblas_no_transpose;
const char* transpose_b = cblas_no_transpose;
if (batch_dot->get_is_a_transposed())
const char* ctranspose_a = cblas_no_transpose;
const char* ctranspose_b = cblas_no_transpose;
if (transpose_a)
{
transpose_a = cblas_transpose;
ctranspose_a = cblas_transpose;
m = shape_a[2];
k = shape_a[1];
lda = std::max(1UL, m);
}
if (batch_dot->get_is_b_transposed())
if (transpose_b)
{
transpose_b = cblas_transpose;
ctranspose_b = cblas_transpose;
n = shape_b[1];
ldb = std::max(1UL, k);
}
size_t ldc = std::max(1UL, n);
const size_t offset_a = m * k;
const size_t offset_b = k * n;
const size_t offset_c = m * n;
const size_t offset_a = (shape_a.at(0) > 1) ? m * k : 0;
const size_t offset_b = (shape_b.at(0) > 1) ? k * n : 0;
const size_t offset_c = (shape_c.at(0) > 1) ? m * n : 0;
writer.block_begin();
const size_t group_count = 1;
const size_t group_size = shape_a[0];
auto populate_array =
[&writer](const std::string& var, size_t size, size_t offset) {
for (size_t i = 0; i < size; ++i)
......@@ -426,25 +285,23 @@ namespace ngraph
writer << var << "+" << i * offset << ((i < size - 1) ? ", " : "");
}
};
writer << "cblas::Transpose transa_array[] = {" << transpose_a << "};\n";
writer << "cblas::Transpose transb_array[] = {" << transpose_b << "};\n";
writer << "cblas::Transpose transa_array[] = {" << ctranspose_a << "};\n";
writer << "cblas::Transpose transb_array[] = {" << ctranspose_b << "};\n";
writer << "int64_t m_array[] = {" << m << "};\n";
writer << "int64_t n_array[] = {" << n << "};\n";
writer << "int64_t k_array[] = {" << k << "};\n";
writer << "float alpha_array[] = {1.0f};\n";
writer << "std::vector<const float*> a{";
populate_array(mat_a.get_name(), group_size, offset_a);
populate_array(data_a, group_size, offset_a);
writer << "};\n";
writer << "const float** a_array = &a[0];\n";
writer << "int64_t lda_array[] = {" << lda << "};\n";
writer << "std::vector<const float*> b{";
populate_array(mat_b.get_name(), group_size, offset_b);
populate_array(data_b, group_size, offset_b);
writer << "};\n";
writer << "const float** b_array = &b[0];\n";
writer << "int64_t ldb_array[] = {" << ldb << "};\n";
writer << "float beta_array[] = {0.0f};\n";
writer << "std::vector<float*> c{";
populate_array(mat_c.get_name(), group_size, offset_c);
populate_array(data_c, group_size, offset_c);
writer << "};\n";
writer << "float** c_array = &c[0];\n";
writer << "int64_t ldc_array[] = {" << ldc << "};\n";
......@@ -452,11 +309,210 @@ namespace ngraph
writer << "cblas_sgemm_batch(cblas::Layout::RowMajor, ";
writer << "transa_array, transb_array, m_array, n_array, k_array, \n";
writer << "alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, \n";
writer << alpha << ", a_array, lda_array, b_array, ldb_array, " << beta << ", \n";
writer << "c_array, ldc_array, " << group_count << ", group_size);\n";
writer.block_end();
}
template <typename T>
static void emitBatchDot(const ngraph::Node* node,
const Shape& shape_a,
const Shape& shape_b,
const Shape& shape_c,
const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out,
codegen::CodeWriter& writer)
{
writer.block_begin();
const T* batch_dot = static_cast<const T*>(node);
auto mat_a = args[0];
auto mat_b = args[1];
auto mat_c = out[0];
writer << "float alpha_array[] = {1.0f};\n";
writer << "float beta_array[] = {0.0f};\n";
const size_t group_size = shape_a[0];
emitCblasSgemmBatch(writer,
shape_a,
shape_b,
shape_c,
batch_dot->get_is_a_transposed(),
batch_dot->get_is_b_transposed(),
mat_a.get_name(),
mat_b.get_name(),
mat_c.get_name(),
"alpha_array",
"beta_array",
group_size);
writer.block_end();
}
static Shape pad_with(Shape v, size_t val, size_t length)
{
if (length <= v.size())
{
return v;
}
Shape tv(length - v.size(), val);
v.insert(v.begin(), tv.begin(), tv.end());
return v;
}
static std::string emit_constant_array(const std::string& type,
const std::string& name,
const string& val,
size_t size)
{
std::stringstream writer;
writer << "static " << type << " " << name << "[" << size << "]"
<< " = { " << val;
for (size_t i = 1; i < size; ++i)
{
writer << ", " << val;
}
writer << "};\n";
return writer.str();
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::MatmulBias)
{
const ngraph::op::MatmulBias* cg = static_cast<const ngraph::op::MatmulBias*>(node);
const Shape& arg0_shape = pad_with(cg->get_a_shape(), 1, 3); //A
const Shape& arg1_shape = pad_with(cg->get_b_shape(), 1, 3); //B
const Shape& arg2_shape = node->get_shape(); //bias (C)
const Shape& padded_result_shape = pad_with(node->get_shape(), 1, 3);
//Step 1: dot(A,B)
emitBatchDot<ngraph::op::MatmulBias>(
node, arg0_shape, arg1_shape, padded_result_shape, args, out, writer);
//Step 2: add bias
if (args.size() < 3)
{
//no bias
return;
}
auto mat_c = args[2];
//the bias argument of add(dot(A,B), broadcast(C)) is typically C
//In order to broadcast C to the same shape as dot(A,B)
//we use cblas_gemm_batch(ones, C) or cblas_gemm_batch(C, ones)
//where ones is a tensor of appropriate shape
//consisting of the identity element
// Consider an example of broadcasing a tensor of Shape{1,3}
// to Shape {4,3}
//
// [1 [1 2 3] [1 2 3
// 1 1 2 3
// 1 * 1 2 3
// 1] 1 2 3]
//The next example is broadcasting a tensor of Shape{3,1} to Shape {3,4}
//
// [1 [1 1 1 1] [1 1 1 1
// 2 * 2 2 2 2
// 3] 3 3 3 3]
writer << "float alpha_beta_array[] = {1.0f};\n";
const size_t group_size = 1;
auto axes = cg->get_broadcast_axes();
if (axes.size() == 1)
{
auto second_broadcast_axis = *axes.begin();
if (second_broadcast_axis == 0)
{
writer << emit_constant_array(out[0].get_element_type().c_type_string(),
"ones",
"1.0f",
arg2_shape.at(0));
;
emitCblasSgemmBatch(writer,
Shape{1, arg2_shape.at(0), 1}, // ones shape
Shape{1, 1, arg2_shape.at(1)}, // C shape
node->get_shape(),
false,
false,
"ones", // ones
mat_c.get_name(), // C
out[0].get_name(), // dot(A,B)
"alpha_beta_array",
"alpha_beta_array",
group_size);
}
else
{
writer << emit_constant_array(out[0].get_element_type().c_type_string(),
"ones",
"1.0f",
arg2_shape.at(1));
emitCblasSgemmBatch(writer,
Shape{1, arg2_shape.at(0), 1}, //C shape
Shape{1, 1, arg2_shape.at(1)}, // ones shape
node->get_shape(),
false, // C transpose
false, // C shape
mat_c.get_name(),
"ones",
out[0].get_name(), // dot(A,B)
"alpha_beta_array",
"alpha_beta_array",
group_size);
}
}
else
{
if (axes.size() != 2)
{
throw ngraph_error("unexpected broadcast rank");
}
writer << emit_constant_array(out[0].get_element_type().c_type_string(),
"ones",
"1.0f",
arg2_shape.at(1));
auto bias_scalar = args[2].get_name() + "[0]";
writer << emit_constant_array(out[0].get_element_type().c_type_string(),
"bias_vector",
bias_scalar,
arg2_shape.at(0));
emitCblasSgemmBatch(writer,
Shape{1, arg2_shape.at(0), 1}, // bias_vector shape
Shape{1, 1, arg2_shape.at(1)}, // ones shape
node->get_shape(),
false, // bias_vector tranpose
false, // ones tranpose
"bias_vector",
"ones",
out[0].get_name(), // dot(A,B)
"alpha_beta_array",
"alpha_beta_array",
group_size);
}
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::BatchDot)
{
const auto* cg = static_cast<const ngraph::op::BatchDot*>(node);
emitBatchDot<ngraph::op::BatchDot>(node,
cg->get_a_shape(),
cg->get_b_shape(),
out[0].get_shape(),
args,
out,
writer);
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Lstm)
{
......
......@@ -32,6 +32,8 @@ namespace ngraph
bool get_is_a_transposed() const { return m_transpose_a; }
bool get_is_b_transposed() const { return m_transpose_b; }
Shape get_a_shape() const { return get_argument(0)->get_shape(); }
Shape get_b_shape() const { return get_argument(1)->get_shape(); }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -35,10 +35,10 @@ namespace ngraph
bool transpose_x,
AxisSet axes = AxisSet{});
bool get_is_arg0_transposed() const { return m_transpose_w; }
bool get_is_arg1_transposed() const { return m_transpose_x; }
Shape get_arg0_shape() const { return m_shape_w; }
Shape get_arg1_shape() const { return m_shape_x; }
bool get_is_a_transposed() const { return m_transpose_w; }
bool get_is_b_transposed() const { return m_transpose_x; }
Shape get_a_shape() const { return m_shape_w; }
Shape get_b_shape() const { return m_shape_x; }
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -153,10 +153,10 @@ void ngraph::runtime::cpu::pass::CPUFusion::construct_matmulbias()
auto mmb = std::make_shared<op::MatmulBias>(pattern_map[W],
pattern_map[x],
m_bias,
m_matmul->get_arg0_shape(),
m_matmul->get_arg1_shape(),
m_matmul->get_is_arg0_transposed(),
m_matmul->get_is_arg1_transposed(),
m_matmul->get_a_shape(),
m_matmul->get_b_shape(),
m_matmul->get_is_a_transposed(),
m_matmul->get_is_b_transposed(),
m_broadcast->get_broadcast_axes());
ngraph::replace_node(m.get_match_root(), mmb);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment