Unverified Commit 2651f738 authored by Amy Zhuang's avatar Amy Zhuang Committed by GitHub

[MLIR] Use enum class for broadcast hint. (#4238)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d3a8e62a
...@@ -1324,39 +1324,39 @@ namespace ...@@ -1324,39 +1324,39 @@ namespace
} }
attrs.gemmAttrs2d.ldc = attrs.gemmAttrs2d.n; attrs.gemmAttrs2d.ldc = attrs.gemmAttrs2d.n;
int broadcastHint = -2; BroadcastType broadcastHint = BroadcastType::ERROR;
if (vBias.rank() == 0) if (vBias.rank() == 0)
{ {
// Scalar // Scalar
broadcastHint = 2; broadcastHint = BroadcastType::ROWCOLUMN;
} }
else if (vBias.rank() == 2) else if (vBias.rank() == 2)
{ {
if (biasShape[0] == attrs.gemmAttrs2d.m && biasShape[1] == 1) if (biasShape[0] == attrs.gemmAttrs2d.m && biasShape[1] == 1)
{ {
broadcastHint = 1; broadcastHint = BroadcastType::COLUMN;
} }
else if (biasShape[0] == 1 && biasShape[1] == attrs.gemmAttrs2d.n) else if (biasShape[0] == 1 && biasShape[1] == attrs.gemmAttrs2d.n)
{ {
broadcastHint = 0; broadcastHint = BroadcastType::ROW;
} }
else else if (biasShape[0] == attrs.gemmAttrs2d.m && biasShape[1] == attrs.gemmAttrs2d.n)
{ {
broadcastHint = -1; broadcastHint = BroadcastType::NONE;
} }
} }
else else
{ {
if (biasShape[0] == attrs.gemmAttrs2d.m) if (biasShape[0] == attrs.gemmAttrs2d.m)
{ {
broadcastHint = 1; broadcastHint = BroadcastType::COLUMN;
} }
else if (biasShape[0] == attrs.gemmAttrs2d.n) else if (biasShape[0] == attrs.gemmAttrs2d.n)
{ {
broadcastHint = 0; broadcastHint = BroadcastType::ROW;
} }
} }
NGRAPH_CHECK(broadcastHint != -2, "Unhandled broadcast"); NGRAPH_CHECK(broadcastHint != BroadcastType::ERROR, "Unhandled broadcast");
attrs.gemmAttrs2d.broadcastHint = broadcastHint; attrs.gemmAttrs2d.broadcastHint = broadcastHint;
auto int64Ty = rewriter.getIntegerType(64); auto int64Ty = rewriter.getIntegerType(64);
......
...@@ -84,6 +84,15 @@ namespace ngraph ...@@ -84,6 +84,15 @@ namespace ngraph
SOFTMAX SOFTMAX
}; };
enum class BroadcastType
{
NONE,
ROW,
COLUMN,
ROWCOLUMN,
ERROR
};
// These structs and union are used to pass attributes to callbacks. // These structs and union are used to pass attributes to callbacks.
template <int N> template <int N>
struct poolAttrs struct poolAttrs
...@@ -107,7 +116,7 @@ namespace ngraph ...@@ -107,7 +116,7 @@ namespace ngraph
int64_t ldc; int64_t ldc;
float alpha; float alpha;
float beta; float beta;
int64_t broadcastHint; BroadcastType broadcastHint;
}; };
union opAttrs { union opAttrs {
......
...@@ -636,7 +636,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA, ...@@ -636,7 +636,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut, matOut,
std::max<size_t>(1, ldc)); std::max<size_t>(1, ldc));
if (broadcastHint == 0) if (broadcastHint == BroadcastType::ROW)
{ {
std::vector<float> ones(m, 1.0f); std::vector<float> ones(m, 1.0f);
cblas::cblas_sgemm(cblas::Layout::RowMajor, cblas::cblas_sgemm(cblas::Layout::RowMajor,
...@@ -654,7 +654,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA, ...@@ -654,7 +654,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut, matOut,
std::max<size_t>(1, ldc)); std::max<size_t>(1, ldc));
} }
else if (broadcastHint == 1) else if (broadcastHint == BroadcastType::COLUMN)
{ {
std::vector<float> ones(n, 1.0f); std::vector<float> ones(n, 1.0f);
cblas::cblas_sgemm(cblas::Layout::RowMajor, cblas::cblas_sgemm(cblas::Layout::RowMajor,
...@@ -672,7 +672,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA, ...@@ -672,7 +672,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut, matOut,
std::max<size_t>(1, ldc)); std::max<size_t>(1, ldc));
} }
else if (broadcastHint == 2) else if (broadcastHint == BroadcastType::ROWCOLUMN)
{ {
std::vector<float> ones(m, 1.0f); std::vector<float> ones(m, 1.0f);
std::vector<float> bias(n, *matC); std::vector<float> bias(n, *matC);
...@@ -691,7 +691,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA, ...@@ -691,7 +691,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut, matOut,
std::max<size_t>(1, ldc)); std::max<size_t>(1, ldc));
} }
else else if (broadcastHint == BroadcastType::NONE)
{ {
std::vector<float> identity(n * n, 0.0f); std::vector<float> identity(n * n, 0.0f);
for (auto i = 0; i < n * n; i += n + 1) for (auto i = 0; i < n * n; i += n + 1)
...@@ -713,6 +713,10 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA, ...@@ -713,6 +713,10 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut, matOut,
std::max<size_t>(1, ldc)); std::max<size_t>(1, ldc));
} }
else
{
NGRAPH_UNREACHABLE("Unsupported broadcast");
}
} }
extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index, OpType type) extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index, OpType type)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment