lowerer.cpp 38 KB
Newer Older
1
//*****************************************************************************
nmostafa's avatar
nmostafa committed
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15 16
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

17 18 19
// NOTE: This file follows nGraph format style and MLIR naming convention since it does
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.

20
#include "lowerer.hpp"
21

Nagy Mostafa's avatar
Nagy Mostafa committed
22
#include "compiler.hpp"
23 24
#include "dialect/ops.hpp"
#include "dialect/type.hpp"
25 26
#include "ngraph/assertion.hpp"

27 28 29 30 31 32 33 34 35 36
#include <llvm/ADT/DenseSet.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Transforms/DialectConversion.h>

#include <map>

37 38 39 40 41 42
// anonymous namespace
// no need to expose any of the following outside of this file
namespace
{
    using namespace mlir;
    using namespace mlir::edsc;
43
    using namespace mlir::edsc::op;
44
    using namespace ngraph::runtime;
45
    using namespace ngraph::runtime::ngmlir;
46 47

    class DialectLoweringPass;
48

49 50 51 52 53 54 55 56
    /// Base class for nGraph operation conversions to affine/standard dialect. Provides
    /// conversion patterns with an access to the DialectLoweringPass which holds the state of the
    /// conversion.
    class NGraphOpLowering : public ConversionPattern
    {
    public:
        NGraphOpLowering(StringRef rootOpName, MLIRContext* context, DialectLoweringPass& pass)
            : ConversionPattern(rootOpName, /*benefit=*/1, context)
57
            , pass(pass){};
58 59 60 61

    protected:
        // Back-reference to the lowering pass which contains the lowering state, including the
        // nGraph type converter.
62
        DialectLoweringPass& pass;
63 64
    };

65 66 67 68 69 70 71 72 73 74 75 76
// Conversion classes declarations
#define MLIR_OP(OP)                                                                                \
    class OP##Conversion : public NGraphOpLowering                                                 \
    {                                                                                              \
    public:                                                                                        \
        explicit OP##Conversion(mlir::MLIRContext* context, DialectLoweringPass& pass)             \
            : NGraphOpLowering(mlir::OP::getOperationName(), context, pass)                        \
        {                                                                                          \
        }                                                                                          \
                                                                                                   \
        PatternMatchResult matchAndRewrite(Operation* op,                                          \
                                           ArrayRef<Value*> operands,                              \
77
                                           ConversionPatternRewriter& rewriter) const override;    \
78 79
    };

80 81
#include "op_lowerers.inc"

nmostafa's avatar
nmostafa committed
82
    // Helpers
83 84 85 86
    template <typename RedOp>
    void lowerIndexReduction(Operation* op,
                             ArrayRef<Value*> operands,
                             PatternRewriter& rewriter,
87
                             DialectLoweringPass& pass);
nmostafa's avatar
nmostafa committed
88

89 90 91 92
    template <typename OP>
    void lower_binary_elementwise(Operation* op,
                                  ArrayRef<Value*> operands,
                                  PatternRewriter& rewriter,
93
                                  DialectLoweringPass& pass);
94

95 96
    /// Conversion from types in the nGraph dialect to the Standard dialect.
    class NGraphTypeConverter : public TypeConverter
97 98
    {
    public:
99 100
        NGraphTypeConverter()
            : TypeConverter()
101 102 103 104 105 106 107 108 109 110
        {
        }

        Type convertType(Type t) override;
    };

    /// Dialect Lowering Pass to affine ops
    class DialectLoweringPass : public ModulePass<DialectLoweringPass>
    {
    public:
111
        DialectLoweringPass(ngmlir::MLIRCompiler& compiler)
112
            : compiler(compiler)
113 114
        {
        }
115

116
        void runOnModule() override;
117
        SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
118
        Value* createTempTensor(Type type, PatternRewriter& rewriter);
119

120 121 122 123
        mlir::FuncOp getCallDecl(StringRef name,
                                 ArrayRef<Type> args,
                                 ArrayRef<Type> output,
                                 PatternRewriter& rewriter);
124

125 126 127
        /// Inserts dealloc Ops for each temporary allocated by AllocOp
        void insertDeallocs(PatternRewriter& rewriter);

128
        NGraphTypeConverter& getTypeConverter() { return typeConverter; }
129
    private:
130 131 132
        /// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
        void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);

133
        void findOutputValues();
Nagy Mostafa's avatar
Nagy Mostafa committed
134
        void processFakeInstrs();
135
        void insertNoAliasArgAttrs();
136
        Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
137 138

    private:
139
        NGraphTypeConverter typeConverter;
Nagy Mostafa's avatar
Nagy Mostafa committed
140
        // Value holding mem manager passed pointer
141
        SmallVector<Value*, 4> memMgrDefs;
142
        // List of temporary memrefs to deallocate at end of function
143
        SmallVector<Value*, 4> memRefsToDealloc;
144
        // list of results values to add to func signature
145 146
        SmallVector<Value*, 4> loweredOutputValues;
        ngmlir::MLIRCompiler& compiler;
147 148 149 150
    };

    void DialectLoweringPass::runOnModule()
    {
151 152 153
        // Create type converter and initialize conversion patterns.
        NGraphTypeConverter converter;
        OwningRewritePatternList patterns;
154 155 156
        // Add default FuncOp type conversion. It replaces the incoming FuncOp with a *new* one
        // with the converted types.
        mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), typeConverter);
157 158 159 160 161 162 163
        populateNGraphToAffineConversionPatterns(patterns);

        // Create target that defines legal ops for nGraph dialect to be lowered to.
        ConversionTarget target(getContext());
        // TODO: Remove NGFakeInputOp. We need to set NGFakeInputOp as legal op because we generate
        // it as part of the lowering to affine/standard.
        target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
164 165 166 167 168
        target.addLegalOp<ModuleOp, ModuleTerminatorOp, NGFakeInputOp>();
        target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
            // FuncOp is legal only if types have been converted to Std types.
            return typeConverter.isSignatureLegal(op.getType());
        });
169

170 171 172 173 174
        // capture output values by looking for the Return and grabbing the values
        // the order of the returned values matches the order of the lowered func signature for
        // results. This is used to find the arg_id that a defined value maps to if it is an output
        findOutputValues();

175
        if (failed(applyFullConversion(getModule(), target, std::move(patterns), &converter)))
176
        {
177
            emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering nGraph dialect\n");
178 179
            signalPassFailure();
        }
nmostafa's avatar
nmostafa committed
180

Nagy Mostafa's avatar
Nagy Mostafa committed
181
        processFakeInstrs();
182
        insertNoAliasArgAttrs();
183 184
    }

185 186 187
    void DialectLoweringPass::populateNGraphToAffineConversionPatterns(
        OwningRewritePatternList& patterns)
    {
188 189 190 191 192
#define MLIR_OP(OP) OP##Conversion,
#define MLIR_LAST_OP(OP) OP##Conversion
        RewriteListBuilder<
#include "op_lowerers.inc"
            >::build(patterns, &getContext(), *this);
193 194
    }

195 196
    void DialectLoweringPass::findOutputValues()
    {
Nagy Mostafa's avatar
Nagy Mostafa committed
197
        // get original function
198
        auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
199 200 201 202 203
        SmallVector<Value*, 4> outputList;
        unsigned outputCount = 0;

        // we find out output values by looking at returned values
        // any return should return all outputs of the subgraph
204
        f.walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) {
205 206
            for (unsigned i = 0; i < ret.getNumOperands(); i++)
            {
207 208 209 210
                auto outputValue = ret.getOperand(i);
                auto op = outputValue->getDefiningOp();
                op->setAttr("graphOutputIdx",
                            mlir::IntegerAttr::get(IntegerType::get(8, op->getContext()), i));
211
            }
212 213
            NGRAPH_CHECK(outputCount == 0 || outputCount == ret.getNumOperands(),
                         "Inconsistent returns in function");
214 215 216
            outputCount = ret.getNumOperands();
        });
        // will be populated with lowered output values later
217 218 219
        // TODO: This resize is making debugging obscure. When the container is not populated due
        // to a bug, null pointers are used by the consumer leading to a crash more difficult to
        // root-cause. We should try to change the current approach or introduce verification code.
220
        loweredOutputValues.resize(outputCount, nullptr);
221 222
    }

Nagy Mostafa's avatar
Nagy Mostafa committed
223
    /// Inserts a fake def for Mem Mgr pointer at converted func start
224
    Value* DialectLoweringPass::insertMemMgrDef(PatternRewriter* rewriter)
Nagy Mostafa's avatar
Nagy Mostafa committed
225 226 227
    {
        // it would be nice to insert one fake def at the start of the new func
        // however, due to how DialectConversion framework works, new func is only
228 229 230
        // materialized after conversion is done (rewriter->getFunction, or even
        // rewriter->getInsertionBlock()->getFunction() will give you the original func). This
        // makes it very convoluted to insert instructions at entry block.
231
        auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(),
232
                                                  IndexType::get(&getContext()));
Nagy Mostafa's avatar
Nagy Mostafa committed
233
        // will be fixed later to read passed arg instead.
234
        memMgrDefs.push_back(op.getResult());
Nagy Mostafa's avatar
Nagy Mostafa committed
235 236 237 238
        return op.getResult();
    }

    SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
239
                                                                PatternRewriter& rewriter)
Nagy Mostafa's avatar
Nagy Mostafa committed
240 241 242 243 244
    {
        SmallVector<Value*, 4> newResults;
        for (auto origResult : op->getResults())
        {
            // create output def if this operation produces any sub-graph outputs
245
            if (IntegerAttr attr = op->getAttrOfType<IntegerAttr>("graphOutputIdx"))
Nagy Mostafa's avatar
Nagy Mostafa committed
246
            {
247
                unsigned argId = (int)attr.getInt();
248
                auto fakeOp = rewriter.create<NGFakeInputOp>(
249
                    op->getLoc(),
250
                    typeConverter.convertType(origResult->getType()) /* convert to lowered type */
251 252 253 254
                    );
                // Fake instrution is short-lived. Verify here.
                fakeOp.verify();
                auto newResult = fakeOp.getResult();
Nagy Mostafa's avatar
Nagy Mostafa committed
255
                newResults.push_back(newResult);
256
                loweredOutputValues[argId] = newResult;
Nagy Mostafa's avatar
Nagy Mostafa committed
257 258 259
            }
            else
            {
260
                auto tensorType = origResult->getType().cast<NGTensorType>();
261
                auto newResult = createTempTensor(typeConverter.convertType(tensorType), rewriter);
Nagy Mostafa's avatar
Nagy Mostafa committed
262 263 264 265 266 267
                newResults.push_back(newResult);
            }
        }
        return newResults;
    }

268
    Value* DialectLoweringPass::createTempTensor(Type type, PatternRewriter& rewriter)
269
    {
270 271 272 273 274
        MemRefType memRefType = type.cast<MemRefType>();

        NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported");

        Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType);
275
        memRefsToDealloc.push_back(alloc);
276 277 278 279 280 281 282 283 284 285

        // TODO:
        // Enable dynamic memref allocation via call-back to nGraph allocator
        // We should create a list of Values representing each dynamic dim
        // The values would be computed based on the shape of the input to the ng op we are lowering.
        // E.g. If lowering concat, Value for dynamic concat axis will be the sum of input dims.
        // The lowerer will generate code to compute the dims.
        // This is better be done via std.AllocOp but we need to make it hookable to nGraph allocator call-back.

        return alloc;
286 287
    }

Nagy Mostafa's avatar
Nagy Mostafa committed
288 289 290
    void DialectLoweringPass::processFakeInstrs()
    {
        auto context = getModule().getContext();
291 292 293
        auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
        mlir::Block* entryBlock = &*(f.begin());
        auto oldFuncType = f.getType();
Nagy Mostafa's avatar
Nagy Mostafa committed
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
        ArrayRef<mlir::Type> ipArgs = oldFuncType.getInputs();
        ArrayRef<mlir::Type> opArgs = oldFuncType.getResults();
        SmallVector<mlir::Type, 4> allArgs;

        // Move all args as inputs in new type
        for (auto type : ipArgs)
        {
            allArgs.push_back(type);
        }
        for (auto type : opArgs)
        {
            allArgs.push_back(type);
            // add new value for result
            entryBlock->addArgument(type);
        }
        // Mem Manager Ptr
        auto indexType = mlir::IndexType::get(context);
        allArgs.push_back(indexType);
        entryBlock->addArgument(indexType);
        // update type
        auto newFuncType = mlir::FunctionType::get(allArgs, {}, context);
315
        f.setType(newFuncType);
Nagy Mostafa's avatar
Nagy Mostafa committed
316 317 318

        // RAUW fake outputs with result values
        unsigned i = 0;
319
        for (auto value : loweredOutputValues)
Nagy Mostafa's avatar
Nagy Mostafa committed
320 321
        {
            auto op = value->getDefiningOp();
322
            NGRAPH_CHECK(isa<NGFakeInputOp>(op), "output value not defined by fake output?");
Nagy Mostafa's avatar
Nagy Mostafa committed
323 324 325 326
            value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
            op->erase();
            i++;
        }
327
        for (auto v : memMgrDefs)
Nagy Mostafa's avatar
Nagy Mostafa committed
328
        {
329
            v->replaceAllUsesWith(entryBlock->getArgument(compiler.get_mem_mgr_arg_id(f)));
Nagy Mostafa's avatar
Nagy Mostafa committed
330 331 332 333
            v->getDefiningOp()->erase();
        }
    }

334 335 336 337
    /// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
    /// by nGraph op semantics.
    void DialectLoweringPass::insertNoAliasArgAttrs()
    {
338
        auto func = getModule().lookupSymbol<mlir::FuncOp>("main");
339
        unsigned int argIdx = 0;
340
        for (auto* arg : func.getArguments())
341 342 343
        {
            if (arg->getType().isa<MemRefType>())
            {
344
                func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
345 346 347 348 349 350
            }

            ++argIdx;
        }
    }

351 352
    void DialectLoweringPass::insertDeallocs(PatternRewriter& rewriter)
    {
353
        for (auto value : memRefsToDealloc)
354 355 356 357 358
        {
            rewriter.create<DeallocOp>(rewriter.getUnknownLoc(), value);
        }
    }

359 360 361 362
    mlir::FuncOp DialectLoweringPass::getCallDecl(StringRef name,
                                                  ArrayRef<Type> args,
                                                  ArrayRef<Type> output,
                                                  PatternRewriter& rewriter)
Nagy Mostafa's avatar
Nagy Mostafa committed
363
    {
364 365
        auto callBackFunc = getModule().lookupSymbol<mlir::FuncOp>(name);
        if (!callBackFunc)
Nagy Mostafa's avatar
Nagy Mostafa committed
366 367
        {
            auto callBackType = rewriter.getFunctionType(args, output);
368 369
            auto callBackFunc = mlir::FuncOp::create(rewriter.getUnknownLoc(), name, callBackType);
            getModule().push_back(callBackFunc);
Nagy Mostafa's avatar
Nagy Mostafa committed
370
        }
371
        return callBackFunc;
Nagy Mostafa's avatar
Nagy Mostafa committed
372
    }
373

374
    // NGDialect converters
375
    Type NGraphTypeConverter::convertType(Type type)
Nagy Mostafa's avatar
Nagy Mostafa committed
376
    {
377
        // We may need to refactor this code to a external utility if type conversion is needed
378
        // outside of the lowering context since NGraphTypeConverter is private.
379

380
        if (auto tensorType = type.dyn_cast<NGTensorType>())
Nagy Mostafa's avatar
Nagy Mostafa committed
381
        {
382 383
            // Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
            // This may change in the future.
384 385
            return MemRefType::get(tensorType.getShape(),
                                   convertType(tensorType.getElementType()),
386 387
                                   {/* no map used */},
                                   0);
Nagy Mostafa's avatar
Nagy Mostafa committed
388
        }
389
        if (auto floatType = type.dyn_cast<NGFloatType>())
390
        {
391
            // Float types are already std type.
392
            return floatType;
393
        }
394
        if (auto intType = type.dyn_cast<NGIntegerType>())
395
        {
396
            return mlir::IntegerType::get(intType.getWidth(), intType.getContext());
397
        }
398
        if (auto boolType = type.dyn_cast<NGBoolType>())
399
        {
400
            return mlir::IntegerType::get(1 /* width */, boolType.getContext());
401
        }
402

403 404
        // Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to return
        // the input type if the type is not supported.
405
        return type;
Nagy Mostafa's avatar
Nagy Mostafa committed
406 407
    }

408
#define REWRITER(OP)                                                                               \
409
    PatternMatchResult OP##Conversion::matchAndRewrite(                                            \
410
        Operation* op, ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) const
411 412

    REWRITER(NGAddOp)
413
    {
414
        lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
415 416
        return matchSuccess();
    }
417

418
    REWRITER(NGSubOp)
419
    {
420
        lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, pass);
421 422
        return matchSuccess();
    }
423

424 425
    REWRITER(NGMulOp)
    {
426
        lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, pass);
427 428
        return matchSuccess();
    }
429

430 431
    REWRITER(NGDivOp)
    {
432
        lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, pass);
433 434
        return matchSuccess();
    }
435

436 437
    REWRITER(NGGreaterOp)
    {
438
        lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass);
439 440 441 442 443
        return matchSuccess();
    }

    REWRITER(NGLessOp)
    {
444
        lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, pass);
445 446
        return matchSuccess();
    }
447

448 449
    REWRITER(NGMaxOp)
    {
450
        lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
451 452 453 454 455
        return matchSuccess();
    }

    REWRITER(NGMinOp)
    {
456
        lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, pass);
457
        return matchSuccess();
458 459
    }

460 461
    REWRITER(NGArgMaxRedOp)
    {
462
        lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, pass);
463 464 465 466 467
        return matchSuccess();
    }

    REWRITER(NGArgMinRedOp)
    {
468
        lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, pass);
469
        return matchSuccess();
470 471
    }

472 473 474 475 476
    // Relu
    REWRITER(NGReluOp)
    {
        auto loc = cast<NGReluOp>(op).getLoc();

477
        auto result = pass.buildOutputDefs(op, rewriter)[0];
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
        NGRAPH_CHECK(result->getType().isa<MemRefType>());
        // Note that builder's current function is still the original function body.
        // use getBlock to get the new block instead.

        // get new operands
        Value* lhs = operands[0];

        ScopedContext scope(rewriter, loc);
        // Views
        MemRefView vRes(result), vLHS(lhs);
        // Index Values
        IndexedValue iRes(result), iLHS(lhs);
        // Bounds Index Handles
        auto lbs = vLHS.getLbs();
        auto ubs = vLHS.getUbs();
        // Loop induction vars
        auto ivs = IndexHandle::makeIndexHandles(vLHS.rank());
        auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
        // Steps
        auto steps = vLHS.getSteps();

        NGRAPH_CHECK(lhs->getType().isa<MemRefType>());
        Type elemTy = lhs->getType().dyn_cast<MemRefType>().getElementType();
        NGRAPH_CHECK(!elemTy.isa<FloatType>(),
                     "NGReluOp with float element type should not be lowered until MLIR supports "
                     "lowering !std.CmpF");

        LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
            ValueHandle val = iLHS(ivs);
            if (auto floatTy = elemTy.dyn_cast<FloatType>())
            {
                ValueHandle zero = intrinsics::constant_float(llvm::APFloat(0.0f), floatTy);
                iRes(ivs) = intrinsics::select(val > zero, val, zero);
            }
            else if (auto intTy = elemTy.dyn_cast<IntegerType>())
            {
                ValueHandle zero = intrinsics::constant_int(0, intTy.getWidth());
                iRes(ivs) = intrinsics::select(val > zero, val, zero);
            }
            else
            {
                NGRAPH_CHECK(false, "Unsupported type for Relu");
            }
        });

        rewriter.replaceOp(op, {result});
        return matchSuccess();
    }

527
    REWRITER(NGDotOp)
528
    {
529 530
        auto dot = cast<NGDotOp>(op);
        auto loc = dot.getLoc();
531 532 533 534 535

        // Retrieve/generate Values for operands and result.
        ScopedContext scope(rewriter, loc);
        Value* lhs = operands[0];
        Value* rhs = operands[1];
536
        Value* result = pass.buildOutputDefs(op, rewriter)[0];
537
        NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp");
538

539 540 541 542 543 544
        auto resultTy = result->getType().dyn_cast<MemRefType>();
        auto lhsTy = lhs->getType().dyn_cast<MemRefType>();
        auto rhsTy = rhs->getType().dyn_cast<MemRefType>();
        NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
        NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
        NGRAPH_CHECK(rhsTy, "Unexpected non-memref RHS type");
545

546 547
        Type elemTy = resultTy.getElementType();
        NGRAPH_CHECK(elemTy == lhsTy.getElementType() && elemTy == rhsTy.getElementType(),
548
                     "Types mismatch in DotOp");
549 550 551 552 553 554 555 556

        // Create the following loop nest for matmul operation:
        //   for(n, N, 1)
        //     for(m, M, 1)
        //       for(k, K, 1)
        //         res[n, k] += lhs[n, m] * rhs[m, k]
        // TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.

557
        MemRefView vRes(result), vLhs(lhs), vRhs(rhs);
558

559
        NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2,
560
                     "Dot operation is only supported for 2D tensors");
561

562 563 564 565
        // Create induction variables, lower bounds, upper bounds and steps of the loop nest.
        // It's important to note that MemRefView priovides lb/ub/step info is "reverse order",
        // i.e., fastest varying dimension is the last one, slowest varying dimention is the first
        // one.
566
        IndexHandle n, m, k;
567 568 569 570 571 572
        unsigned nDim = vLhs.fastestVarying() - 1;
        unsigned mDim = vRhs.fastestVarying();
        unsigned kDim = vRhs.fastestVarying();
        IndexHandle nLb(vLhs.lb(nDim)), mLb(vLhs.lb(mDim)), kLb(vRhs.lb(kDim));
        IndexHandle nUb(vLhs.ub(nDim)), mUb(vLhs.ub(mDim)), kUb(vRhs.ub(kDim));
        int64_t nStep = vLhs.step(nDim), mStep = vLhs.step(mDim), kStep = vRhs.step(kDim);
573

574
        // Constants and indexed values to be used inside the loop nest.
575 576 577 578 579 580 581
        IndexedValue iRes(result), iLhs(lhs), iRhs(rhs);
        ValueHandle zeroInit(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elemTy)));

        LoopBuilder(&n, nLb, nUb, nStep)([&] {
            LoopBuilder(&k, kLb, kUb, kStep)([&] {
                iRes(n, k) = zeroInit;
                LoopBuilder(&m, mLb, mUb, mStep)([&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); });
582 583 584
            });
        });

585
        rewriter.replaceOp(op, {result});
Adam Procter's avatar
Adam Procter committed
586 587 588 589 590 591 592 593 594 595 596

        return matchSuccess();
    }

    REWRITER(NGConcatOp)
    {
        auto concat = cast<NGConcatOp>(op);
        auto loc = concat.getLoc();
        ScopedContext scope(rewriter, loc);

        // Create Value for result, and extract type info.
597
        Value* result = pass.buildOutputDefs(op, rewriter)[0];
Adam Procter's avatar
Adam Procter committed
598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
        NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");

        // Create view to write into result.
        MemRefView vRes(result);
        auto rank = vRes.rank();

        // For each operand, generate a separate loop to copy into the target slice of "result".
        // We'll keep track of the slice offsets via concatenation_axis_pos.
        auto concatenationAxis = concat.concatenation_axis().getSExtValue();
        IndexHandle concatenationAxisPos(index_t(0));

        for (auto& operand : operands)
        {
            NGRAPH_CHECK(operand, "Unexpected null operand in ConcatOp");

            // Assuming rank = r, and the concatenation axis is A where A<r, we'll be creating
            // loops of this form:
            //
            //   for i_0 := 0 to operand.dims[0]:
            //    for i_1 := 0 to operand.dims[1]:
            //     ...
            //      for i_(r-2) := 0 to operand.dims[r-2]:
            //       for i_(r-1) := 0 to operand.dims[r-1]:
            //        result[i_0][i_1]...
            //              [i_(A-1)][i_A + concatenationAxisPos][i_(A+1)]...
            //              [i_(r-2)][i_(r-1)]
            //                  :=
            //        operand[i_0][i_1]...[i_(r-2)][i_(r-1)]
            MemRefView vOperand(operand);
            NGRAPH_CHECK(vOperand.rank() == rank, "Unexpected rank mismatch");

            llvm::SmallVector<ValueHandle, 5> indexVars;
            llvm::SmallVector<ValueHandle*, 5> indexVarPtrs;
            llvm::SmallVector<ValueHandle, 5> indexVarLbs;
            llvm::SmallVector<ValueHandle, 5> indexVarUbs;
            llvm::SmallVector<int64_t, 5> indexVarSteps;
            for (int i = 0; i < rank; i++)
            {
                indexVars.push_back(IndexHandle());
                indexVarPtrs.push_back(&(indexVars.back()));
                indexVarLbs.push_back(vOperand.lb(i));
                indexVarUbs.push_back(vOperand.ub(i));
                indexVarSteps.push_back(vOperand.step(i));
            }

            LoopNestBuilder(indexVarPtrs, indexVarLbs, indexVarUbs, indexVarSteps)([&] {
                IndexedValue ivRes(result);
                IndexedValue ivOperand(operand);

                // On the LHS of the assignment, adjust the index for the concatenation axis.
                llvm::SmallVector<ValueHandle, 5> resIndexHandles;
                for (int i = 0; i < rank; i++)
                {
                    resIndexHandles.push_back(i == concatenationAxis
                                                  ? indexVars[i] + concatenationAxisPos
                                                  : indexVars[i]);
                }

                ivRes(resIndexHandles) = ivOperand(indexVars);
            });

            // Move up concatenation_axis_pos for the next operand.
            concatenationAxisPos = concatenationAxisPos + vOperand.ub(concatenationAxis);
        }

        rewriter.replaceOp(op, {result});
664

665
        return matchSuccess();
nmostafa's avatar
nmostafa committed
666
    }
667

668
    REWRITER(NGGatherOp)
nmostafa's avatar
nmostafa committed
669
    {
670 671 672 673 674
        auto gatherOp = cast<NGGatherOp>(op);
        auto loc = gatherOp.getLoc();
        ScopedContext scope(rewriter, loc);

        // Get operands
675
        Value* result = pass.buildOutputDefs(op, rewriter)[0];
676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700
        NGRAPH_CHECK(result, "Unexpected null result in GatherOp");

        Value* params = operands[0];
        Value* indices = operands[1];
        auto axis = gatherOp.axis().getSExtValue();

        // Create view to write into result.
        MemRefView vRes(result), vParams(params), vIndices(indices);
        // Indexed Values
        IndexedValue iRes(result), iParams(params), iIndices(indices);

        // Construct outer loop for params dims. Exclude the axis dim.
        SmallVector<ValueHandle, 4> paramsLbs, paramsUbs;
        SmallVector<IndexHandle, 4> paramsIVs;
        SmallVector<int64_t, 4> paramsSteps;
        SmallVector<ValueHandle*, 4> paramsIVPtrs;
        for (auto i = 0; i < vParams.rank(); i++)
        {
            // skip gather axis
            if (i == axis)
                continue;
            paramsLbs.push_back(IndexHandle(vParams.lb(i)));
            paramsUbs.push_back(IndexHandle(vParams.ub(i)));
            paramsSteps.push_back(vParams.step(i));
        }
nmostafa's avatar
nmostafa committed
701 702 703 704
        NGRAPH_CHECK(paramsLbs.size() == vParams.rank() - 1 &&
                         paramsUbs.size() == paramsLbs.size() &&
                         paramsSteps.size() == paramsLbs.size(),
                     "Incorrect loop nest bounds size for gather params");
705

nmostafa's avatar
nmostafa committed
706
        paramsIVs = IndexHandle::makeIndexHandles(vParams.rank() - 1);
707 708 709 710 711
        paramsIVPtrs = IndexHandle::makeIndexHandlePointers(paramsIVs);

        auto indicesLbs = vIndices.getLbs();
        auto indicesUbs = vIndices.getUbs();
        auto indicesSteps = vIndices.getSteps();
nmostafa's avatar
nmostafa committed
712

713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737
        auto indicesIVs = IndexHandle::makeIndexHandles(vIndices.rank());
        auto indicesIVPtrs = IndexHandle::makeIndexHandlePointers(indicesIVs);

        SmallVector<IndexHandle, 8> paramsIndices, resIndices;

        // Make sure we are going to create loops
        NGRAPH_CHECK(vParams.rank() > 0, "Invalid size for indices steps");

        // Let params rank : N
        // Let indices rank : M
        // Let axis be A
        // Generate
        // params loops
        // for P_0: 0 -> params.dim[0]
        //   for P_1: 0 -> params.dim[1]
        //     for P_2: 0 -> params.dim[2]
        // ...
        //       for P_(A-1):0 -> params.dim[A-1]
        //         for P_(A+1):0 -> params.dim[A+1]
        // ...
        //           for P_(N-1):0 -> params.dim[N-1]
        //             indices loops
        //             for I_0:0 -> indices.dim[0]
        // ...
        //               for I_(M-1):0 -> indices.dim[M-1]
nmostafa's avatar
nmostafa committed
738
        //                 res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] =
739 740 741 742 743 744 745 746 747 748 749 750 751 752 753
        //                   params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)], P_(A+1), ... P_(N-1)];

        LoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] {
            LoopNestBuilder(indicesIVPtrs, indicesLbs, indicesUbs, indicesSteps)([&] {
                // Load axis value from indices array and cast it to Index Type
                ValueHandle axisIdx = ValueHandle::create<IndexCastOp>(
                    (ValueHandle)iIndices(indicesIVs), rewriter.getIndexType());
                // construct indices for param
                // [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1]
                for (auto i = 0, j = 0; i < vParams.rank(); i++)
                {
                    if (i == axis)
                    {
                        paramsIndices.push_back(IndexHandle(axisIdx));
                    }
nmostafa's avatar
nmostafa committed
754
                    else
755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
                    {
                        paramsIndices.push_back(paramsIVs[j++]);
                    }
                }

                // construct indices for result
                // [P_0, P_1, .. P_axis-1, I0, I1, .. I_k-1, P_axis+1, P_axis+2, .. P_n-1]
                for (auto i = 0, j = 0; i < vParams.rank() + vIndices.rank() - 1;)
                {
                    if (i == axis && indicesIVs.size() > 0)
                    {
                        resIndices.append(indicesIVs.begin(), indicesIVs.end());
                        i += indicesIVs.size();
                    }
                    else
                    {
                        resIndices.push_back(paramsIVs[j++]);
                        i++;
                    }
                }
                // Store into result
                iRes(resIndices) = iParams(paramsIndices);
            });
        });

        rewriter.replaceOp(op, {result});
781
        return matchSuccess();
nmostafa's avatar
nmostafa committed
782
    }
783

784
    REWRITER(NGReturnOp)
nmostafa's avatar
nmostafa committed
785
    {
786
        pass.insertDeallocs(rewriter);
787
        rewriter.replaceOpWithNewOp<ReturnOp>(op);
nmostafa's avatar
nmostafa committed
788 789
        return matchSuccess();
    }
790

nmostafa's avatar
nmostafa committed
791
#undef REWRITER
nmostafa's avatar
nmostafa committed
792
    /// End of pattern matchers
793 794 795 796
    template <typename OP>
    void lower_binary_elementwise(Operation* op,
                                  ArrayRef<Value*> operands,
                                  PatternRewriter& rewriter,
797
                                  DialectLoweringPass& pass)
798 799
    {
        auto loc = cast<OP>(op).getLoc();
800
        auto result = pass.buildOutputDefs(op, rewriter)[0];
801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867
        NGRAPH_CHECK(result->getType().isa<MemRefType>());
        // get new operands
        Value* lhs = operands[0];
        Value* rhs = operands[1];

        ScopedContext scope(rewriter, loc);
        // Views
        MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
        // Index Values
        IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
        // Bounds Index Handles
        auto lbs = vLHS.getLbs();
        auto ubs = vLHS.getUbs();
        // Loop induction vars
        auto ivs = IndexHandle::makeIndexHandles(vLHS.rank());
        auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
        // Steps
        auto steps = vLHS.getSteps();
        LoopNestBuilder(pivs, lbs, ubs, steps)(
            // single stmt body
            [&] {
                if (isa<NGAddOp>(op))
                {
                    iRes(ivs) = iLHS(ivs) + iRHS(ivs);
                }
                else if (isa<NGSubOp>(op))
                {
                    iRes(ivs) = iLHS(ivs) - iRHS(ivs);
                }
                else if (isa<NGMulOp>(op))
                {
                    iRes(ivs) = iLHS(ivs) * iRHS(ivs);
                }
                else if (isa<NGDivOp>(op))
                {
                    iRes(ivs) = iLHS(ivs) / iRHS(ivs);
                }
                else if (isa<NGGreaterOp>(op))
                {
                    iRes(ivs) = ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs));
                }
                else if (isa<NGLessOp>(op))
                {
                    iRes(ivs) = ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs));
                }
                else if (isa<NGMaxOp>(op))
                {
                    iRes(ivs) =
                        edsc::intrinsics::select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
                                                 ValueHandle(iLHS(ivs)),
                                                 ValueHandle(iRHS(ivs)));
                }
                else if (isa<NGMinOp>(op))
                {
                    iRes(ivs) =
                        edsc::intrinsics::select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
                                                 ValueHandle(iLHS(ivs)),
                                                 ValueHandle(iRHS(ivs)));
                }
                else
                {
                    NGRAPH_CHECK(false, "Unsupported op");
                }
            });
        rewriter.replaceOp(op, {result});
    }

868
    template <typename RedOp>
869 870 871
    void lowerIndexReduction(Operation* op,
                             ArrayRef<Value*> operands,
                             PatternRewriter& rewriter,
872
                             DialectLoweringPass& pass)
nmostafa's avatar
nmostafa committed
873
    {
874 875 876 877 878 879
        static_assert(std::is_same<RedOp, NGArgMinRedOp>() || std::is_same<RedOp, NGArgMaxRedOp>(),
                      "Template parameter is not supported by lowerIndexReduction");

        RedOp redOp = cast<RedOp>(op);
        auto loc = redOp.getLoc();
        auto axesAttr = redOp.axes();
880 881 882 883 884 885 886 887 888 889 890

        NGRAPH_CHECK(axesAttr.size() == 1, "Index Reduction op should have one reduction axis");
        Attribute axisAttr = *axesAttr.begin();
        unsigned axis = axisAttr.dyn_cast<IntegerAttr>().getInt();

        NGRAPH_CHECK(operands.size() == 1 && operands[0] != nullptr,
                     "Expected one non-null operand in Index Reduction op");

        // Retrieve/generate Values for operands and result.
        ScopedContext scope(rewriter, loc);
        Value* arg = operands[0];
891

892
        Value* result = pass.buildOutputDefs(op, rewriter)[0];
893 894 895 896 897 898 899 900 901 902

        // Views
        MemRefView vRes(result), vArg(arg);
        // Index Values
        IndexedValue iRes(result), iArg(arg);
        // Bounds Index Handles
        auto resLbs = vRes.getLbs();
        auto resUbs = vRes.getUbs();
        auto argLbs = vArg.getLbs();
        auto argUbs = vArg.getUbs();
903 904 905

        Type resTy = result->getType().cast<MemRefType>().getElementType();
        // Generate loop nest that initializes result to lower bound of the axis to be reduced.
906 907 908 909 910
        {
            auto ivs = IndexHandle::makeIndexHandles(vRes.rank());
            auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
            auto steps = vRes.getSteps();
            auto initVal = vArg.lb(axis);
911 912 913
            LoopNestBuilder(pivs, resLbs, resUbs, steps)(
                [&] { iRes(ivs) = ValueHandle::create<IndexCastOp>(initVal, resTy); });
        }
914

915 916 917 918 919 920
        // Generate loop nest that computes the actual index reduction.
        {
            auto allIVs = IndexHandle::makeIndexHandles(vArg.rank());
            auto pAllIVs = IndexHandle::makeIndexHandlePointers(allIVs);
            auto steps = vArg.getSteps();
            SmallVector<IndexHandle, 8> nonRedIVs;
921

922 923 924
            Type resTy = result->getType().cast<MemRefType>().getElementType();
            NGRAPH_CHECK(resTy.isa<IntegerType>(),
                         "Expected integer result type in index reduction");
925

926 927
            // iterate over all argument dimensions
            LoopNestBuilder(pAllIVs, argLbs, argUbs, steps)([&] {
nmostafa's avatar
nmostafa committed
928 929 930 931 932 933
                // build a list of non-reduction IVs
                for (auto i = 0; i < vArg.rank(); i++)
                {
                    if (i != axis)
                        nonRedIVs.push_back(allIVs[i]);
                }
934 935 936 937 938 939

                // Load current min index with integer data type and convert it to index data type.
                ValueHandle currRedIdx = ValueHandle::create<IndexCastOp>(
                    (ValueHandle)iRes(nonRedIVs), IndexType::get(resTy.getContext()));

                // Build list of IVs including current min index.
nmostafa's avatar
nmostafa committed
940
                auto tempIVs = allIVs;
941
                tempIVs[axis] = currRedIdx;
942

943 944 945 946 947 948 949 950 951 952 953 954 955 956
                // Select the min/max value and cast it back to integer type before storing it.
                ValueHandle newRedIdx =
                    std::is_same<RedOp, NGArgMinRedOp>()
                        ? edsc::intrinsics::select(
                              iArg(allIVs) < iArg(tempIVs), allIVs[axis], currRedIdx)
                        : edsc::intrinsics::select(
                              iArg(tempIVs) < iArg(allIVs), allIVs[axis], currRedIdx);

                iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
            });
        }

        rewriter.replaceOp(op, result);
    }
957 958
}

959
namespace mlir
960
{
961
    Pass* createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler)
962
    {
963
        return new DialectLoweringPass(*compiler);
964 965
    }
}