Commit 9acdfe04 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] MLIR version upgrade (#28)

* Upgrade MLIR. Several code fixes based on API changes

* Fixes due to DialectConv API changes

* style-apply

* PR fixes
parent eda52385
......@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID 43d110b)
set(MLIR_COMMIT_ID 1fbf407)
set(MLIR_LLVM_COMMIT_ID bb2b527)
set(MLIR_COMMIT_ID 49f7efc)
set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project)
set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects)
set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir)
......
......@@ -353,9 +353,8 @@ void MLIRCompiler::execute()
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run
// LLVM optimizations at level 3.
mlir::PassManager* mlir_pm = nullptr;
auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), mlir_pm, llvm_transformer);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer);
NGRAPH_ASSERT(maybeEngine) << "failed to construct an execution engine";
m_engine = std::move(maybeEngine.get());
......
......@@ -36,11 +36,19 @@ include "mlir/IR/OpBase.td"
// Each def will corresponding to a C++ class
def NG_Dialect : Dialect {
let name = "ng";
// TODO: Have the dialect under its own mlir::ngraph namespace
// At mlir top-level for now
let cppNamespace = "";
}
// NGraph Types
// This defines records equivalent to NGraph types. It doesn't generate code.
// This is used as a type in the DAG input/outputs.
// Constraints (CPred) are used to type-check args/results of that type during op verification
def NG_TensorType : Type<CPred<"{0}.isa<mlir::NGTensorType>()">,
def NG_TensorType : Type<CPred<"$_self.isa<mlir::NGTensorType>()">,
"NGraph Tensor Type">;
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
......@@ -49,7 +57,7 @@ def NG_MemRefType : Type<IsMemRefTypePred, "MemRef Type">;
// NGraph operation base class.
// Prepends "ng." to operation name
class NG_Op<string mnemonic, list<OpTrait> traits = []> :
Op<!strconcat("ng.", mnemonic), traits> {}
Op<NG_Dialect, mnemonic, traits> {}
// Operations producing single result.
// Will set OneResult trait based on Results out dag.
......@@ -71,7 +79,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$arg)>
{
// TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return false; }];
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
let verifier = [{ return verifyUnaryArithOp(this); }];
}
......@@ -83,7 +91,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
Arguments<(ins NG_TensorType:$lhs, NG_TensorType:$rhs)>
{
// TODO: Implement
let parser = [{ NGRAPH_FAIL() << "No parser support"; return false; }];
let parser = [{ NGRAPH_FAIL() << "No parser support"; return mlir::failure(); }];
let verifier = [{ return verifyBinaryArithOp(this); }];
}
......
......@@ -116,7 +116,7 @@ namespace mlir
/// Convert to equivalent std type
/// std types are sign-agnostic.
mlir::Type toStdType() const { return mlir::IntegerType::get(getWidth(), getContext()); }
mlir::Type toStdType() { return mlir::IntegerType::get(getWidth(), getContext()); }
/// Check if signed type
bool isSigned() const;
......@@ -164,7 +164,7 @@ namespace mlir
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; }
static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); }
/// Convert to equivalent std type. Integer of width 1 in that case
mlir::Type toStdType() const { return mlir::IntegerType::get(1, getContext()); }
mlir::Type toStdType() { return mlir::IntegerType::get(1, getContext()); }
};
// Note that dialect types don't add new data members, so always possible
......
......@@ -37,6 +37,7 @@ namespace
using namespace ngraph::runtime;
class DialectLoweringPass;
#include "op_lowerers.inc"
/// Use Dialect Converson Framework
......@@ -53,11 +54,10 @@ namespace
protected:
// Initialize the list of converters.
llvm::DenseSet<DialectOpConversion*> initConverters(MLIRContext* context) override
void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override
{
return ConversionListBuilder<NGAddOpConversion,
NGMatMulBiasOpConversion,
NGReturnOpConversion>::build(&allocator, context, m_pass);
RewriteListBuilder<NGAddOpConversion, NGMatMulBiasOpConversion, NGReturnOpConversion>::
build(patterns, mlirContext, m_pass);
}
private:
......@@ -75,24 +75,22 @@ namespace
{
}
void runOnModule() override;
std::map<Value*, unsigned>& getOutputValueMap() { return m_outputValueMap; };
SmallVector<Value*, 4> buildOutputDefs(Operation* op, FuncBuilder& rewriter);
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
private:
mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
FuncBuilder& rewriter);
PatternRewriter& rewriter);
void findOutputValues();
void processFakeInstrs();
Value* insertMemMgrDef(FuncBuilder* rewriter = nullptr);
Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
private:
DialectLowerer m_dialectLowerer;
// Value holding mem manager passed pointer
SmallVector<Value*, 4> m_memMgrDefs;
// maps output ng dialect values to args pos
std::map<Value*, unsigned> m_outputValueMap;
// list of results values to add to func signature
SmallVector<Value*, 4> m_loweredOutputValues;
ngmlir::MLIRCompiler& m_compiler;
......@@ -134,7 +132,10 @@ namespace
f->walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++)
{
this->m_outputValueMap.insert(std::pair<Value*, unsigned>(ret.getOperand(i), i));
auto outputValue = ret.getOperand(i);
auto op = outputValue->getDefiningOp();
op->setAttr("graphOutputIdx",
mlir::IntegerAttr::get(IntegerType::get(8, op->getContext()), i));
}
NGRAPH_ASSERT(outputCount == 0 || outputCount == ret.getNumOperands())
<< "Inconsistent returns in function";
......@@ -145,7 +146,7 @@ namespace
}
/// Inserts a fake def for Mem Mgr pointer at converted func start
Value* DialectLoweringPass::insertMemMgrDef(FuncBuilder* rewriter)
Value* DialectLoweringPass::insertMemMgrDef(PatternRewriter* rewriter)
{
// it would be nice to insert one fake def at the start of the new func
// however, due to how DialectConversion framework works, new func is only
......@@ -159,17 +160,15 @@ namespace
}
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
FuncBuilder& rewriter)
PatternRewriter& rewriter)
{
auto& outputMap = getOutputValueMap();
SmallVector<Value*, 4> newResults;
for (auto origResult : op->getResults())
{
auto it = outputMap.find(origResult);
// create output def if this operation produces any sub-graph outputs
if (it != outputMap.end())
if (IntegerAttr attr = op->getAttrOfType<IntegerAttr>("graphOutputIdx"))
{
unsigned argId = (*it).second;
unsigned argId = (int)attr.getInt();
auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(),
m_dialectLowerer.convertType(
......@@ -237,7 +236,7 @@ namespace
for (auto value : m_loweredOutputValues)
{
auto op = value->getDefiningOp();
NGRAPH_ASSERT(op->isa<NGFakeInputOp>()) << "output value not defined by fake output?";
NGRAPH_ASSERT(isa<NGFakeInputOp>(op)) << "output value not defined by fake output?";
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase();
i++;
......@@ -252,7 +251,7 @@ namespace
mlir::Function* DialectLoweringPass::getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
FuncBuilder& rewriter)
PatternRewriter& rewriter)
{
auto callBackFuncPtr = getModule().getNamedFunction(name);
if (callBackFuncPtr == nullptr)
......@@ -292,12 +291,15 @@ namespace
return t;
}
#define REWRITER(OP) \
void OP##Conversion::rewrite( \
Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter) const
// ADD
SmallVector<Value*, 4> NGAddOpConversion::rewrite(Operation* op,
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
REWRITER(NGAddOp)
{
auto add = op->cast<NGAddOp>();
auto add = cast<NGAddOp>(op);
auto loc = add.getLoc();
Value *origResult, *newResult;
......@@ -323,18 +325,19 @@ namespace
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
LoopNestBuilder(pivs, lbs, ubs, steps)({// single stmt body
iRes(ivs) = iLHS(ivs) + iRHS(ivs)});
// return result memref
return {result};
// clang-format off
LoopNestBuilder(pivs, lbs, ubs, steps)(
// single stmt body
[&] {
iRes(ivs) = iLHS(ivs) + iRHS(ivs);
});
// clang-format on
rewriter.replaceOp(op, {result});
}
SmallVector<Value*, 4> NGMatMulBiasOpConversion::rewrite(Operation* op,
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
REWRITER(NGMatMulBiasOp)
{
auto matmul = op->cast<NGMatMulBiasOp>();
auto matmul = cast<NGMatMulBiasOp>(op);
auto loc = matmul.getLoc();
NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation";
......@@ -382,27 +385,26 @@ namespace
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
// clang-format off
LoopBuilder(&n, n_lb, n_ub, n_step)({
LoopBuilder(&k, k_lb, k_ub, k_step)({
i_res(n, k) = zero_init,
LoopBuilder(&m, m_lb, m_ub, m_step)({
i_res(n, k) += i_lhs(n, m) * i_rhs(m, k)
})
}),
});
LoopBuilder(&n, n_lb, n_ub, n_step)(
[&]{
LoopBuilder(&k, k_lb, k_ub, k_step)(
[&]{
i_res(n, k) = zero_init;
LoopBuilder(&m, m_lb, m_ub, m_step)(
[&]{
i_res(n, k) += i_lhs(n, m) * i_rhs(m, k);
}
);
}
);
}
);
// clang-format on
// Return result memref.
return {result};
rewriter.replaceOp(op, {result});
}
SmallVector<Value*, 4> NGReturnOpConversion::rewrite(Operation* op,
ArrayRef<Value*> operands,
FuncBuilder& rewriter) const
{
rewriter.create<ReturnOp>(op->getLoc());
return {};
}
REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
#undef REWRITER
}
namespace mlir
......
......@@ -18,14 +18,14 @@
// Add new dialect ops lowerers to this file
#define DECL_OP_CONV(OP) \
class OP##Conversion : public mlir::DialectOpConversion \
class OP##Conversion : public mlir::DialectConversionPattern \
{\
public:\
explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\
: mlir::DialectOpConversion(mlir::OP::getOperationName(), 1, context),\
: mlir::DialectConversionPattern(mlir::OP::getOperationName(), 1, context),\
m_pass(pass)\
{} \
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override; \
void rewrite(Operation *op, ArrayRef<Value *> operands, PatternRewriter &rewriter) const override; \
DialectLoweringPass& m_pass;\
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment