Commit 9acdfe04 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] MLIR version upgrade (#28)

* Upgrade MLIR. Several code fixes based on API changes

* Fixes due to DialectConv API changes

* style-apply

* PR fixes
parent eda52385
...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git) ...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git) set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions # Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID 43d110b) set(MLIR_LLVM_COMMIT_ID bb2b527)
set(MLIR_COMMIT_ID 1fbf407) set(MLIR_COMMIT_ID 49f7efc)
set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project) set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project)
set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects) set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects)
set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir) set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir)
......
...@@ -353,9 +353,8 @@ void MLIRCompiler::execute() ...@@ -353,9 +353,8 @@ void MLIRCompiler::execute()
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we // Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run // don't run MLIR passes that were already run. We also pass a default transformer to run
// LLVM optimizations at level 3. // LLVM optimizations at level 3.
mlir::PassManager* mlir_pm = nullptr;
auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/); auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), mlir_pm, llvm_transformer); auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer);
NGRAPH_ASSERT(maybeEngine) << "failed to construct an execution engine"; NGRAPH_ASSERT(maybeEngine) << "failed to construct an execution engine";
m_engine = std::move(maybeEngine.get()); m_engine = std::move(maybeEngine.get());
......
...@@ -36,11 +36,19 @@ include "mlir/IR/OpBase.td" ...@@ -36,11 +36,19 @@ include "mlir/IR/OpBase.td"
// Each def will corresponding to a C++ class // Each def will corresponding to a C++ class
def NG_Dialect : Dialect {
let name = "ng";
// TODO: Have the dialect under its own mlir::ngraph namespace
// At mlir top-level for now
let cppNamespace = "";
}
// NGraph Types // NGraph Types
// This defines records equivalent to NGraph types. It doesn't generate code. // This defines records equivalent to NGraph types. It doesn't generate code.
// This is used as a type in the DAG input/outputs. // This is used as a type in the DAG input/outputs.
// Constraints (CPred) are used to type-check args/results of that type during op verification // Constraints (CPred) are used to type-check args/results of that type during op verification
def NG_TensorType : Type<CPred<"{0}.isa<mlir::NGTensorType>()">, def NG_TensorType : Type<CPred<"$_self.isa<mlir::NGTensorType>()">,
"NGraph Tensor Type">; "NGraph Tensor Type">;
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering // A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
...@@ -49,7 +57,7 @@ def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">; ...@@ -49,7 +57,7 @@ def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">;
// NGraph operation base class. // NGraph operation base class.
// Prepends "ng." to operation name // Prepends "ng." to operation name
class NG_Op<string mnemonic, list<OpTrait> traits = []> : class NG_Op<string mnemonic, list<OpTrait> traits = []> :
Op<!strconcat("ng.", mnemonic), traits> {} Op<NG_Dialect, mnemonic, traits> {}
// Operations producing single result. // Operations producing single result.
// Will set OneResult trait based on Results out dag. // Will set OneResult trait based on Results out dag.
...@@ -71,7 +79,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -71,7 +79,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$arg)> Arguments<(ins NG_TensorType:$arg)>
{ {
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return false; }]; let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
let verifier = [{ return verifyUnaryArithOp(this); }]; let verifier = [{ return verifyUnaryArithOp(this); }];
} }
...@@ -83,7 +91,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -83,7 +91,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)> Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{ {
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return false; }]; let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
let verifier = [{ return verifyBinaryArithOp(this); }]; let verifier = [{ return verifyBinaryArithOp(this); }];
} }
......
...@@ -116,7 +116,7 @@ namespace mlir ...@@ -116,7 +116,7 @@ namespace mlir
/// Convert to equivalent std type /// Convert to equivalent std type
/// std types are sign-agnostic. /// std types are sign-agnostic.
mlir::Type toStdType() const { return mlir::IntegerType::get(getWidth(), getContext()); } mlir::Type toStdType() { return mlir::IntegerType::get(getWidth(), getContext()); }
/// Check if signed type /// Check if signed type
bool isSigned() const; bool isSigned() const;
...@@ -164,7 +164,7 @@ namespace mlir ...@@ -164,7 +164,7 @@ namespace mlir
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; } static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; }
static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); } static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); }
/// Convert to equivalent std type. Integer of width 1 in that case /// Convert to equivalent std type. Integer of width 1 in that case
mlir::Type toStdType() const { return mlir::IntegerType::get(1, getContext()); } mlir::Type toStdType() { return mlir::IntegerType::get(1, getContext()); }
}; };
// Note that dialect types don't add new data members, so always possible // Note that dialect types don't add new data members, so always possible
......
...@@ -37,6 +37,7 @@ namespace ...@@ -37,6 +37,7 @@ namespace
using namespace ngraph::runtime; using namespace ngraph::runtime;
class DialectLoweringPass; class DialectLoweringPass;
#include "op_lowerers.inc" #include "op_lowerers.inc"
/// Use Dialect Converson Framework /// Use Dialect Converson Framework
...@@ -53,11 +54,10 @@ namespace ...@@ -53,11 +54,10 @@ namespace
protected: protected:
// Initialize the list of converters. // Initialize the list of converters.
llvm::DenseSet<DialectOpConversion*> initConverters(MLIRContext* context) override void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override
{ {
return ConversionListBuilder<NGAddOpConversion, RewriteListBuilder<NGAddOpConversion, NGMatMulBiasOpConversion, NGReturnOpConversion>::
NGMatMulBiasOpConversion, build(patterns, mlirContext, m_pass);
NGReturnOpConversion>::build(&allocator, context, m_pass);
} }
private: private:
...@@ -75,24 +75,22 @@ namespace ...@@ -75,24 +75,22 @@ namespace
{ {
} }
void runOnModule() override; void runOnModule() override;
std::map<Value*, unsigned>& getOutputValueMap() { return m_outputValueMap; }; SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
SmallVector<Value*, 4> buildOutputDefs(Operation* op, FuncBuilder& rewriter);
private: private:
mlir::Function* getCallDecl(StringRef name, mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args, ArrayRef<Type> args,
ArrayRef<Type> output, ArrayRef<Type> output,
FuncBuilder& rewriter); PatternRewriter& rewriter);
void findOutputValues(); void findOutputValues();
void processFakeInstrs(); void processFakeInstrs();
Value* insertMemMgrDef(FuncBuilder* rewriter = nullptr); Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
private: private:
DialectLowerer m_dialectLowerer; DialectLowerer m_dialectLowerer;
// Value holding mem manager passed pointer // Value holding mem manager passed pointer
SmallVector<Value*, 4> m_memMgrDefs; SmallVector<Value*, 4> m_memMgrDefs;
// maps output ng dialect values to args pos
std::map<Value*, unsigned> m_outputValueMap;
// list of results values to add to func signature // list of results values to add to func signature
SmallVector<Value*, 4> m_loweredOutputValues; SmallVector<Value*, 4> m_loweredOutputValues;
ngmlir::MLIRCompiler& m_compiler; ngmlir::MLIRCompiler& m_compiler;
...@@ -134,7 +132,10 @@ namespace ...@@ -134,7 +132,10 @@ namespace
f->walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) { f->walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++) for (unsigned i = 0; i < ret.getNumOperands(); i++)
{ {
this->m_outputValueMap.insert(std::pair<Value*, unsigned>(ret.getOperand(i), i)); auto outputValue = ret.getOperand(i);
auto op = outputValue->getDefiningOp();
op->setAttr("graphOutputIdx",
mlir::IntegerAttr::get(IntegerType::get(8, op->getContext()), i));
} }
NGRAPH_ASSERT(outputCount == 0 || outputCount == ret.getNumOperands()) NGRAPH_ASSERT(outputCount == 0 || outputCount == ret.getNumOperands())
<< "Inconsistent returns in function"; << "Inconsistent returns in function";
...@@ -145,7 +146,7 @@ namespace ...@@ -145,7 +146,7 @@ namespace
} }
/// Inserts a fake def for Mem Mgr pointer at converted func start /// Inserts a fake def for Mem Mgr pointer at converted func start
Value* DialectLoweringPass::insertMemMgrDef(FuncBuilder* rewriter) Value* DialectLoweringPass::insertMemMgrDef(PatternRewriter* rewriter)
{ {
// it would be nice to insert one fake def at the start of the new func // it would be nice to insert one fake def at the start of the new func
// however, due to how DialectConversion framework works, new func is only // however, due to how DialectConversion framework works, new func is only
...@@ -159,17 +160,15 @@ namespace ...@@ -159,17 +160,15 @@ namespace
} }
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op, SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
FuncBuilder& rewriter) PatternRewriter& rewriter)
{ {
auto& outputMap = getOutputValueMap();
SmallVector<Value*, 4> newResults; SmallVector<Value*, 4> newResults;
for (auto origResult : op->getResults()) for (auto origResult : op->getResults())
{ {
auto it = outputMap.find(origResult);
// create output def if this operation produces any sub-graph outputs // create output def if this operation produces any sub-graph outputs
if (it != outputMap.end()) if (IntegerAttr attr = op->getAttrOfType<IntegerAttr>("graphOutputIdx"))
{ {
unsigned argId = (*it).second; unsigned argId = (int)attr.getInt();
auto fakeOp = rewriter.create<NGFakeInputOp>( auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(), op->getLoc(),
m_dialectLowerer.convertType( m_dialectLowerer.convertType(
...@@ -237,7 +236,7 @@ namespace ...@@ -237,7 +236,7 @@ namespace
for (auto value : m_loweredOutputValues) for (auto value : m_loweredOutputValues)
{ {
auto op = value->getDefiningOp(); auto op = value->getDefiningOp();
NGRAPH_ASSERT(op->isa<NGFakeInputOp>()) << "output value not defined by fake output?"; NGRAPH_ASSERT(isa<NGFakeInputOp>(op)) << "output value not defined by fake output?";
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i)); value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase(); op->erase();
i++; i++;
...@@ -252,7 +251,7 @@ namespace ...@@ -252,7 +251,7 @@ namespace
mlir::Function* DialectLoweringPass::getCallDecl(StringRef name, mlir::Function* DialectLoweringPass::getCallDecl(StringRef name,
ArrayRef<Type> args, ArrayRef<Type> args,
ArrayRef<Type> output, ArrayRef<Type> output,
FuncBuilder& rewriter) PatternRewriter& rewriter)
{ {
auto callBackFuncPtr = getModule().getNamedFunction(name); auto callBackFuncPtr = getModule().getNamedFunction(name);
if (callBackFuncPtr == nullptr) if (callBackFuncPtr == nullptr)
...@@ -292,12 +291,15 @@ namespace ...@@ -292,12 +291,15 @@ namespace
return t; return t;
} }
#define REWRITER(OP) \
void OP##Conversion::rewrite( \
Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter) const
// ADD // ADD
SmallVector<Value*, 4> NGAddOpConversion::rewrite(Operation* op, REWRITER(NGAddOp)
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
{ {
auto add = op->cast<NGAddOp>(); auto add = cast<NGAddOp>(op);
auto loc = add.getLoc(); auto loc = add.getLoc();
Value *origResult, *newResult; Value *origResult, *newResult;
...@@ -323,18 +325,19 @@ namespace ...@@ -323,18 +325,19 @@ namespace
auto pivs = IndexHandle::makeIndexHandlePointers(ivs); auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps // Steps
auto steps = vLHS.getSteps(); auto steps = vLHS.getSteps();
// clang-format off
LoopNestBuilder(pivs, lbs, ubs, steps)({// single stmt body LoopNestBuilder(pivs, lbs, ubs, steps)(
iRes(ivs) = iLHS(ivs) + iRHS(ivs)}); // single stmt body
// return result memref [&] {
return {result}; iRes(ivs) = iLHS(ivs) + iRHS(ivs);
});
// clang-format on
rewriter.replaceOp(op, {result});
} }
SmallVector<Value*, 4> NGMatMulBiasOpConversion::rewrite(Operation* op, REWRITER(NGMatMulBiasOp)
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
{ {
auto matmul = op->cast<NGMatMulBiasOp>(); auto matmul = cast<NGMatMulBiasOp>(op);
auto loc = matmul.getLoc(); auto loc = matmul.getLoc();
NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation"; NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation";
...@@ -382,27 +385,26 @@ namespace ...@@ -382,27 +385,26 @@ namespace
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty))); ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
// clang-format off // clang-format off
LoopBuilder(&n, n_lb, n_ub, n_step)({ LoopBuilder(&n, n_lb, n_ub, n_step)(
LoopBuilder(&k, k_lb, k_ub, k_step)({ [&]{
i_res(n, k) = zero_init, LoopBuilder(&k, k_lb, k_ub, k_step)(
LoopBuilder(&m, m_lb, m_ub, m_step)({ [&]{
i_res(n, k) += i_lhs(n, m) * i_rhs(m, k) i_res(n, k) = zero_init;
}) LoopBuilder(&m, m_lb, m_ub, m_step)(
}), [&]{
}); i_res(n, k) += i_lhs(n, m) * i_rhs(m, k);
}
);
}
);
}
);
// clang-format on // clang-format on
rewriter.replaceOp(op, {result});
// Return result memref.
return {result};
} }
SmallVector<Value*, 4> NGReturnOpConversion::rewrite(Operation* op, REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
ArrayRef<Value*> operands, #undef REWRITER
FuncBuilder& rewriter) const
{
rewriter.create<ReturnOp>(op->getLoc());
return {};
}
} }
namespace mlir namespace mlir
......
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
// Add new dialect ops lowerers to this file // Add new dialect ops lowerers to this file
#define DECL_OP_CONV(OP) \ #define DECL_OP_CONV(OP) \
class OP##Conversion : public mlir::DialectOpConversion \ class OP##Conversion : public mlir::DialectConversionPattern \
{\ {\
public:\ public:\
explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\ explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\
: mlir::DialectOpConversion(mlir::OP::getOperationName(), 1, context),\ : mlir::DialectConversionPattern(mlir::OP::getOperationName(), 1, context),\
m_pass(pass)\ m_pass(pass)\
{} \ {} \
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override; \ void rewrite(Operation *op, ArrayRef<Value *> operands, PatternRewriter &rewriter) const override; \
DialectLoweringPass& m_pass;\ DialectLoweringPass& m_pass;\
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment