Commit 27199cee authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Change order of dot operation loop nest (#3748)

* Change order of gemm operation

* Add Dot lit test

* style-apply

* Fix captures to avoid escapes on []
parent 15c99fe5
...@@ -589,10 +589,14 @@ namespace ...@@ -589,10 +589,14 @@ namespace
IndexedValue iRes(result), iLhs(lhs), iRhs(rhs); IndexedValue iRes(result), iLhs(lhs), iRhs(rhs);
ValueHandle zeroInit(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elemTy))); ValueHandle zeroInit(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elemTy)));
{
IndexHandle n, k;
LoopBuilder(&n, nLb, nUb, nStep)(
[&] { LoopBuilder(&k, kLb, kUb, kStep)([&] { iRes(n, k) = zeroInit; }); });
}
LoopBuilder(&n, nLb, nUb, nStep)([&] { LoopBuilder(&n, nLb, nUb, nStep)([&] {
LoopBuilder(&k, kLb, kUb, kStep)([&] { LoopBuilder(&m, mLb, mUb, mStep)([&] {
iRes(n, k) = zeroInit; LoopBuilder(&k, kLb, kUb, kStep)([&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); });
LoopBuilder(&m, mLb, mUb, mStep)([&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); });
}); });
}); });
......
...@@ -5,13 +5,34 @@ ...@@ -5,13 +5,34 @@
// ----- // -----
// Gather Op // Gather Op
// CHECK: affine.for [[i:%.*]] = 0 to 16 { // CHECK: affine.for %[[I:.*]] = 0 to 16 {
// CHECK: [[L0:%.*]] = affine.load %{{.*\[}}[[i]]{{\]}} // CHECK: %[[L0:.*]] = affine.load %{{.*}}[%[[I]]]
// CHECK: [[GATHER_IDX:%.*]] = index_cast [[L0]] // CHECK: %[[GATHER_IDX:.*]] = index_cast %[[L0]]
// CHECK: affine.for [[j:%.*]] = 0 to 32 { // CHECK: affine.for %[[J:.*]] = 0 to 32 {
// CHECK: [[VALUE:%.*]] = load %{{.*\[}}[[GATHER_IDX]], [[j]]{{\]}} // CHECK: %[[VALUE:.*]] = load %{{.*}}[%[[GATHER_IDX]], %[[J]]]
// CHECK: affine.store [[VALUE]], %{{.*\[}}[[i]], [[j]]{{\]}} // CHECK: affine.store %[[VALUE]], {{.*}}[%[[I]], %[[J]]]
func @simple_gather(%arg0: !ng.tensor<16x!ng.i64>, %arg1: !ng.tensor<512x32xf32>) -> !ng.tensor<16x32xf32> { func @simple_gather(%arg0: !ng.tensor<16x!ng.i64>, %arg1: !ng.tensor<512x32xf32>) -> !ng.tensor<16x32xf32> {
%0 = "ng.gather"(%arg1, %arg0) {axis = 0 : i64} : (!ng.tensor<512x32xf32>, !ng.tensor<16x!ng.i64>) -> !ng.tensor<16x32xf32> %0 = "ng.gather"(%arg1, %arg0) {axis = 0 : i64} : (!ng.tensor<512x32xf32>, !ng.tensor<16x!ng.i64>) -> !ng.tensor<16x32xf32>
"ng.return"(%0) : (!ng.tensor<16x32xf32>) -> () "ng.return"(%0) : (!ng.tensor<16x32xf32>) -> ()
} }
// -----
// Dot Op
// CHECK: affine.for %[[I:.*]] = 0 to 16
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 32
// CHECK-NEXT: affine.store %{{.*}}, %[[RESULT:.*]][%[[I]], %[[J]]]
// CHECK: }
// CHECK-NEXT: }
// CHECK: affine.for %[[K:.*]] = 0 to 16
// CHECK-NEXT: affine.for {{%.*}} = 0 to 8
// CHECK-NEXT: affine.for %[[M:.*]] = 0 to 32
// CHECK: affine.load
// CHECK: affine.load
// CHECK: mulf
// CHECK: %[[R:.*]] = addf
// CHECK: affine.store %[[R]], %[[RESULT]][%[[K]], %[[M]]]
func @simple_dot(%arg0: !ng.tensor<16x8xf32>, %arg1: !ng.tensor<8x32xf32>) -> !ng.tensor<16x32xf32> {
%0 = "ng.dot"(%arg0, %arg1) : (!ng.tensor<16x8xf32>, !ng.tensor<8x32xf32>) -> !ng.tensor<16x32xf32>
"ng.return"(%0) : (!ng.tensor<16x32xf32>) -> ()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment