Commit 3aa4db1d authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Improve gather op lowering (#3667)

* [MLIR] Add support for parsing nGraph tensor type

Initial commit that enables nGraph parsing. It's needed for testing.

* [MLIR] Enable nGraph dialect in ngraph-opt

This PR registers nGraph dialect in ngraph-opt and prepares
nGraph lowering pass for LIT testing, fixing all the related issues.
Among other things, lowering pass has to be turned into a function pass,
dead argument in constructor was removed and `convert-ngraph-to-affine`
flag was added.

* Fix issue with function name and multiple functions

* Extend module_function.mlir lit test

* [MLIR] Add support for parsing nGraph element types

It introduces initial support for parsing nGraph signed/unsigned
integer and floating point data types.

* [MLIR] Improve gather op lowering

This PR interchanges indices and param loops in gather lowering so
that a better memory access patter is generated. Fusion of gather
with other ops is also observed with this change.
parent 8ccddb19
...@@ -737,28 +737,29 @@ namespace ...@@ -737,28 +737,29 @@ namespace
// Let indices rank : M // Let indices rank : M
// Let axis be A // Let axis be A
// Generate // Generate
// params loops // indices loops
// for P_0: 0 -> params.dim[0] // for I_0:0 -> indices.dim[0]
// for P_1: 0 -> params.dim[1]
// for P_2: 0 -> params.dim[2]
// ... // ...
// for P_(A-1):0 -> params.dim[A-1] // for I_(M-1):0 -> indices.dim[M-1]
// for P_(A+1):0 -> params.dim[A+1] // params loops
// for P_0: 0 -> params.dim[0]
// for P_1: 0 -> params.dim[1]
// for P_2: 0 -> params.dim[2]
// ... // ...
// for P_(N-1):0 -> params.dim[N-1] // for P_(A-1):0 -> params.dim[A-1]
// indices loops // for P_(A+1):0 -> params.dim[A+1]
// for I_0:0 -> indices.dim[0]
// ... // ...
// for I_(M-1):0 -> indices.dim[M-1] // for P_(N-1):0 -> params.dim[N-1]
// res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] = // res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] =
// params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)], // params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)],
// P_(A+1), ... P_(N-1)]; // P_(A+1), ... P_(N-1)];
LoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] { LoopNestBuilder(indicesIVPtrs, indicesLbs, indicesUbs, indicesSteps)([&] {
LoopNestBuilder(indicesIVPtrs, indicesLbs, indicesUbs, indicesSteps)([&] { // Load axis value from indices array and cast it to Index Type
// Load axis value from indices array and cast it to Index Type ValueHandle axisIdx = ValueHandle::create<IndexCastOp>(
ValueHandle axisIdx = ValueHandle::create<IndexCastOp>( (ValueHandle)iIndices(indicesIVs), rewriter.getIndexType());
(ValueHandle)iIndices(indicesIVs), rewriter.getIndexType());
LoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] {
// construct indices for param // construct indices for param
// [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1] // [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1]
for (auto i = 0, j = 0; i < vParams.rank(); i++) for (auto i = 0, j = 0; i < vParams.rank(); i++)
......
// RUN: ngraph-opt %s -convert-ngraph-to-affine -split-input-file | FileCheck %s
// Verify that core operations are properly converted to affine dialect.
// -----
// Gather Op
// CHECK: affine.for [[i:%.*]] = 0 to 16 {
// CHECK: [[L0:%.*]] = affine.load %{{.*\[}}[[i]]{{\]}}
// CHECK: [[GATHER_IDX:%.*]] = index_cast [[L0]]
// CHECK: affine.for [[j:%.*]] = 0 to 32 {
// CHECK: [[VALUE:%.*]] = load %{{.*\[}}[[GATHER_IDX]], [[j]]{{\]}}
// CHECK: affine.store [[VALUE]], %{{.*\[}}[[i]], [[j]]{{\]}}
func @simple_gather(%arg0: !ng.tensor<16x!ng.i64>, %arg1: !ng.tensor<512x32xf32>) -> !ng.tensor<16x32xf32> {
%0 = "ng.gather"(%arg1, %arg0) {axis = 0 : i64} : (!ng.tensor<512x32xf32>, !ng.tensor<16x!ng.i64>) -> !ng.tensor<16x32xf32>
"ng.return"(%0) : (!ng.tensor<16x32xf32>) -> ()
}
...@@ -83,4 +83,3 @@ func @u32(%arg0: !ng.u32) { ...@@ -83,4 +83,3 @@ func @u32(%arg0: !ng.u32) {
func @u64(%arg0: !ng.u64) { func @u64(%arg0: !ng.u64) {
"ng.return"() : () -> () "ng.return"() : () -> ()
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment