Commit aedd8c2e authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Enable affine dialect loop fusion (#3290)

* [MLIR] Enable affine dialect loop fusion

Enable affine dialect loop fusion in nGraph pipeline. It also adds an
opt flag to enable/disable it when ngraph-opt is in place. Fusion seems
to work for simple cases. It wasn't able to fuse dot + add, though, at
least in my test case. One example that worked:

Input:
  %6 = alloc() : memref<2500x2500xf32>
  affine.for %i3 = 0 to 2500 {
    affine.for %i4 = 0 to 2500 {
      %7 = load %arg0[%i3, %i4] : memref<2500x2500xf32>
      %8 = load %0[%i3, %i4] : memref<2500x2500xf32>
      %9 = addf %8, %7 : f32
      store %9, %6[%i3, %i4] : memref<2500x2500xf32>
    }
  }
  %10 = alloc() : memref<2500x2500xf32>
  affine.for %i5 = 0 to 2500 {
    affine.for %i6 = 0 to 2500 {
      %11 = load %arg2[%i5, %i6] : memref<2500x2500xf32>
      %12 = load %0[%i5, %i6] : memref<2500x2500xf32>
      %13 = addf %12, %11 : f32
      store %13, %10[%i5, %i6] : memref<2500x2500xf32>
    }
  }
  %14 = alloc() : memref<2500x2500xf32>
  affine.for %i7 = 0 to 2500 {
    affine.for %i8 = 0 to 2500 {
      %15 = load %10[%i7, %i8] : memref<2500x2500xf32>
      %16 = load %6[%i7, %i8] : memref<2500x2500xf32>
      %17 = addf %16, %15 : f32
      store %17, %14[%i7, %i8] : memref<2500x2500xf32>
    }
  }

Output:
  %8 = alloc() : memref<2500x2500xf32>
  affine.for %i3 = 0 to 2500 {
    affine.for %i4 = 0 to 2500 {
      %9 = load %arg2[%i3, %i4] : memref<2500x2500xf32>
      %10 = load %2[%i3, %i4] : memref<2500x2500xf32>
      %11 = addf %10, %9 : f32
      %12 = affine.apply #map2(%i3, %i4, %i3, %i4)
      %13 = affine.apply #map3(%i3, %i4, %i3, %i4)
      store %11, %0[%12, %13] : memref<1x1xf32>
      %14 = load %arg0[%i3, %i4] : memref<2500x2500xf32>
      %15 = load %2[%i3, %i4] : memref<2500x2500xf32>
      %16 = addf %15, %14 : f32
      %17 = affine.apply #map2(%i3, %i4, %i3, %i4)
      %18 = affine.apply #map3(%i3, %i4, %i3, %i4)
      store %16, %1[%17, %18] : memref<1x1xf32>
      %19 = affine.apply #map2(%i3, %i4, %i3, %i4)
      %20 = affine.apply #map3(%i3, %i4, %i3, %i4)
      %21 = load %0[%19, %20] : memref<1x1xf32>
      %22 = affine.apply #map2(%i3, %i4, %i3, %i4)
      %23 = affine.apply #map3(%i3, %i4, %i3, %i4)
      %24 = load %1[%22, %23] : memref<1x1xf32>
      %25 = addf %24, %21 : f32
      store %25, %8[%i3, %i4] : memref<2500x2500xf32>
    }
  }

* Rename MLIR_LLVM_OPTIONS to NGRAPH_MLIR_OPTIONS

Something like this works now:
NGRAPH_MLIR_OPTIONS="--enable-affine-loop-fusion=false"

* Disable loop fusion by default and fix typo
parent 862aa5fe
......@@ -70,6 +70,11 @@ using llvm::make_unique;
using llvm::ArrayRef;
using namespace ngraph::runtime::ngmlir;
static llvm::cl::opt<bool>
clEnableAffineLoopFusion("enable-affine-loop-fusion",
llvm::cl::init(false),
llvm::cl::desc("Enable loop fusion optimization in Affine dialect"));
#define COMPILE_OP_DECL(op_name) \
create_op<op_name>(MLIRCompiler & compiler, const ngraph::Node* ng_node)
......@@ -85,7 +90,7 @@ void MLIRCompiler::init_mlir()
{
mlir::registerDialect<mlir::NGraphOpsDialect>();
// Register any LLVM command line options
llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
llvm::cl::ParseEnvironmentOptions("ngraph", "NGRAPH_MLIR_OPTIONS", "");
initialized = true;
}
}
......@@ -248,7 +253,7 @@ void MLIRCompiler::lower_ng_dialect()
NGRAPH_CHECK(false, "Incorrect module after dialect lowering");
}
dump_mlir_module("Affine Dialect Dump:");
dump_mlir_module("Affine Dialect Dump (Pre-Optimizations):");
optimize();
......@@ -290,14 +295,26 @@ void MLIRCompiler::lower_ng_dialect()
void MLIRCompiler::optimize()
{
// Lower Affine to Std Dialect
mlir::PassManager pm;
// Lower affine ops
pm.addPass(mlir::createLowerAffinePass());
auto rr = pm.run(m_module.get());
NGRAPH_CHECK(succeeded(rr), "Affine loop lowering failed");
// Run Affine dialect optimizations.
mlir::PassManager pm_opts;
if (clEnableAffineLoopFusion)
{
pm_opts.addPass(mlir::createLoopFusionPass());
}
auto opt_res = pm_opts.run(m_module.get());
NGRAPH_CHECK(succeeded(opt_res), "Affine optimizations failed");
dump_mlir_module("Affine Dialect Dump (Post-Optimizations):");
// Run Affine dialect to Std dialect conversion.
mlir::PassManager pm_lowering;
pm_lowering.addPass(mlir::createLowerAffinePass());
auto lowering_res = pm_lowering.run(m_module.get());
NGRAPH_CHECK(succeeded(lowering_res), "Affine convertion to Std dialect failed");
dump_mlir_module("Standard Dialect Dump:");
// Run Std dialect optimizations.
// TODO
}
// MLIR builders
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment