Unverified Commit 2651f738 authored by Amy Zhuang's avatar Amy Zhuang Committed by GitHub

[MLIR] Use enum class for broadcast hint. (#4238)

Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent d3a8e62a
......@@ -1324,39 +1324,39 @@ namespace
}
attrs.gemmAttrs2d.ldc = attrs.gemmAttrs2d.n;
int broadcastHint = -2;
BroadcastType broadcastHint = BroadcastType::ERROR;
if (vBias.rank() == 0)
{
// Scalar
broadcastHint = 2;
broadcastHint = BroadcastType::ROWCOLUMN;
}
else if (vBias.rank() == 2)
{
if (biasShape[0] == attrs.gemmAttrs2d.m && biasShape[1] == 1)
{
broadcastHint = 1;
broadcastHint = BroadcastType::COLUMN;
}
else if (biasShape[0] == 1 && biasShape[1] == attrs.gemmAttrs2d.n)
{
broadcastHint = 0;
broadcastHint = BroadcastType::ROW;
}
else
else if (biasShape[0] == attrs.gemmAttrs2d.m && biasShape[1] == attrs.gemmAttrs2d.n)
{
broadcastHint = -1;
broadcastHint = BroadcastType::NONE;
}
}
else
{
if (biasShape[0] == attrs.gemmAttrs2d.m)
{
broadcastHint = 1;
broadcastHint = BroadcastType::COLUMN;
}
else if (biasShape[0] == attrs.gemmAttrs2d.n)
{
broadcastHint = 0;
broadcastHint = BroadcastType::ROW;
}
}
NGRAPH_CHECK(broadcastHint != -2, "Unhandled broadcast");
NGRAPH_CHECK(broadcastHint != BroadcastType::ERROR, "Unhandled broadcast");
attrs.gemmAttrs2d.broadcastHint = broadcastHint;
auto int64Ty = rewriter.getIntegerType(64);
......
......@@ -84,6 +84,15 @@ namespace ngraph
SOFTMAX
};
enum class BroadcastType
{
NONE,
ROW,
COLUMN,
ROWCOLUMN,
ERROR
};
// These structs and union are used to pass attributes to callbacks.
template <int N>
struct poolAttrs
......@@ -107,7 +116,7 @@ namespace ngraph
int64_t ldc;
float alpha;
float beta;
int64_t broadcastHint;
BroadcastType broadcastHint;
};
union opAttrs {
......
......@@ -636,7 +636,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));
if (broadcastHint == 0)
if (broadcastHint == BroadcastType::ROW)
{
std::vector<float> ones(m, 1.0f);
cblas::cblas_sgemm(cblas::Layout::RowMajor,
......@@ -654,7 +654,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));
}
else if (broadcastHint == 1)
else if (broadcastHint == BroadcastType::COLUMN)
{
std::vector<float> ones(n, 1.0f);
cblas::cblas_sgemm(cblas::Layout::RowMajor,
......@@ -672,7 +672,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));
}
else if (broadcastHint == 2)
else if (broadcastHint == BroadcastType::ROWCOLUMN)
{
std::vector<float> ones(m, 1.0f);
std::vector<float> bias(n, *matC);
......@@ -691,7 +691,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));
}
else
else if (broadcastHint == BroadcastType::NONE)
{
std::vector<float> identity(n * n, 0.0f);
for (auto i = 0; i < n * n; i += n + 1)
......@@ -713,6 +713,10 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));
}
else
{
NGRAPH_UNREACHABLE("Unsupported broadcast");
}
}
extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index, OpType type)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment