Unverified Commit 2574be4d authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3137 from NervanaSystems/dcaballe/mlir_bump

[MLIR] Bump MLIR repo to commit 82d5084 Wed Jun 26.
parents 05195925 c5e06cc3
...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git) ...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git) set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions # Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID bb2b527) set(MLIR_LLVM_COMMIT_ID c0cad98)
set(MLIR_COMMIT_ID 49f7efc) set(MLIR_COMMIT_ID 82d5084)
set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project) set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project)
set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects) set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects)
set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir) set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir)
......
...@@ -56,6 +56,7 @@ if (NGRAPH_MLIR_ENABLE) ...@@ -56,6 +56,7 @@ if (NGRAPH_MLIR_ENABLE)
MLIRExecutionEngine MLIRExecutionEngine
MLIRIR MLIRIR
MLIRLLVMIR MLIRLLVMIR
MLIRStandardToLLVM
MLIRParser MLIRParser
MLIRPass MLIRPass
MLIRTargetLLVMIR MLIRTargetLLVMIR
......
...@@ -34,11 +34,12 @@ ...@@ -34,11 +34,12 @@
#include <llvm/Support/MemoryBuffer.h> #include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetSelect.h> #include <llvm/Support/TargetSelect.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h> #include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/MemRefUtils.h> #include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/ExecutionEngine/OptUtils.h> #include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/LLVMIR/LLVMDialect.h> #include <mlir/LLVMIR/LLVMDialect.h>
#include <mlir/LLVMIR/Transforms.h>
#include <mlir/Pass/PassManager.h> #include <mlir/Pass/PassManager.h>
#include <mlir/Target/LLVMIR.h> #include <mlir/Target/LLVMIR.h>
#include <mlir/Transforms/DialectConversion.h> #include <mlir/Transforms/DialectConversion.h>
...@@ -50,6 +51,7 @@ ...@@ -50,6 +51,7 @@
using llvm::SmallVector; using llvm::SmallVector;
using llvm::StringRef; using llvm::StringRef;
using llvm::make_unique; using llvm::make_unique;
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
#define COMPILE_OP_DECL(op_name) \ #define COMPILE_OP_DECL(op_name) \
...@@ -75,7 +77,7 @@ void MLIRCompiler::init_mlir() ...@@ -75,7 +77,7 @@ void MLIRCompiler::init_mlir()
if (!initialized) if (!initialized)
{ {
mlir::registerDialect<mlir::NGDialect>(); mlir::registerDialect<mlir::NGraphOpsDialect>();
// Register any LLVM command line options // Register any LLVM command line options
llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", ""); llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
initialized = true; initialized = true;
...@@ -133,7 +135,7 @@ void MLIRCompiler::build_ng_dialect_module() ...@@ -133,7 +135,7 @@ void MLIRCompiler::build_ng_dialect_module()
} }
// create builder // create builder
m_builder = llvm::make_unique<mlir::FuncBuilder>(function.get()); m_builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
build_ng_dialect(); build_ng_dialect();
m_module->getFunctions().push_back(function.release()); m_module->getFunctions().push_back(function.release());
if (failed(m_module->verify())) if (failed(m_module->verify()))
...@@ -359,10 +361,14 @@ void MLIRCompiler::execute() ...@@ -359,10 +361,14 @@ void MLIRCompiler::execute()
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect. // Lower Standard dialect to LLVM dialect.
auto converter = mlir::createStdToLLVMConverter(); mlir::LLVMTypeConverter llvm_converter(&m_context);
auto r = converter->convert(m_module.get()); OwningRewritePatternList patterns;
(void)r; mlir::populateStdToLLVMConversionPatterns(llvm_converter, patterns);
NGRAPH_CHECK(succeeded(r), "second conversion failed");
mlir::ConversionTarget target(m_context);
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
auto result = applyConversionPatterns(*m_module, target, llvm_converter, std::move(patterns));
NGRAPH_CHECK(succeeded(result), "Standard to LLVM dialect conversion failed");
dump_mlir_module("LLVM-IR Dialect Dump:"); dump_mlir_module("LLVM-IR Dialect Dump:");
......
...@@ -132,7 +132,7 @@ namespace ngraph ...@@ -132,7 +132,7 @@ namespace ngraph
mlir::MLIRContext m_context; mlir::MLIRContext m_context;
std::unique_ptr<mlir::Module> m_module; std::unique_ptr<mlir::Module> m_module;
std::unique_ptr<mlir::FuncBuilder> m_builder; std::unique_ptr<mlir::OpBuilder> m_builder;
std::unique_ptr<mlir::ExecutionEngine> m_engine; std::unique_ptr<mlir::ExecutionEngine> m_engine;
using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>; using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>;
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
using namespace mlir; using namespace mlir;
NGDialect::NGDialect(mlir::MLIRContext* ctx) NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("ng", ctx) : mlir::Dialect(getDialectNamespace(), ctx)
{ {
addTypes<NGTensorType>(); addTypes<NGTensorType>();
addTypes<NGIntegerType>(); addTypes<NGIntegerType>();
...@@ -34,7 +34,7 @@ NGDialect::NGDialect(mlir::MLIRContext* ctx) ...@@ -34,7 +34,7 @@ NGDialect::NGDialect(mlir::MLIRContext* ctx)
>(); >();
} }
void NGDialect::printType(mlir::Type type, raw_ostream& os) const void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const
{ {
switch (type.getKind()) switch (type.getKind())
{ {
......
...@@ -25,15 +25,17 @@ ...@@ -25,15 +25,17 @@
#include "ngraph/check.hpp" #include "ngraph/check.hpp"
namespace mlir namespace mlir
{ {
class NGDialect : public mlir::Dialect class NGraphOpsDialect : public mlir::Dialect
{ {
public: public:
explicit NGDialect(mlir::MLIRContext* ctx); explicit NGraphOpsDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{ {
NGRAPH_CHECK(false, "Unsupported type parsing."); NGRAPH_CHECK(false, "Unsupported type parsing.");
return mlir::Type(); return mlir::Type();
} }
void printType(mlir::Type type, llvm::raw_ostream& os) const override; void printType(mlir::Type type, llvm::raw_ostream& os) const override;
static StringRef getDialectNamespace() { return "ng"; }
}; };
} }
...@@ -41,31 +41,34 @@ namespace ...@@ -41,31 +41,34 @@ namespace
class DialectLoweringPass; class DialectLoweringPass;
/// Base class for nGraph operation conversions to affine/standard dialect. Provides
/// conversion patterns with an access to the DialectLoweringPass which holds the state of the
/// conversion.
class NGraphOpLowering : public ConversionPattern
{
public:
NGraphOpLowering(StringRef rootOpName, MLIRContext* context, DialectLoweringPass& pass)
: ConversionPattern(rootOpName, /*benefit=*/1, context)
, m_pass(pass){};
protected:
// Back-reference to the lowering pass which contains the lowering state, including the
// nGraph type converter.
DialectLoweringPass& m_pass;
};
#include "op_lowerers.inc" #include "op_lowerers.inc"
/// Use Dialect Converson Framework /// Conversion from types in the nGraph dialect to the Standard dialect.
class DialectLowerer : public DialectConversion class NGraphTypeConverter : public TypeConverter
{ {
public: public:
DialectLowerer(DialectLoweringPass& pass) NGraphTypeConverter()
: DialectConversion() : TypeConverter()
, m_pass(pass)
{ {
} }
Type convertType(Type t) override; Type convertType(Type t) override;
protected:
// Initialize the list of converters.
void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override
{
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
patterns, mlirContext, m_pass);
}
private:
DialectLoweringPass& m_pass;
llvm::BumpPtrAllocator allocator;
}; };
/// Dialect Lowering Pass to affine ops /// Dialect Lowering Pass to affine ops
...@@ -73,14 +76,17 @@ namespace ...@@ -73,14 +76,17 @@ namespace
{ {
public: public:
DialectLoweringPass(ngmlir::MLIRCompiler& compiler) DialectLoweringPass(ngmlir::MLIRCompiler& compiler)
: m_dialectLowerer(*this) : m_compiler(compiler)
, m_compiler(compiler)
{ {
} }
void runOnModule() override; void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter); SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
private: private:
/// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);
mlir::Function* getCallDecl(StringRef name, mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args, ArrayRef<Type> args,
ArrayRef<Type> output, ArrayRef<Type> output,
...@@ -90,7 +96,7 @@ namespace ...@@ -90,7 +96,7 @@ namespace
Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr); Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
private: private:
DialectLowerer m_dialectLowerer; NGraphTypeConverter m_typeConverter;
// Value holding mem manager passed pointer // Value holding mem manager passed pointer
SmallVector<Value*, 4> m_memMgrDefs; SmallVector<Value*, 4> m_memMgrDefs;
...@@ -101,21 +107,39 @@ namespace ...@@ -101,21 +107,39 @@ namespace
void DialectLoweringPass::runOnModule() void DialectLoweringPass::runOnModule()
{ {
// Create type converter and initialize conversion patterns.
NGraphTypeConverter converter;
OwningRewritePatternList patterns;
populateNGraphToAffineConversionPatterns(patterns);
// Create target that defines legal ops for nGraph dialect to be lowered to.
ConversionTarget target(getContext());
// TODO: Remove NGFakeInputOp. We need to set NGFakeInputOp as legal op because we generate
// it as part of the lowering to affine/standard.
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
target.addLegalOp<NGFakeInputOp>();
// capture output values by looking for the Return and grabbing the values // capture output values by looking for the Return and grabbing the values
// the order of the returned values matches the order of the lowered func signature for // the order of the returned values matches the order of the lowered func signature for
// results. This is used to find the arg_id that a defined value maps to if it is an output // results. This is used to find the arg_id that a defined value maps to if it is an output
findOutputValues(); findOutputValues();
if (failed(m_dialectLowerer.convert(&getModule()))) if (failed(applyConversionPatterns(getModule(), target, converter, std::move(patterns))))
{ {
getModule().getContext()->emitError(mlir::UnknownLoc::get(getModule().getContext()), emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering nGraph dialect\n");
"Error lowering dialect\n");
signalPassFailure(); signalPassFailure();
} }
processFakeInstrs(); processFakeInstrs();
} }
void DialectLoweringPass::populateNGraphToAffineConversionPatterns(
OwningRewritePatternList& patterns)
{
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
patterns, &getContext(), *this);
}
void DialectLoweringPass::findOutputValues() void DialectLoweringPass::findOutputValues()
{ {
// get original function // get original function
...@@ -138,6 +162,9 @@ namespace ...@@ -138,6 +162,9 @@ namespace
outputCount = ret.getNumOperands(); outputCount = ret.getNumOperands();
}); });
// will be populated with lowered output values later // will be populated with lowered output values later
// TODO: This resize is making debugging obscure. When the container is not populated due
// to a bug, null pointers are used by the consumer leading to a crash more difficult to
// root-cause. We should try to change the current approach or introduce verification code.
m_loweredOutputValues.resize(outputCount, nullptr); m_loweredOutputValues.resize(outputCount, nullptr);
} }
...@@ -146,10 +173,11 @@ namespace ...@@ -146,10 +173,11 @@ namespace
{ {
// it would be nice to insert one fake def at the start of the new func // it would be nice to insert one fake def at the start of the new func
// however, due to how DialectConversion framework works, new func is only // however, due to how DialectConversion framework works, new func is only
// materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction() // materialized after conversion is done (rewriter->getFunction, or even
// will give you the original func). This makes it very convoluted to insert instructions at entry block. // rewriter->getInsertionBlock()->getFunction() will give you the original func). This
// makes it very convoluted to insert instructions at entry block.
auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(), auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(),
IndexType::get(getModule().getContext())); IndexType::get(&getContext()));
// will be fixed later to read passed arg instead. // will be fixed later to read passed arg instead.
m_memMgrDefs.push_back(op.getResult()); m_memMgrDefs.push_back(op.getResult());
return op.getResult(); return op.getResult();
...@@ -167,8 +195,7 @@ namespace ...@@ -167,8 +195,7 @@ namespace
unsigned argId = (int)attr.getInt(); unsigned argId = (int)attr.getInt();
auto fakeOp = rewriter.create<NGFakeInputOp>( auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(), op->getLoc(),
m_dialectLowerer.convertType( m_typeConverter.convertType(origResult->getType()) /* convert to lowered type */
origResult->getType()) /* convert to lowered type */
); );
// Fake instrution is short-lived. Verify here. // Fake instrution is short-lived. Verify here.
fakeOp.verify(); fakeOp.verify();
...@@ -181,7 +208,7 @@ namespace ...@@ -181,7 +208,7 @@ namespace
auto tensorType = origResult->getType().cast<NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate", auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()}, {rewriter.getIndexType(), rewriter.getIndexType()},
{m_dialectLowerer.convertType(tensorType)}, {m_typeConverter.convertType(tensorType)},
rewriter); rewriter);
auto size = tensorType.getSizeInBytes(); auto size = tensorType.getSizeInBytes();
...@@ -261,10 +288,10 @@ namespace ...@@ -261,10 +288,10 @@ namespace
return callBackFuncPtr; return callBackFuncPtr;
} }
// NGDialect converters // NGDialect converters
Type DialectLowerer::convertType(Type type) Type NGraphTypeConverter::convertType(Type type)
{ {
// We may need to refactor this code to a external utility if type conversion is needed // We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since DialectLowerer is private. // outside of the lowering context since NGraphTypeConverter is private.
if (auto tensor_type = type.dyn_cast<NGTensorType>()) if (auto tensor_type = type.dyn_cast<NGTensorType>())
{ {
...@@ -294,7 +321,7 @@ namespace ...@@ -294,7 +321,7 @@ namespace
} }
#define REWRITER(OP) \ #define REWRITER(OP) \
void OP##Conversion::rewrite( \ PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter) const Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter) const
// ADD // ADD
...@@ -334,6 +361,8 @@ namespace ...@@ -334,6 +361,8 @@ namespace
}); });
// clang-format on // clang-format on
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
return matchSuccess();
} }
REWRITER(NGDotOp) REWRITER(NGDotOp)
...@@ -396,9 +425,16 @@ namespace ...@@ -396,9 +425,16 @@ namespace
}); });
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGReturnOp)
{
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
} }
REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
#undef REWRITER #undef REWRITER
} }
......
...@@ -27,6 +27,8 @@ namespace ngraph ...@@ -27,6 +27,8 @@ namespace ngraph
namespace ngmlir namespace ngmlir
{ {
class MLIRCompiler; class MLIRCompiler;
using OwningRewritePatternList = std::vector<std::unique_ptr<mlir::RewritePattern>>;
} }
} }
} }
......
...@@ -17,17 +17,19 @@ ...@@ -17,17 +17,19 @@
// Add new dialect ops lowerers to this file // Add new dialect ops lowerers to this file
#define DECL_OP_CONV(OP) \ #define DECL_OP_CONV(OP) \
class OP##Conversion : public mlir::DialectConversionPattern \ class OP##Conversion : public NGraphOpLowering \
{\ { \
public:\ public: \
explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\ explicit OP##Conversion(mlir::MLIRContext* context, DialectLoweringPass& pass) \
: mlir::DialectConversionPattern(mlir::OP::getOperationName(), 1, context),\ : NGraphOpLowering(mlir::OP::getOperationName(), context, pass) \
m_pass(pass)\ { \
{} \ } \
void rewrite(Operation *op, ArrayRef<Value *> operands, PatternRewriter &rewriter) const override; \ \
DialectLoweringPass& m_pass;\ PatternMatchResult matchAndRewrite(Operation* op, \
}; ArrayRef<Value*> operands, \
PatternRewriter& rewriter) const override; \
};
DECL_OP_CONV(NGAddOp) DECL_OP_CONV(NGAddOp)
DECL_OP_CONV(NGDotOp) DECL_OP_CONV(NGDotOp)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment