Commit a2b9c6b8 authored by nmostafa's avatar nmostafa

Refactor optimize() back

parent 6a7e1f24
...@@ -234,31 +234,20 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens ...@@ -234,31 +234,20 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens
void MLIRCompiler::lower_ng_dialect() void MLIRCompiler::lower_ng_dialect()
{ {
// Lower NG dialect to Affine // Lower NG dialect to Affine
{ mlir::PassManager pm;
mlir::PassManager pm; pm.addPass(mlir::createDialectLoweringPass(this));
pm.addPass(mlir::createDialectLoweringPass(this)); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.run(m_module.get());
if (failed(m_module->verify())) pm.run(m_module.get());
{
NGRAPH_CHECK(false, "Incorrect module after dialect lowering");
}
dump_mlir_module("Affine Dialect Dump:"); if (failed(m_module->verify()))
{
NGRAPH_CHECK(false, "Incorrect module after dialect lowering");
} }
// Lower Affine to Std Dialect dump_mlir_module("Affine Dialect Dump:");
{
mlir::PassManager pm;
// Lower affine ops
pm.addPass(mlir::createLowerAffinePass());
auto rr = pm.run(m_module.get());
NGRAPH_CHECK(succeeded(rr), "Affine loop lowering failed");
dump_mlir_module("Standard Dialect Dump:"); optimize();
}
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
...@@ -296,6 +285,18 @@ void MLIRCompiler::lower_ng_dialect() ...@@ -296,6 +285,18 @@ void MLIRCompiler::lower_ng_dialect()
m_engine = std::move(maybeEngine.get()); m_engine = std::move(maybeEngine.get());
} }
void MLIRCompiler::optimize()
{
// Lower Affine to Std Dialect
mlir::PassManager pm;
// Lower affine ops
pm.addPass(mlir::createLowerAffinePass());
auto rr = pm.run(m_module.get());
NGRAPH_CHECK(succeeded(rr), "Affine loop lowering failed");
dump_mlir_module("Standard Dialect Dump:");
}
// MLIR builders // MLIR builders
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
......
...@@ -67,6 +67,7 @@ namespace ngraph ...@@ -67,6 +67,7 @@ namespace ngraph
/// Compiles a subgraph with MLIR /// Compiles a subgraph with MLIR
void compile(); void compile();
/// Executes a pre-compiled subgraph /// Executes a pre-compiled subgraph
void run(std::vector<void*>& external_tensors); void run(std::vector<void*>& external_tensors);
...@@ -88,6 +89,7 @@ namespace ngraph ...@@ -88,6 +89,7 @@ namespace ngraph
private: private:
void build_ng_dialect_module(); void build_ng_dialect_module();
void lower_ng_dialect(); void lower_ng_dialect();
void optimize();
void bind_arguments(std::vector<void*>& external_tensors); void bind_arguments(std::vector<void*>& external_tensors);
void execute(); void execute();
void cleanup(); void cleanup();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment