//***************************************************************************** // Copyright 2017-2019 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** // NOTE: This file follows nGraph format style and MLIR naming convention since it does // not expose public API to the rest of nGraph codebase and heavily depends on MLIR API. #include "affine_lowerer.hpp" #include "contrib/mlir/core/ngraph_dialect/ops.hpp" #include "contrib/mlir/core/ngraph_dialect/type.hpp" #include "ngraph/assertion.hpp" #include <llvm/ADT/DenseSet.h> #include <mlir/EDSC/Builders.h> #include <mlir/EDSC/Helpers.h> #include <mlir/EDSC/Intrinsics.h> #include <mlir/IR/AffineExpr.h> #include <mlir/IR/IntegerSet.h> #include <mlir/IR/MLIRContext.h> #include <mlir/IR/StandardTypes.h> #include <mlir/Transforms/DialectConversion.h> #include <map> #define PASS_NAME "convert-ngraph-to-affine" #define DEBUG_TYPE PASS_NAME // anonymous namespace // no need to expose any of the following outside of this file namespace { using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::op; using namespace ngraph::runtime; using namespace ngraph::runtime::ngmlir; // Index notation to generate standard (i.e., non-affine) loads and stores. using StdIndexedValue = TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>; class DialectLoweringPass; /// Base class for nGraph operation conversions to affine/standard dialect. Provides /// conversion patterns with an access to the DialectLoweringPass which holds the state of the /// conversion. class NGraphOpLowering : public ConversionPattern { public: NGraphOpLowering(StringRef rootOpName, MLIRContext* context, DialectLoweringPass& pass) : ConversionPattern(rootOpName, /*benefit=*/1, context) , pass(pass){}; protected: // Back-reference to the lowering pass which contains the lowering state, including the // nGraph type converter. DialectLoweringPass& pass; }; // Conversion classes declarations #define MLIR_OP(OP, INPLACE) \ class OP##Conversion : public NGraphOpLowering \ { \ public: \ explicit OP##Conversion(mlir::MLIRContext* context, DialectLoweringPass& pass) \ : NGraphOpLowering(mlir::OP::getOperationName(), context, pass) \ { \ } \ \ PatternMatchResult matchAndRewrite(Operation* op, \ ArrayRef<Value*> operands, \ ConversionPatternRewriter& rewriter) const override; \ }; #include "op_lowerers.inc" // FuncOp Conversion pattern class FuncOpSignatureConversion : public ConversionPattern { public: FuncOpSignatureConversion(MLIRContext* ctx, TypeConverter& converter) : ConversionPattern(FuncOp::getOperationName(), 1, ctx) , converter(converter) { } /// Hook for derived classes to implement combined matching and rewriting. PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) const override { auto funcOp = cast<FuncOp>(op); FunctionType type = funcOp.getType(); // Convert the original function arguments. TypeConverter::SignatureConversion result(type.getNumInputs()); for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) { if (failed(converter.convertSignatureArg(i, type.getInput(i), result))) { return matchFailure(); } } auto funcTypeResults = type.getResults(); if (!funcTypeResults.empty()) { // Convert the original function results. SmallVector<Type, 4> convertedResults; if (failed(converter.convertTypes(funcTypeResults, convertedResults))) { return matchFailure(); } // Add result types as input args without mapping result.addInputs(convertedResults); } // Create a new function with an updated signature. auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); newFuncOp.setType( FunctionType::get(result.getConvertedTypes(), {/*void*/}, funcOp.getContext())); // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); rewriter.replaceOp(op, llvm::None); return matchSuccess(); } /// The type converter to use when rewriting the signature. TypeConverter& converter; }; // Helpers template <typename RedOp> void lowerIndexReduction(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& pass); template <typename OP> void lowerBinaryElementwise(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& pass); template <typename OP> void lowerUnaryElementwise(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& pass); ValueHandle createZeroConstant(mlir::Type type); /// Conversion from types in the nGraph dialect to the Standard dialect. class NGraphTypeConverter : public TypeConverter { public: NGraphTypeConverter() : TypeConverter() { } Type convertType(Type t) override; }; /// Dialect Lowering Pass to affine ops class DialectLoweringPass : public ModulePass<DialectLoweringPass> { public: void runOnModule() override; SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter); /// Allocates a linear buffer for a temporary tensor Value* createTempBuffer(Type type, PatternRewriter& rewriter); /// Creates an allocation or view of a memref. /// type MemRef Type /// buffer Optional buffer value to create view over /// offset Optional offset into the buffer this view starts at /// /// If buffer is null, a new allocation of a memref is created. /// Offset is ignored. If buffer is non-null, then we create a temp /// view over a pre-allocated buffer (see createTempBuffer) Value* createTempMemref(Type type, Value* buffer, unsigned offset, PatternRewriter& rewriter); /// Inserts dealloc Ops for each temporary allocated by AllocOp void insertDeallocs(PatternRewriter& rewriter); NGraphTypeConverter& getTypeConverter() { return typeConverter; } private: /// Collect a set of patterns to convert from the nGraph dialect to Affine dialect. void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns); void findOutputValues(); void insertNoAliasArgAttrs(); private: NGraphTypeConverter typeConverter; // List of temporary memrefs to deallocate at end of function SmallVector<Value*, 4> memRefsToDealloc; // Ops maybe assigned mem-refs in previous memory optimization passes. // Track pre-assigned buffers for each Value and re-use it if one is available. using IdToMemRefMap = std::unordered_map<unsigned, Value*>; IdToMemRefMap m_id_to_memref; // TODO: Workaround for findOutputValues and buildOutputDefs. See NGCPU-470. std::string funcName; }; void DialectLoweringPass::runOnModule() { // Create type converter and initialize conversion patterns. NGraphTypeConverter converter; OwningRewritePatternList patterns; populateNGraphToAffineConversionPatterns(patterns); // Create target that defines legal ops for nGraph dialect to be lowered to. ConversionTarget target(getContext()); target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { // FuncOp is legal only if types have been converted to Std types. return typeConverter.isSignatureLegal(op.getType()); }); // Gather functions to be processed. Note that new functions will be added to module as part // of the function signature conversion so we have to collect the original ones before hand. SmallVector<FuncOp, 2> origFuncOps(getModule().getOps<FuncOp>()); for (auto origFunc : origFuncOps) { // TODO: Workaround for findOutputValues and buildOutputDefs. See NGCPU-470. funcName = origFunc.getName(); // Capture output values by looking for the Return and grabbing the values the order of // the returned values matches the order of the lowered func signature for results. This // is used to find the arg_id that a defined value maps to if it is an output. findOutputValues(); // NOTE: Function signature conversion creates a new FuncOp that is inserted in the // module. References the original FuncOp are no longer valid after this point. if (failed(applyFullConversion(origFunc, target, std::move(patterns), &converter))) { emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering nGraph dialect\n"); signalPassFailure(); } // TODO: Encode no alias attribute as part of the function signature conversion or as a // separate rewrite pattern. Retrieve new function after signature conversion. insertNoAliasArgAttrs(); } } void DialectLoweringPass::populateNGraphToAffineConversionPatterns( OwningRewritePatternList& patterns) { #define MLIR_OP(OP, INPLACE) OP##Conversion, #define MLIR_LAST_OP(OP, INPLACE) OP##Conversion patterns.insert< #include "op_lowerers.inc" >(&getContext(), *this); // FuncOp pattern patterns.insert<FuncOpSignatureConversion>(&getContext(), typeConverter); } void DialectLoweringPass::findOutputValues() { FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName); NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found"); SmallVector<Value*, 4> outputList; unsigned outputCount = 0; unsigned inputCount = f.getType().getNumInputs(); // we find out output values by looking at returned values // any return should return all outputs of the subgraph f.walk([this, &outputCount, inputCount](NGReturnOp ret) { for (unsigned i = 0; i < ret.getNumOperands(); i++) { // annotate instructions defining outputs with the arg idx of the output auto outputValue = ret.getOperand(i); auto op = outputValue->getDefiningOp(); op->setAttr( "graphOutputIdx", mlir::IntegerAttr::get(IntegerType::get(32, op->getContext()), i + inputCount)); } NGRAPH_CHECK(outputCount == 0 || outputCount == ret.getNumOperands(), "Inconsistent returns in function"); }); } SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op, PatternRewriter& rewriter) { FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName); NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found"); SmallVector<Value*, 4> newResults; for (auto origResult : op->getResults()) { // find output arg if this operation produces any sub-graph outputs if (IntegerAttr attr = op->getAttrOfType<IntegerAttr>("graphOutputIdx")) { mlir::Block* entryBlock = &*(f.begin()); unsigned argId = (unsigned)attr.getInt(); newResults.push_back(entryBlock->getArgument(argId)); } else { // For temporaries, we create two instructions: // 1. Linear buffer allocation: If the ng value already has a buffer ID assigned, // we re-use that linear buffer SSA value, else generate an AllocOp. // 2. View creation: Create a view with the tensor shape and an N-D to 1 map over // the linear buffer. // If two memrefs are defined via 2 Views over the same buffer, then they share and // will re-use the same buffer. auto tensorType = origResult->getType().cast<NGTensorType>(); Value* newResult = nullptr; Attribute bufferIdAttr = getBufferId(op); Type memRefType = typeConverter.convertType(tensorType); Value* bufferValue = nullptr; if (!bufferIdAttr) { // Allocate new memref newResult = createTempMemref(memRefType, nullptr, 0, rewriter); } else { unsigned bufferId = bufferIdAttr.cast<IntegerAttr>().getInt(); // Re-use a buffer if it exist, else create a new one and update map IdToMemRefMap::iterator it = m_id_to_memref.find(bufferId); if (it == m_id_to_memref.end()) { // create a new buffer bufferValue = createTempBuffer(memRefType, rewriter); m_id_to_memref[bufferId] = bufferValue; } else { bufferValue = it->second; } // Create a temp view over the linear buffer newResult = createTempMemref(memRefType, bufferValue, 0, rewriter); } NGRAPH_CHECK(newResult != nullptr, "Temp memref value is not set"); newResults.push_back(newResult); } } return newResults; } Value* DialectLoweringPass::createTempBuffer(Type type, PatternRewriter& rewriter) { MemRefType memRefType = type.cast<MemRefType>(); NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported"); // deduce linear buffer shape unsigned sizeInBytes = memRefType.getSizeInBits() / 8; MemRefType bufferType = MemRefType::get({sizeInBytes}, IntegerType::get(8, type.getContext()), {}); Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), bufferType); memRefsToDealloc.push_back(alloc); // TODO: // Enable dynamic memref allocation via call-back to nGraph allocator // We should create a list of Values representing each dynamic dim // The values would be computed based on the shape of the input to the ng op we are // lowering. // E.g. If lowering concat, Value for dynamic concat axis will be the sum of input dims. // The lowerer will generate code to compute the dims. // This is better be done via std.AllocOp but we need to make it hookable to nGraph // allocator call-back. return alloc; } Value* DialectLoweringPass::createTempMemref(Type type, Value* buffer, unsigned offset, PatternRewriter& rewriter) { NGRAPH_CHECK(offset == 0, "Only zero offset is supported"); MemRefType memRefType = type.cast<MemRefType>(); if (buffer) { // We have a buffer to map to. Create a view over it. // Create the N-D to 1D affine expression mapping the memref shape to the underlying // linear // buffer // This is simply (d0, d1, d2, .. dN-1) --> d0 * S0 + d1 * S1 ... + dN-1 * SN-1 // Where Si is the stride along the i_th dimension auto shape = memRefType.getShape(); SmallVector<int64_t, 4> strides(shape.size(), 0); strides[shape.size() - 1] = 1; for (int64_t i = shape.size() - 2; i >= 0; i--) { strides[i] = strides[i + 1] * shape[i + 1]; } auto map = makeStridedLinearLayoutMap(strides, offset, rewriter.getContext()); MemRefType newMemRefType = MemRefType::get(shape, memRefType.getElementType(), map); auto viewOp = rewriter.create<mlir::ViewOp>( buffer->getDefiningOp()->getLoc(), newMemRefType, buffer, llvm::None); return viewOp.getResult(); } // No buffer, create an atomic memref without underlying buffer NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported"); Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType); memRefsToDealloc.push_back(alloc); return alloc; } /// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe /// by nGraph op semantics. void DialectLoweringPass::insertNoAliasArgAttrs() { FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName); NGRAPH_CHECK(func, "FuncOp '" + funcName + "' not found"); unsigned int argIdx = 0; for (auto* arg : func.getArguments()) { if (arg->getType().isa<MemRefType>()) { func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext())); } ++argIdx; } } void DialectLoweringPass::insertDeallocs(PatternRewriter& rewriter) { for (auto value : memRefsToDealloc) { rewriter.create<DeallocOp>(rewriter.getUnknownLoc(), value); } } // NGDialect converters Type NGraphTypeConverter::convertType(Type type) { // We may need to refactor this code to a external utility if type conversion is needed // outside of the lowering context since NGraphTypeConverter is private. if (auto tensorType = type.dyn_cast<NGTensorType>()) { // Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType. // This may change in the future. return MemRefType::get(tensorType.getShape(), convertType(tensorType.getElementType()), {/* no map used */}, 0); } if (auto floatType = type.dyn_cast<NGFloatType>()) { // Float types are already std type. return floatType; } if (auto intType = type.dyn_cast<NGIntegerType>()) { return mlir::IntegerType::get(intType.getWidth(), intType.getContext()); } if (auto boolType = type.dyn_cast<NGBoolType>()) { return mlir::IntegerType::get(1 /* width */, boolType.getContext()); } // Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to return // the input type if the type is not supported. return type; } #define REWRITER(OP) \ PatternMatchResult OP##Conversion::matchAndRewrite( \ Operation* op, ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) const REWRITER(NGAddOp) { lowerBinaryElementwise<mlir::NGAddOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGSubOp) { lowerBinaryElementwise<mlir::NGSubOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGMulOp) { lowerBinaryElementwise<mlir::NGMulOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGDivOp) { lowerBinaryElementwise<mlir::NGDivOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGGreaterOp) { lowerBinaryElementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGLessOp) { lowerBinaryElementwise<mlir::NGLessOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGMaxOp) { lowerBinaryElementwise<mlir::NGMaxOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGMinOp) { lowerBinaryElementwise<mlir::NGMinOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGArgMaxRedOp) { lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGArgMinRedOp) { lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, pass); return matchSuccess(); } // Relu REWRITER(NGReluOp) { auto loc = cast<NGReluOp>(op).getLoc(); auto result = pass.buildOutputDefs(op, rewriter)[0]; NGRAPH_CHECK(result->getType().isa<MemRefType>()); // Note that builder's current function is still the original function body. // use getBlock to get the new block instead. // get new operands Value* lhs = operands[0]; ScopedContext scope(rewriter, loc); // Views MemRefView vRes(result), vLHS(lhs); // Index Values IndexedValue iRes(result), iLHS(lhs); // Bounds Index Handles auto lbs = vLHS.getLbs(); auto ubs = vLHS.getUbs(); // Loop induction vars auto ivs = makeIndexHandles(vLHS.rank()); auto pivs = makeIndexHandlePointers(ivs); // Steps auto steps = vLHS.getSteps(); NGRAPH_CHECK(lhs->getType().isa<MemRefType>()); Type elemTy = lhs->getType().dyn_cast<MemRefType>().getElementType(); AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] { ValueHandle val = iLHS(ivs); ValueHandle zero = createZeroConstant(elemTy); iRes(ivs) = intrinsics::select(val > zero, val, zero); }); rewriter.replaceOp(op, {result}); return matchSuccess(); } // Negative REWRITER(NGNegOp) { lowerUnaryElementwise<mlir::NGNegOp>(op, operands, rewriter, pass); return matchSuccess(); } REWRITER(NGDotOp) { auto dot = cast<NGDotOp>(op); auto loc = dot.getLoc(); // Retrieve/generate Values for operands and result. ScopedContext scope(rewriter, loc); Value* lhs = operands[0]; Value* rhs = operands[1]; Value* result = pass.buildOutputDefs(op, rewriter)[0]; NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp"); auto resultTy = result->getType().dyn_cast<MemRefType>(); auto lhsTy = lhs->getType().dyn_cast<MemRefType>(); auto rhsTy = rhs->getType().dyn_cast<MemRefType>(); NGRAPH_CHECK(resultTy, "Unexpected non-memref result type"); NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type"); NGRAPH_CHECK(rhsTy, "Unexpected non-memref RHS type"); Type elemTy = resultTy.getElementType(); NGRAPH_CHECK(elemTy == lhsTy.getElementType() && elemTy == rhsTy.getElementType(), "Types mismatch in DotOp"); // Create the following loop nest for matmul operation: // for(n, N, 1) // for(m, M, 1) // for(k, K, 1) // res[n, k] += lhs[n, m] * rhs[m, k] // TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout. MemRefView vRes(result), vLhs(lhs), vRhs(rhs); NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2, "Dot operation is only supported for 2D tensors"); // Create induction variables, lower bounds, upper bounds and steps of the loop nest. // It's important to note that MemRefView priovides lb/ub/step info is "reverse order", // i.e., fastest varying dimension is the last one, slowest varying dimention is the first // one. IndexHandle n, m, k; unsigned nDim = vLhs.fastestVarying() - 1; unsigned mDim = vRhs.fastestVarying(); unsigned kDim = vRhs.fastestVarying(); IndexHandle nLb(vLhs.lb(nDim)), mLb(vLhs.lb(mDim)), kLb(vRhs.lb(kDim)); IndexHandle nUb(vLhs.ub(nDim)), mUb(vLhs.ub(mDim)), kUb(vRhs.ub(kDim)); int64_t nStep = vLhs.step(nDim), mStep = vLhs.step(mDim), kStep = vRhs.step(kDim); // Constants and indexed values to be used inside the loop nest. IndexedValue iRes(result), iLhs(lhs), iRhs(rhs); ValueHandle zeroInit(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elemTy))); { IndexHandle n, k; LoopBuilder::makeAffine(&n, nLb, nUb, nStep)([&] { LoopBuilder::makeAffine(&k, kLb, kUb, kStep)([&] { iRes(n, k) = zeroInit; }); }); } LoopBuilder::makeAffine(&n, nLb, nUb, nStep)([&] { LoopBuilder::makeAffine(&m, mLb, mUb, mStep)([&] { LoopBuilder::makeAffine(&k, kLb, kUb, kStep)( [&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); }); }); }); rewriter.replaceOp(op, {result}); return matchSuccess(); } REWRITER(NGConcatOp) { auto concat = cast<NGConcatOp>(op); auto loc = concat.getLoc(); ScopedContext scope(rewriter, loc); // Create Value for result, and extract type info. Value* result = pass.buildOutputDefs(op, rewriter)[0]; NGRAPH_CHECK(result, "Unexpected null result in ConcatOp"); // Create view to write into result. MemRefView vRes(result); auto rank = vRes.rank(); // For each operand, generate a separate loop to copy into the target slice of "result". // We'll keep track of the slice offsets via concatenation_axis_pos. auto concatenationAxis = concat.concatenation_axis().getSExtValue(); IndexHandle concatenationAxisPos(index_t(0)); for (auto& operand : operands) { NGRAPH_CHECK(operand, "Unexpected null operand in ConcatOp"); // Assuming rank = r, and the concatenation axis is A where A<r, we'll be creating // loops of this form: // // for i_0 := 0 to operand.dims[0]: // for i_1 := 0 to operand.dims[1]: // ... // for i_(r-2) := 0 to operand.dims[r-2]: // for i_(r-1) := 0 to operand.dims[r-1]: // result[i_0][i_1]... // [i_(A-1)][i_A + concatenationAxisPos][i_(A+1)]... // [i_(r-2)][i_(r-1)] // := // operand[i_0][i_1]...[i_(r-2)][i_(r-1)] MemRefView vOperand(operand); NGRAPH_CHECK(vOperand.rank() == rank, "Unexpected rank mismatch"); llvm::SmallVector<ValueHandle, 5> indexVars; llvm::SmallVector<ValueHandle*, 5> indexVarPtrs; llvm::SmallVector<ValueHandle, 5> indexVarLbs; llvm::SmallVector<ValueHandle, 5> indexVarUbs; llvm::SmallVector<int64_t, 5> indexVarSteps; for (int i = 0; i < rank; i++) { indexVars.push_back(IndexHandle()); indexVarPtrs.push_back(&(indexVars.back())); indexVarLbs.push_back(vOperand.lb(i)); indexVarUbs.push_back(vOperand.ub(i)); indexVarSteps.push_back(vOperand.step(i)); } AffineLoopNestBuilder(indexVarPtrs, indexVarLbs, indexVarUbs, indexVarSteps)([&] { IndexedValue ivRes(result); IndexedValue ivOperand(operand); // On the LHS of the assignment, adjust the index for the concatenation axis. llvm::SmallVector<ValueHandle, 5> resIndexHandles; for (int i = 0; i < rank; i++) { resIndexHandles.push_back(i == concatenationAxis ? indexVars[i] + concatenationAxisPos : indexVars[i]); } ivRes(resIndexHandles) = ivOperand(indexVars); }); // Move up concatenation_axis_pos for the next operand. concatenationAxisPos = concatenationAxisPos + vOperand.ub(concatenationAxis); } rewriter.replaceOp(op, {result}); return matchSuccess(); } REWRITER(NGGatherOp) { auto gatherOp = cast<NGGatherOp>(op); auto loc = gatherOp.getLoc(); ScopedContext scope(rewriter, loc); // Get operands Value* result = pass.buildOutputDefs(op, rewriter)[0]; NGRAPH_CHECK(result, "Unexpected null result in GatherOp"); Value* params = operands[0]; Value* indices = operands[1]; auto axis = gatherOp.axis().getSExtValue(); // Create view to write into result. MemRefView vRes(result), vParams(params), vIndices(indices); // Indexed Values IndexedValue iRes(result), iIndices(indices); StdIndexedValue iParams(params); // Construct outer loop for params dims. Exclude the axis dim. SmallVector<ValueHandle, 4> paramsLbs, paramsUbs; SmallVector<IndexHandle, 4> paramsIVs; SmallVector<int64_t, 4> paramsSteps; SmallVector<ValueHandle*, 4> paramsIVPtrs; for (auto i = 0; i < vParams.rank(); i++) { // skip gather axis if (i == axis) continue; paramsLbs.push_back(IndexHandle(vParams.lb(i))); paramsUbs.push_back(IndexHandle(vParams.ub(i))); paramsSteps.push_back(vParams.step(i)); } NGRAPH_CHECK(paramsLbs.size() == vParams.rank() - 1 && paramsUbs.size() == paramsLbs.size() && paramsSteps.size() == paramsLbs.size(), "Incorrect loop nest bounds size for gather params"); paramsIVs = makeIndexHandles(vParams.rank() - 1); paramsIVPtrs = makeIndexHandlePointers(paramsIVs); auto indicesLbs = vIndices.getLbs(); auto indicesUbs = vIndices.getUbs(); auto indicesSteps = vIndices.getSteps(); auto indicesIVs = makeIndexHandles(vIndices.rank()); auto indicesIVPtrs = makeIndexHandlePointers(indicesIVs); SmallVector<IndexHandle, 8> paramsIndices, resIndices; // Make sure we are going to create loops NGRAPH_CHECK(vParams.rank() > 0, "Invalid size for indices steps"); // Let params rank : N // Let indices rank : M // Let axis be A // Generate // indices loops // for I_0:0 -> indices.dim[0] // ... // for I_(M-1):0 -> indices.dim[M-1] // params loops // for P_0: 0 -> params.dim[0] // for P_1: 0 -> params.dim[1] // for P_2: 0 -> params.dim[2] // ... // for P_(A-1):0 -> params.dim[A-1] // for P_(A+1):0 -> params.dim[A+1] // ... // for P_(N-1):0 -> params.dim[N-1] // res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] = // params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)], // P_(A+1), ... P_(N-1)]; AffineLoopNestBuilder(indicesIVPtrs, indicesLbs, indicesUbs, indicesSteps)([&] { // Load axis value from indices array and cast it to Index Type ValueHandle axisIdx = ValueHandle::create<IndexCastOp>( (ValueHandle)iIndices(indicesIVs), rewriter.getIndexType()); AffineLoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] { // construct indices for param // [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1] for (auto i = 0, j = 0; i < vParams.rank(); i++) { if (i == axis) { paramsIndices.push_back(IndexHandle(axisIdx)); } else { paramsIndices.push_back(paramsIVs[j++]); } } // construct indices for result // [P_0, P_1, .. P_axis-1, I0, I1, .. I_k-1, P_axis+1, P_axis+2, .. P_n-1] for (auto i = 0, j = 0; i < vParams.rank() + vIndices.rank() - 1;) { if (i == axis && indicesIVs.size() > 0) { resIndices.append(indicesIVs.begin(), indicesIVs.end()); i += indicesIVs.size(); } else { resIndices.push_back(paramsIVs[j++]); i++; } } // Store into result iRes(resIndices) = iParams(paramsIndices); }); }); rewriter.replaceOp(op, {result}); return matchSuccess(); } REWRITER(NGConvolutionOp) { auto convolOp = cast<NGConvolutionOp>(op); auto loc = convolOp.getLoc(); ScopedContext scope(rewriter, loc); // Get operands Value* result = pass.buildOutputDefs(op, rewriter)[0]; NGRAPH_CHECK(result, "Unexpected null result in Convolution Op"); Value* images = operands[0]; Value* filters = operands[1]; auto strides = convolOp.strides().getValue(); auto padBelow = convolOp.padBelow().getValue(); auto padAbove = convolOp.padBelow().getValue(); Type elemTy = images->getType().cast<MemRefType>().getElementType(); // Let Images shape be [N, C_IN, D_1, ... D_f] // Let Filters shape be [C_OUT, C_IN, F_1, ... F_f] // Output shape will be [N, C_OUT, R_1, ..R_f] // where R_i = (AdjD_i - AdjF_i + 1) / Strides[i] // // AdjD_i is adjusted image spatial dimension after padding and dilation // AdjD_i = padBelow[i] + (dilation[i] * (D_i - 1) + 1) + padAbove[i] // // AdjF_i is adjusted filters spatial dimension after dilation // AdjF_i = dilation[i] * (F_i - 1) + 1 // // If no padding, padAbove/Below[i] = 0 // If no dilation, dilation[i] is 1 // // Generate the following (currently without padding/dilation support) // // // for n : 0 -> N // for k : 0 -> C_OUT // for <r_1 .. r_f> : <0 .. 0> -> <R_1 ... R_f> // //initialize result to zero // Output[n, k, r_1, .. r_f] = 0; // // for n : 0 -> N // for k : 0 -> C_OUT // for c : 0 -> C_IN // // iterate over output spatial shape // for <r_1 .. r_f> : <0 .. 0> -> <R_1 ... R_f> // // //compute image start inputs indices // i_1 = r_1 * strides[0]; // .. // i_f = r_f * strides[f - 1]; // // iterate over kernel spatial shape // for <j_1 .. j_f> : <0 .. 0> -> <F_1 .. F_f> // Output[n, k, r_1, .. r_f] += // Images[n, c, i_1 + j_1, .. i_f + j_f] * Filters[k, c, j_1, .. j_f] // With padding, we check (using IntegerSets) whether each spatial dim in Images lie inside // non-padded spatial region. If true, we perform the computation: // // for <j_1 .. j_f> : <0 .. 0> -> <F_1 .. F_f> // if(indices in non-padded region): // Output[n, k, r_1, .. r_f] += // Images[n, c, i_1 + j_1, .. i_f + j_f] * Filters[k, c, j_1, .. j_f] // Create view to write into result. MemRefView vRes(result), vImages(images), vFilters(filters); // Indexed Values IndexedValue iRes(result), iImages(images), iFilters(filters); // Bounds on batch size N ValueHandle batchLb = vImages.lb(0), batchUb = vImages.ub(0); // Bounds on number of filters ValueHandle numFiltersLb = vFilters.lb(0), numFiltersUb = vFilters.ub(0); // Bound on number of channels ValueHandle numChannelsLb = vImages.lb(1), numChannelsUb = vImages.ub(1); // Bounds on result spatial dimensions SmallVector<ValueHandle, 4> resSpatialLbs, resSpatialUbs; SmallVector<ValueHandle, 4> imgSpatialLbs, imgSpatialUbs; SmallVector<ValueHandle, 4> filtersSpatialLbs, filtersSpatialUbs; // Spatial rank unsigned spatialRank = vImages.rank() - 2; // Result spatial indices and bounds auto resSpatialIndices = makeIndexHandles(spatialRank); auto resSpatialIndicesPtrs = makeIndexHandlePointers(resSpatialIndices); SmallVector<int64_t, 4> resSteps, filtersSteps; SmallVector<int, 4> padBelowIntValues; bool withPadding = false; for (auto i = 0; i < spatialRank; i++) { // result spatial bounds and steps resSpatialLbs.push_back(vRes.lb(i + 2)); resSpatialUbs.push_back(vRes.ub(i + 2)); resSteps.push_back(vRes.step(i + 2)); // image spatial bounds imgSpatialLbs.push_back(vImages.lb(i + 2)); imgSpatialUbs.push_back(vImages.ub(i + 2)); // Check if we have any padding and collect pad values IntegerAttr iAttr = padBelow[i].cast<IntegerAttr>(); int padValue = iAttr.getInt(); if (padValue) { withPadding = true; } padBelowIntValues.push_back(padValue); iAttr = padAbove[i].cast<IntegerAttr>(); padValue = iAttr.getInt(); if (padValue) { withPadding = true; } } NGRAPH_CHECK(vImages.rank() == vFilters.rank(), "Images and Filters have unequal ranks"); NGRAPH_CHECK(resSpatialLbs.size() == resSpatialUbs.size() && resSpatialLbs.size() == spatialRank, "Results spatial dims mismatches input"); // Filters spatial indices and bounds auto filtersSpatialIndices = makeIndexHandles(spatialRank); auto filtersSpatialIndicesPtrs = makeIndexHandlePointers(filtersSpatialIndices); for (auto i = 0; i < spatialRank; i++) { filtersSpatialLbs.push_back(vFilters.lb(i + 2)); filtersSpatialUbs.push_back(vFilters.ub(i + 2)); filtersSteps.push_back(vFilters.step(i + 2)); } IntegerSet nonPaddedRange; if (withPadding) { // Create affine expressions and IntegerSet // IntegerSet (d0, d1, .. d_N-1)[LB_0, LB_1, .. LB_N-1, UB_0, UB_1, .. UB_N-1], where // for each dim: // (d_dim - padBelow[dim] - LB_dim >= 0), // (padBelow[dim] + UB_dim - d_dim - 1 >= 0) SmallVector<AffineExpr, 4> affineExprs; // Bool to indicate if expr is equality or inequality SmallVector<bool, 4> isEq; for (unsigned dim = 0; dim < spatialRank; dim++) { // i_dim auto dimExpr = rewriter.getAffineDimExpr(dim); auto imgLbExpr = rewriter.getAffineSymbolExpr(dim); // expr1 : i_dim - padBelow[dim] - imgLB >= 0 auto padBelowExpr = rewriter.getAffineConstantExpr(padBelowIntValues[dim]); affineExprs.push_back(dimExpr - padBelowExpr - imgLbExpr); isEq.push_back(false); // expr2: padBelow[dim] + imgUB - i_dim - 1 >= 0 auto imgUbExpr = rewriter.getAffineSymbolExpr(spatialRank + dim); auto oneExpr = rewriter.getAffineConstantExpr(1); affineExprs.push_back(padBelowExpr + imgUbExpr - dimExpr - oneExpr); isEq.push_back(false); } NGRAPH_CHECK(affineExprs.size() == isEq.size() && isEq.size() == 2 * spatialRank, "Invalid number of expressions in the IntegerSet"); nonPaddedRange = IntegerSet::get(spatialRank, 2 * spatialRank, affineExprs, isEq); } // Initialize output to zero { IndexHandle n, k, c; auto resSpatialIndices = makeIndexHandles(spatialRank); auto resSpatialIndicesPtrs = makeIndexHandlePointers(resSpatialIndices); LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] { LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] { AffineLoopNestBuilder( resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] { SmallVector<IndexHandle, 4> resIndices; // Result indices resIndices.push_back(n); resIndices.push_back(k); resIndices.insert( resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end()); ValueHandle zero = createZeroConstant(elemTy); iRes(resIndices) = zero; }); }); }); } IndexHandle n, k, c; // Convolution loop LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] { // Number of filters loop LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] { // Channels loop LoopBuilder::makeAffine(&c, numChannelsLb, numChannelsUb, 1)([&] { // Results loop AffineLoopNestBuilder( resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] { // Compute image start indices SmallVector<IndexHandle, 4> imgStartIndices; for (auto i = 0; i < spatialRank; i++) { IntegerAttr iAttr = strides[i].cast<IntegerAttr>(); auto stride = intrinsics::constant_index(iAttr.getInt()); imgStartIndices.push_back(IndexHandle(resSpatialIndices[i] * stride)); } SmallVector<IndexHandle, 4> resIndices; // Result indices resIndices.push_back(n); resIndices.push_back(k); resIndices.insert( resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end()); // Filters spatial loop AffineLoopNestBuilder(filtersSpatialIndicesPtrs, filtersSpatialLbs, filtersSpatialUbs, filtersSteps)([&] { SmallVector<IndexHandle, 4> imgIndices, filtersIndices; // Image indices // Here we compute the virtual start index into the padded image. imgIndices.push_back(n); imgIndices.push_back(c); for (auto i = 0; i < spatialRank; i++) { imgIndices.push_back( IndexHandle(imgStartIndices[i] + filtersSpatialIndices[i])); } // Filter indices filtersIndices.push_back(k); filtersIndices.push_back(c); filtersIndices.insert(filtersIndices.end(), filtersSpatialIndices.begin(), filtersSpatialIndices.end()); if (withPadding) { // if args : img dims, img lbs, img ubs SmallVector<IndexHandle, 4>::iterator it = imgIndices.begin(); std::advance(it, 2); SmallVector<Value*, 4> affineIfArgs(it, imgIndices.end()); affineIfArgs.insert( affineIfArgs.end(), imgSpatialLbs.begin(), imgSpatialLbs.end()); affineIfArgs.insert( affineIfArgs.end(), imgSpatialUbs.begin(), imgSpatialUbs.end()); auto affineIfOp = rewriter.create<AffineIfOp>(rewriter.getUnknownLoc(), nonPaddedRange, affineIfArgs, /*withElseRegion=*/false); { auto rewriter = affineIfOp.getThenBodyBuilder(); ScopedContext scope(rewriter, loc); // We must subtract pad below before img load, since the // physical image is not padded SmallVector<IndexHandle, 4> adjustedImgIndices; adjustedImgIndices.push_back(n); adjustedImgIndices.push_back(c); for (auto i = 0; i < spatialRank; i++) { adjustedImgIndices.push_back(IndexHandle( imgIndices[2 + i] - intrinsics::constant_index(padBelowIntValues[i]))); } iRes(resIndices) = iRes(resIndices) + (iImages(adjustedImgIndices) * iFilters(filtersIndices)); } } else { iRes(resIndices) = iRes(resIndices) + (iImages(imgIndices) * iFilters(filtersIndices)); } }); }); }); }); }); rewriter.replaceOp(op, {result}); return matchSuccess(); } REWRITER(NGReturnOp) { pass.insertDeallocs(rewriter); rewriter.replaceOpWithNewOp<ReturnOp>(op); return matchSuccess(); } #undef REWRITER /// End of pattern matchers template <typename OP> void lowerUnaryElementwise(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& pass) { auto loc = cast<OP>(op).getLoc(); auto result = pass.buildOutputDefs(op, rewriter)[0]; NGRAPH_CHECK(result->getType().isa<MemRefType>()); // Note that builder's current function is still the original function body. // use getBlock to get the new block instead. // get new operands Value* lhs = operands[0]; ScopedContext scope(rewriter, loc); // Views MemRefView vRes(result), vLHS(lhs); // Index Values IndexedValue iRes(result), iLHS(lhs); // Bounds Index Handles auto lbs = vLHS.getLbs(); auto ubs = vLHS.getUbs(); // Loop induction vars auto ivs = makeIndexHandles(vLHS.rank()); auto pivs = makeIndexHandlePointers(ivs); // Steps auto steps = vLHS.getSteps(); NGRAPH_CHECK(lhs->getType().isa<MemRefType>()); Type elemTy = lhs->getType().cast<MemRefType>().getElementType(); AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] { ValueHandle val = iLHS(ivs); if (isa<NGNegOp>(op)) { ValueHandle zero = createZeroConstant(elemTy); iRes(ivs) = zero - val; } else { NGRAPH_CHECK(false, "Unsupported op"); } }); rewriter.replaceOp(op, {result}); } template <typename OP> void lowerBinaryElementwise(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& pass) { auto loc = cast<OP>(op).getLoc(); auto result = pass.buildOutputDefs(op, rewriter)[0]; NGRAPH_CHECK(result->getType().isa<MemRefType>()); // get new operands Value* lhs = operands[0]; Value* rhs = operands[1]; ScopedContext scope(rewriter, loc); // Views MemRefView vRes(result), vLHS(lhs), vRHS(rhs); // Index Values IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); // Bounds Index Handles auto lbs = vLHS.getLbs(); auto ubs = vLHS.getUbs(); // Loop induction vars auto ivs = makeIndexHandles(vLHS.rank()); auto pivs = makeIndexHandlePointers(ivs); // Steps auto steps = vLHS.getSteps(); AffineLoopNestBuilder(pivs, lbs, ubs, steps)( // single stmt body [&] { if (isa<NGAddOp>(op)) { iRes(ivs) = iLHS(ivs) + iRHS(ivs); } else if (isa<NGSubOp>(op)) { iRes(ivs) = iLHS(ivs) - iRHS(ivs); } else if (isa<NGMulOp>(op)) { iRes(ivs) = iLHS(ivs) * iRHS(ivs); } else if (isa<NGDivOp>(op)) { iRes(ivs) = iLHS(ivs) / iRHS(ivs); } else if (isa<NGGreaterOp>(op)) { iRes(ivs) = ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)); } else if (isa<NGLessOp>(op)) { iRes(ivs) = ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)); } else if (isa<NGMaxOp>(op)) { iRes(ivs) = edsc::intrinsics::select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)), ValueHandle(iLHS(ivs)), ValueHandle(iRHS(ivs))); } else if (isa<NGMinOp>(op)) { iRes(ivs) = edsc::intrinsics::select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)), ValueHandle(iLHS(ivs)), ValueHandle(iRHS(ivs))); } else { NGRAPH_CHECK(false, "Unsupported op"); } }); rewriter.replaceOp(op, {result}); } template <typename RedOp> void lowerIndexReduction(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& pass) { static_assert(std::is_same<RedOp, NGArgMinRedOp>() || std::is_same<RedOp, NGArgMaxRedOp>(), "Template parameter is not supported by lowerIndexReduction"); RedOp redOp = cast<RedOp>(op); auto loc = redOp.getLoc(); auto axesAttr = redOp.axes(); NGRAPH_CHECK(axesAttr.size() == 1, "Index Reduction op should have one reduction axis"); Attribute axisAttr = *axesAttr.begin(); unsigned axis = axisAttr.dyn_cast<IntegerAttr>().getInt(); NGRAPH_CHECK(operands.size() == 1 && operands[0] != nullptr, "Expected one non-null operand in Index Reduction op"); // Retrieve/generate Values for operands and result. ScopedContext scope(rewriter, loc); Value* arg = operands[0]; Value* result = pass.buildOutputDefs(op, rewriter)[0]; // Views MemRefView vRes(result), vArg(arg); // Index Values StdIndexedValue iRes(result), stdArg(arg); IndexedValue affineArg(arg); // Bounds Index Handles auto resLbs = vRes.getLbs(); auto resUbs = vRes.getUbs(); auto argLbs = vArg.getLbs(); auto argUbs = vArg.getUbs(); Type resTy = result->getType().cast<MemRefType>().getElementType(); // Generate loop nest that initializes result to lower bound of the axis to be reduced. { auto ivs = makeIndexHandles(vRes.rank()); auto pivs = makeIndexHandlePointers(ivs); auto steps = vRes.getSteps(); auto initVal = vArg.lb(axis); AffineLoopNestBuilder(pivs, resLbs, resUbs, steps)( [&] { iRes(ivs) = ValueHandle::create<IndexCastOp>(initVal, resTy); }); } // Generate loop nest that computes the actual index reduction. { auto allIVs = makeIndexHandles(vArg.rank()); auto pAllIVs = makeIndexHandlePointers(allIVs); auto steps = vArg.getSteps(); SmallVector<IndexHandle, 8> nonRedIVs; Type resTy = result->getType().cast<MemRefType>().getElementType(); NGRAPH_CHECK(resTy.isa<IntegerType>(), "Expected integer result type in index reduction"); // iterate over all argument dimensions AffineLoopNestBuilder(pAllIVs, argLbs, argUbs, steps)([&] { // build a list of non-reduction IVs for (auto i = 0; i < vArg.rank(); i++) { if (i != axis) { nonRedIVs.push_back(allIVs[i]); } } // Load current min index with integer data type and convert it to index data type. ValueHandle currRedIdx = ValueHandle::create<IndexCastOp>( (ValueHandle)iRes(nonRedIVs), IndexType::get(resTy.getContext())); // Build list of IVs including current min index. auto tempIVs = allIVs; tempIVs[axis] = currRedIdx; // Select the min/max value and cast it back to integer type before storing it. ValueHandle newRedIdx = std::is_same<RedOp, NGArgMinRedOp>() ? edsc::intrinsics::select( affineArg(allIVs) < stdArg(tempIVs), allIVs[axis], currRedIdx) : edsc::intrinsics::select( stdArg(tempIVs) < affineArg(allIVs), allIVs[axis], currRedIdx); iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy); }); } rewriter.replaceOp(op, result); } ValueHandle createZeroConstant(mlir::Type type) { if (auto floatTy = type.dyn_cast<FloatType>()) { if (floatTy.isF32()) { return intrinsics::constant_float(llvm::APFloat(0.0f), floatTy); } else if (floatTy.isF64()) { return intrinsics::constant_float(llvm::APFloat(0.0), floatTy); } else { NGRAPH_UNREACHABLE("Unsupported floating-point precision"); } } else if (auto intTy = type.dyn_cast<IntegerType>()) { return intrinsics::constant_int(0, intTy.getWidth()); } NGRAPH_UNREACHABLE("Unsupported type"); } } namespace mlir { std::unique_ptr<Pass> createDialectLoweringPass() { return std::make_unique<DialectLoweringPass>(); } } // namespace mlir static PassRegistration<DialectLoweringPass> pass(PASS_NAME, "Convert nGraph dialect to affine dialect");