Unverified Commit 2574be4d authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3137 from NervanaSystems/dcaballe/mlir_bump

[MLIR] Bump MLIR repo to commit 82d5084 Wed Jun 26.
parents 05195925 c5e06cc3
......@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID bb2b527)
set(MLIR_COMMIT_ID 49f7efc)
set(MLIR_LLVM_COMMIT_ID c0cad98)
set(MLIR_COMMIT_ID 82d5084)
set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project)
set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects)
set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir)
......
......@@ -56,6 +56,7 @@ if (NGRAPH_MLIR_ENABLE)
MLIRExecutionEngine
MLIRIR
MLIRLLVMIR
MLIRStandardToLLVM
MLIRParser
MLIRPass
MLIRTargetLLVMIR
......
......@@ -34,11 +34,12 @@
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/LLVMIR/LLVMDialect.h>
#include <mlir/LLVMIR/Transforms.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Target/LLVMIR.h>
#include <mlir/Transforms/DialectConversion.h>
......@@ -50,6 +51,7 @@
using llvm::SmallVector;
using llvm::StringRef;
using llvm::make_unique;
using namespace ngraph::runtime::ngmlir;
#define COMPILE_OP_DECL(op_name) \
......@@ -75,7 +77,7 @@ void MLIRCompiler::init_mlir()
if (!initialized)
{
mlir::registerDialect<mlir::NGDialect>();
mlir::registerDialect<mlir::NGraphOpsDialect>();
// Register any LLVM command line options
llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
initialized = true;
......@@ -133,7 +135,7 @@ void MLIRCompiler::build_ng_dialect_module()
}
// create builder
m_builder = llvm::make_unique<mlir::FuncBuilder>(function.get());
m_builder = llvm::make_unique<mlir::OpBuilder>(function->getBody());
build_ng_dialect();
m_module->getFunctions().push_back(function.release());
if (failed(m_module->verify()))
......@@ -359,10 +361,14 @@ void MLIRCompiler::execute()
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect.
auto converter = mlir::createStdToLLVMConverter();
auto r = converter->convert(m_module.get());
(void)r;
NGRAPH_CHECK(succeeded(r), "second conversion failed");
mlir::LLVMTypeConverter llvm_converter(&m_context);
OwningRewritePatternList patterns;
mlir::populateStdToLLVMConversionPatterns(llvm_converter, patterns);
mlir::ConversionTarget target(m_context);
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
auto result = applyConversionPatterns(*m_module, target, llvm_converter, std::move(patterns));
NGRAPH_CHECK(succeeded(result), "Standard to LLVM dialect conversion failed");
dump_mlir_module("LLVM-IR Dialect Dump:");
......
......@@ -132,7 +132,7 @@ namespace ngraph
mlir::MLIRContext m_context;
std::unique_ptr<mlir::Module> m_module;
std::unique_ptr<mlir::FuncBuilder> m_builder;
std::unique_ptr<mlir::OpBuilder> m_builder;
std::unique_ptr<mlir::ExecutionEngine> m_engine;
using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>;
......
......@@ -21,8 +21,8 @@
using namespace mlir;
NGDialect::NGDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("ng", ctx)
NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx)
: mlir::Dialect(getDialectNamespace(), ctx)
{
addTypes<NGTensorType>();
addTypes<NGIntegerType>();
......@@ -34,7 +34,7 @@ NGDialect::NGDialect(mlir::MLIRContext* ctx)
>();
}
void NGDialect::printType(mlir::Type type, raw_ostream& os) const
void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const
{
switch (type.getKind())
{
......
......@@ -25,15 +25,17 @@
#include "ngraph/check.hpp"
namespace mlir
{
class NGDialect : public mlir::Dialect
class NGraphOpsDialect : public mlir::Dialect
{
public:
explicit NGDialect(mlir::MLIRContext* ctx);
explicit NGraphOpsDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{
NGRAPH_CHECK(false, "Unsupported type parsing.");
return mlir::Type();
}
void printType(mlir::Type type, llvm::raw_ostream& os) const override;
static StringRef getDialectNamespace() { return "ng"; }
};
}
......@@ -41,31 +41,34 @@ namespace
class DialectLoweringPass;
/// Base class for nGraph operation conversions to affine/standard dialect. Provides
/// conversion patterns with an access to the DialectLoweringPass which holds the state of the
/// conversion.
class NGraphOpLowering : public ConversionPattern
{
public:
NGraphOpLowering(StringRef rootOpName, MLIRContext* context, DialectLoweringPass& pass)
: ConversionPattern(rootOpName, /*benefit=*/1, context)
, m_pass(pass){};
protected:
// Back-reference to the lowering pass which contains the lowering state, including the
// nGraph type converter.
DialectLoweringPass& m_pass;
};
#include "op_lowerers.inc"
/// Use Dialect Converson Framework
class DialectLowerer : public DialectConversion
/// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter
{
public:
DialectLowerer(DialectLoweringPass& pass)
: DialectConversion()
, m_pass(pass)
NGraphTypeConverter()
: TypeConverter()
{
}
Type convertType(Type t) override;
protected:
// Initialize the list of converters.
void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override
{
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
patterns, mlirContext, m_pass);
}
private:
DialectLoweringPass& m_pass;
llvm::BumpPtrAllocator allocator;
};
/// Dialect Lowering Pass to affine ops
......@@ -73,14 +76,17 @@ namespace
{
public:
DialectLoweringPass(ngmlir::MLIRCompiler& compiler)
: m_dialectLowerer(*this)
, m_compiler(compiler)
: m_compiler(compiler)
{
}
void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
private:
/// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);
mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
......@@ -90,7 +96,7 @@ namespace
Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
private:
DialectLowerer m_dialectLowerer;
NGraphTypeConverter m_typeConverter;
// Value holding mem manager passed pointer
SmallVector<Value*, 4> m_memMgrDefs;
......@@ -101,21 +107,39 @@ namespace
void DialectLoweringPass::runOnModule()
{
// Create type converter and initialize conversion patterns.
NGraphTypeConverter converter;
OwningRewritePatternList patterns;
populateNGraphToAffineConversionPatterns(patterns);
// Create target that defines legal ops for nGraph dialect to be lowered to.
ConversionTarget target(getContext());
// TODO: Remove NGFakeInputOp. We need to set NGFakeInputOp as legal op because we generate
// it as part of the lowering to affine/standard.
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
target.addLegalOp<NGFakeInputOp>();
// capture output values by looking for the Return and grabbing the values
// the order of the returned values matches the order of the lowered func signature for
// results. This is used to find the arg_id that a defined value maps to if it is an output
findOutputValues();
if (failed(m_dialectLowerer.convert(&getModule())))
if (failed(applyConversionPatterns(getModule(), target, converter, std::move(patterns))))
{
getModule().getContext()->emitError(mlir::UnknownLoc::get(getModule().getContext()),
"Error lowering dialect\n");
emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering nGraph dialect\n");
signalPassFailure();
}
processFakeInstrs();
}
void DialectLoweringPass::populateNGraphToAffineConversionPatterns(
OwningRewritePatternList& patterns)
{
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
patterns, &getContext(), *this);
}
void DialectLoweringPass::findOutputValues()
{
// get original function
......@@ -138,6 +162,9 @@ namespace
outputCount = ret.getNumOperands();
});
// will be populated with lowered output values later
// TODO: This resize is making debugging obscure. When the container is not populated due
// to a bug, null pointers are used by the consumer leading to a crash more difficult to
// root-cause. We should try to change the current approach or introduce verification code.
m_loweredOutputValues.resize(outputCount, nullptr);
}
......@@ -146,10 +173,11 @@ namespace
{
// it would be nice to insert one fake def at the start of the new func
// however, due to how DialectConversion framework works, new func is only
// materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction()
// will give you the original func). This makes it very convoluted to insert instructions at entry block.
// materialized after conversion is done (rewriter->getFunction, or even
// rewriter->getInsertionBlock()->getFunction() will give you the original func). This
// makes it very convoluted to insert instructions at entry block.
auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(),
IndexType::get(getModule().getContext()));
IndexType::get(&getContext()));
// will be fixed later to read passed arg instead.
m_memMgrDefs.push_back(op.getResult());
return op.getResult();
......@@ -167,8 +195,7 @@ namespace
unsigned argId = (int)attr.getInt();
auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(),
m_dialectLowerer.convertType(
origResult->getType()) /* convert to lowered type */
m_typeConverter.convertType(origResult->getType()) /* convert to lowered type */
);
// Fake instrution is short-lived. Verify here.
fakeOp.verify();
......@@ -181,7 +208,7 @@ namespace
auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()},
{m_dialectLowerer.convertType(tensorType)},
{m_typeConverter.convertType(tensorType)},
rewriter);
auto size = tensorType.getSizeInBytes();
......@@ -261,10 +288,10 @@ namespace
return callBackFuncPtr;
}
// NGDialect converters
Type DialectLowerer::convertType(Type type)
Type NGraphTypeConverter::convertType(Type type)
{
// We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since DialectLowerer is private.
// outside of the lowering context since NGraphTypeConverter is private.
if (auto tensor_type = type.dyn_cast<NGTensorType>())
{
......@@ -294,7 +321,7 @@ namespace
}
#define REWRITER(OP) \
void OP##Conversion::rewrite( \
PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter) const
// ADD
......@@ -334,6 +361,8 @@ namespace
});
// clang-format on
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGDotOp)
......@@ -396,9 +425,16 @@ namespace
});
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGReturnOp)
{
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
}
REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
#undef REWRITER
}
......
......@@ -27,6 +27,8 @@ namespace ngraph
namespace ngmlir
{
class MLIRCompiler;
using OwningRewritePatternList = std::vector<std::unique_ptr<mlir::RewritePattern>>;
}
}
}
......
......@@ -18,16 +18,18 @@
// Add new dialect ops lowerers to this file
#define DECL_OP_CONV(OP) \
class OP##Conversion : public mlir::DialectConversionPattern \
{\
public:\
explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\
: mlir::DialectConversionPattern(mlir::OP::getOperationName(), 1, context),\
m_pass(pass)\
{} \
void rewrite(Operation *op, ArrayRef<Value *> operands, PatternRewriter &rewriter) const override; \
DialectLoweringPass& m_pass;\
};
class OP##Conversion : public NGraphOpLowering \
{ \
public: \
explicit OP##Conversion(mlir::MLIRContext* context, DialectLoweringPass& pass) \
: NGraphOpLowering(mlir::OP::getOperationName(), context, pass) \
{ \
} \
\
PatternMatchResult matchAndRewrite(Operation* op, \
ArrayRef<Value*> operands, \
PatternRewriter& rewriter) const override; \
};
DECL_OP_CONV(NGAddOp)
DECL_OP_CONV(NGDotOp)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment