//***************************************************************************** // Copyright 2017-2019 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** #include "lowerer.hpp" #include "compiler.hpp" #include "dialect/ops.hpp" #include "dialect/type.hpp" #include "ngraph/assertion.hpp" #include <llvm/ADT/DenseSet.h> #include <mlir/EDSC/Builders.h> #include <mlir/EDSC/Helpers.h> #include <mlir/EDSC/Intrinsics.h> #include <mlir/IR/MLIRContext.h> #include <mlir/IR/StandardTypes.h> #include <mlir/Transforms/DialectConversion.h> #include <map> // anonymous namespace // no need to expose any of the following outside of this file namespace { using namespace mlir; using namespace mlir::edsc; using namespace ngraph::runtime; class DialectLoweringPass; #include "op_lowerers.inc" /// Use Dialect Converson Framework class DialectLowerer : public DialectConversion { public: DialectLowerer(DialectLoweringPass& pass) : DialectConversion() , m_pass(pass) { } Type convertType(Type t) override; protected: // Initialize the list of converters. void initConverters(OwningRewritePatternList& patterns, MLIRContext* mlirContext) override { RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build( patterns, mlirContext, m_pass); } private: DialectLoweringPass& m_pass; llvm::BumpPtrAllocator allocator; }; /// Dialect Lowering Pass to affine ops class DialectLoweringPass : public ModulePass<DialectLoweringPass> { public: DialectLoweringPass(ngmlir::MLIRCompiler& compiler) : m_dialectLowerer(*this) , m_compiler(compiler) { } void runOnModule() override; SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter); private: mlir::Function* getCallDecl(StringRef name, ArrayRef<Type> args, ArrayRef<Type> output, PatternRewriter& rewriter); void findOutputValues(); void processFakeInstrs(); Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr); private: DialectLowerer m_dialectLowerer; // Value holding mem manager passed pointer SmallVector<Value*, 4> m_memMgrDefs; // list of results values to add to func signature SmallVector<Value*, 4> m_loweredOutputValues; ngmlir::MLIRCompiler& m_compiler; }; void DialectLoweringPass::runOnModule() { // capture output values by looking for the Return and grabbing the values // the order of the returned values matches the order of the lowered func signature for // results. This is used to find the arg_id that a defined value maps to if it is an output findOutputValues(); if (failed(m_dialectLowerer.convert(&getModule()))) { getModule().getContext()->emitError(mlir::UnknownLoc::get(getModule().getContext()), "Error lowering dialect\n"); signalPassFailure(); } processFakeInstrs(); } void DialectLoweringPass::findOutputValues() { // get original function auto f = getModule().getNamedFunction("main"); SmallVector<Value*, 4> outputList; unsigned outputCount = 0; // we find out output values by looking at returned values // any return should return all outputs of the subgraph f->walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) { for (unsigned i = 0; i < ret.getNumOperands(); i++) { auto outputValue = ret.getOperand(i); auto op = outputValue->getDefiningOp(); op->setAttr("graphOutputIdx", mlir::IntegerAttr::get(IntegerType::get(8, op->getContext()), i)); } NGRAPH_CHECK(outputCount == 0 || outputCount == ret.getNumOperands(), "Inconsistent returns in function"); outputCount = ret.getNumOperands(); }); // will be populated with lowered output values later m_loweredOutputValues.resize(outputCount, nullptr); } /// Inserts a fake def for Mem Mgr pointer at converted func start Value* DialectLoweringPass::insertMemMgrDef(PatternRewriter* rewriter) { // it would be nice to insert one fake def at the start of the new func // however, due to how DialectConversion framework works, new func is only // materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction() // will give you the original func). This makes it very convoluted to insert instructions at entry block. auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(), IndexType::get(getModule().getContext())); // will be fixed later to read passed arg instead. m_memMgrDefs.push_back(op.getResult()); return op.getResult(); } SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op, PatternRewriter& rewriter) { SmallVector<Value*, 4> newResults; for (auto origResult : op->getResults()) { // create output def if this operation produces any sub-graph outputs if (IntegerAttr attr = op->getAttrOfType<IntegerAttr>("graphOutputIdx")) { unsigned argId = (int)attr.getInt(); auto fakeOp = rewriter.create<NGFakeInputOp>( op->getLoc(), m_dialectLowerer.convertType( origResult->getType()) /* convert to lowered type */ ); // Fake instrution is short-lived. Verify here. fakeOp.verify(); auto newResult = fakeOp.getResult(); newResults.push_back(newResult); m_loweredOutputValues[argId] = newResult; } else { auto tensorType = origResult->getType().cast<NGTensorType>(); auto callBackFunc = getCallDecl("__mlir_allocate", {rewriter.getIndexType(), rewriter.getIndexType()}, {m_dialectLowerer.convertType(tensorType)}, rewriter); auto size = tensorType.getSizeInBytes(); SmallVector<mlir::Value*, 4> args = { insertMemMgrDef(&rewriter), /* pointer to mem manager */ rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(), size)}; /* size to allocate */ auto newResult = rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args) .getResult(0); newResults.push_back(newResult); } } return newResults; } void DialectLoweringPass::processFakeInstrs() { auto context = getModule().getContext(); auto f = getModule().getNamedFunction("main"); mlir::Block* entryBlock = &*(f->begin()); auto oldFuncType = f->getType(); ArrayRef<mlir::Type> ipArgs = oldFuncType.getInputs(); ArrayRef<mlir::Type> opArgs = oldFuncType.getResults(); SmallVector<mlir::Type, 4> allArgs; // Move all args as inputs in new type for (auto type : ipArgs) { allArgs.push_back(type); } for (auto type : opArgs) { allArgs.push_back(type); // add new value for result entryBlock->addArgument(type); } // Mem Manager Ptr auto indexType = mlir::IndexType::get(context); allArgs.push_back(indexType); entryBlock->addArgument(indexType); // update type auto newFuncType = mlir::FunctionType::get(allArgs, {}, context); f->setType(newFuncType); // RAUW fake outputs with result values unsigned i = 0; for (auto value : m_loweredOutputValues) { auto op = value->getDefiningOp(); NGRAPH_CHECK(isa<NGFakeInputOp>(op), "output value not defined by fake output?"); value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i)); op->erase(); i++; } for (auto v : m_memMgrDefs) { v->replaceAllUsesWith(entryBlock->getArgument(m_compiler.get_mem_mgr_arg_id(f))); v->getDefiningOp()->erase(); } } mlir::Function* DialectLoweringPass::getCallDecl(StringRef name, ArrayRef<Type> args, ArrayRef<Type> output, PatternRewriter& rewriter) { auto callBackFuncPtr = getModule().getNamedFunction(name); if (callBackFuncPtr == nullptr) { auto callBackType = rewriter.getFunctionType(args, output); auto callBackFunc = llvm::make_unique<mlir::Function>(rewriter.getUnknownLoc(), name, callBackType); callBackFuncPtr = callBackFunc.get(); getModule().getFunctions().push_back(callBackFunc.release()); } return callBackFuncPtr; } // NGDialect converters Type DialectLowerer::convertType(Type type) { // We may need to refactor this code to a external utility if type conversion is needed // outside of the lowering context since DialectLowerer is private. if (auto tensor_type = type.dyn_cast<NGTensorType>()) { // Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType. // This may change in the future. return MemRefType::get(tensor_type.getShape(), convertType(tensor_type.getElementType()), {/* no map used */}, 0); } if (auto float_type = type.dyn_cast<NGFloatType>()) { // Float types are already std type. return float_type; } if (auto int_type = type.dyn_cast<NGIntegerType>()) { return mlir::IntegerType::get(int_type.getWidth(), int_type.getContext()); } if (auto bool_type = type.dyn_cast<NGBoolType>()) { return mlir::IntegerType::get(1 /* width */, bool_type.getContext()); } NGRAPH_CHECK(false, "Unsupported type to lower"); return type; } #define REWRITER(OP) \ void OP##Conversion::rewrite( \ Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter) const // ADD REWRITER(NGAddOp) { auto add = cast<NGAddOp>(op); auto loc = add.getLoc(); auto result = m_pass.buildOutputDefs(op, rewriter)[0]; NGRAPH_CHECK(result->getType().isa<MemRefType>()); // Note that builder's current function is still the original function body. // use getBlock to get the new block instead. // get new operands Value* lhs = operands[0]; Value* rhs = operands[1]; ScopedContext scope(rewriter, loc); // Views MemRefView vRes(result), vLHS(lhs), vRHS(rhs); // Index Values IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); // Bounds Index Handles auto lbs = vLHS.getLbs(); auto ubs = vLHS.getUbs(); // Loop induction vars auto ivs = IndexHandle::makeIndexHandles(vLHS.rank()); auto pivs = IndexHandle::makeIndexHandlePointers(ivs); // Steps auto steps = vLHS.getSteps(); // clang-format off LoopNestBuilder(pivs, lbs, ubs, steps)( // single stmt body [&] { iRes(ivs) = iLHS(ivs) + iRHS(ivs); }); // clang-format on rewriter.replaceOp(op, {result}); } REWRITER(NGDotOp) { auto dot = cast<NGDotOp>(op); auto loc = dot.getLoc(); // Retrieve/generate Values for operands and result. ScopedContext scope(rewriter, loc); Value* lhs = operands[0]; Value* rhs = operands[1]; Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp"); auto result_ty = result->getType().dyn_cast<MemRefType>(); auto lhs_ty = lhs->getType().dyn_cast<MemRefType>(); auto rhs_ty = rhs->getType().dyn_cast<MemRefType>(); NGRAPH_CHECK(result_ty, "Unexpected non-memref result type"); NGRAPH_CHECK(lhs_ty, "Unexpected non-memref LHS type"); NGRAPH_CHECK(rhs_ty, "Unexpected non-memref RHS type"); Type elem_ty = result_ty.getElementType(); NGRAPH_CHECK(elem_ty == lhs_ty.getElementType() && elem_ty == rhs_ty.getElementType(), "Types mismatch in DotOp"); // Create the following loop nest for matmul operation: // for(n, N, 1) // for(m, M, 1) // for(k, K, 1) // res[n, k] += lhs[n, m] * rhs[m, k] // TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout. MemRefView v_res(result), v_lhs(lhs), v_rhs(rhs); NGRAPH_CHECK(v_lhs.rank() == 2 && v_rhs.rank() == 2 && v_res.rank() == 2, "Dot operation is only supported for 2D tensors"); // Create induction variables, lower bounds, upper bounds and steps of the loop nest. // It's important to note that MemRefView priovides lb/ub/step info is "reverse order", // i.e., fastest varying dimension is the last one, slowest varying dimention is the first // one. IndexHandle n, m, k; unsigned n_dim = v_lhs.fastestVarying() - 1; unsigned m_dim = v_rhs.fastestVarying(); unsigned k_dim = v_rhs.fastestVarying(); IndexHandle n_lb(v_lhs.lb(n_dim)), m_lb(v_lhs.lb(m_dim)), k_lb(v_rhs.lb(k_dim)); IndexHandle n_ub(v_lhs.ub(n_dim)), m_ub(v_lhs.ub(m_dim)), k_ub(v_rhs.ub(k_dim)); int64_t n_step = v_lhs.step(n_dim), m_step = v_lhs.step(m_dim), k_step = v_rhs.step(k_dim); // Constants, indexed values and indexes to be used inside the loop nest. IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs); ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty))); LoopBuilder(&n, n_lb, n_ub, n_step)([&] { LoopBuilder(&k, k_lb, k_ub, k_step)([&] { i_res(n, k) = zero_init; LoopBuilder(&m, m_lb, m_ub, m_step)( [&] { i_res(n, k) += i_lhs(n, m) * i_rhs(m, k); }); }); }); rewriter.replaceOp(op, {result}); } REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); } #undef REWRITER } namespace mlir { Pass* createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler) { return new DialectLoweringPass(*compiler); } }