Commit c59ea84c authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Sang Ik Lee

[MLIR] Bump MLIR commit to c61db4bb (#3879)

* WIP

* WIP

* WIP

* WIP

* style

* WIP

* WIP

* Add err msg

* Fix headers and cleanup
parent ef922f0c
...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git) ...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git) set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions # Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID 0845ac7331e) set(MLIR_LLVM_COMMIT_ID e0f1d9d8729)
set(MLIR_COMMIT_ID 1f7893e0) set(MLIR_COMMIT_ID c61db4bb)
# MLIR environment variables. Some of them are used by LIT tool. # MLIR environment variables. Some of them are used by LIT tool.
set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project) set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project)
......
...@@ -521,7 +521,7 @@ namespace ...@@ -521,7 +521,7 @@ namespace
NGRAPH_CHECK(lhs->getType().isa<MemRefType>()); NGRAPH_CHECK(lhs->getType().isa<MemRefType>());
Type elemTy = lhs->getType().dyn_cast<MemRefType>().getElementType(); Type elemTy = lhs->getType().dyn_cast<MemRefType>().getElementType();
LoopNestBuilder(pivs, lbs, ubs, steps)([&] { AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs); ValueHandle val = iLHS(ivs);
ValueHandle zero = createZeroConstant(elemTy); ValueHandle zero = createZeroConstant(elemTy);
iRes(ivs) = intrinsics::select(val > zero, val, zero); iRes(ivs) = intrinsics::select(val > zero, val, zero);
...@@ -591,12 +591,14 @@ namespace ...@@ -591,12 +591,14 @@ namespace
{ {
IndexHandle n, k; IndexHandle n, k;
LoopBuilder(&n, nLb, nUb, nStep)( LoopBuilder::makeAffine(&n, nLb, nUb, nStep)([&] {
[&] { LoopBuilder(&k, kLb, kUb, kStep)([&] { iRes(n, k) = zeroInit; }); }); LoopBuilder::makeAffine(&k, kLb, kUb, kStep)([&] { iRes(n, k) = zeroInit; });
});
} }
LoopBuilder(&n, nLb, nUb, nStep)([&] { LoopBuilder::makeAffine(&n, nLb, nUb, nStep)([&] {
LoopBuilder(&m, mLb, mUb, mStep)([&] { LoopBuilder::makeAffine(&m, mLb, mUb, mStep)([&] {
LoopBuilder(&k, kLb, kUb, kStep)([&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); }); LoopBuilder::makeAffine(&k, kLb, kUb, kStep)(
[&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); });
}); });
}); });
...@@ -658,7 +660,7 @@ namespace ...@@ -658,7 +660,7 @@ namespace
indexVarSteps.push_back(vOperand.step(i)); indexVarSteps.push_back(vOperand.step(i));
} }
LoopNestBuilder(indexVarPtrs, indexVarLbs, indexVarUbs, indexVarSteps)([&] { AffineLoopNestBuilder(indexVarPtrs, indexVarLbs, indexVarUbs, indexVarSteps)([&] {
IndexedValue ivRes(result); IndexedValue ivRes(result);
IndexedValue ivOperand(operand); IndexedValue ivOperand(operand);
...@@ -758,12 +760,12 @@ namespace ...@@ -758,12 +760,12 @@ namespace
// params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)], // params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)],
// P_(A+1), ... P_(N-1)]; // P_(A+1), ... P_(N-1)];
LoopNestBuilder(indicesIVPtrs, indicesLbs, indicesUbs, indicesSteps)([&] { AffineLoopNestBuilder(indicesIVPtrs, indicesLbs, indicesUbs, indicesSteps)([&] {
// Load axis value from indices array and cast it to Index Type // Load axis value from indices array and cast it to Index Type
ValueHandle axisIdx = ValueHandle::create<IndexCastOp>( ValueHandle axisIdx = ValueHandle::create<IndexCastOp>(
(ValueHandle)iIndices(indicesIVs), rewriter.getIndexType()); (ValueHandle)iIndices(indicesIVs), rewriter.getIndexType());
LoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] { AffineLoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] {
// construct indices for param // construct indices for param
// [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1] // [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1]
for (auto i = 0, j = 0; i < vParams.rank(); i++) for (auto i = 0, j = 0; i < vParams.rank(); i++)
...@@ -965,8 +967,7 @@ namespace ...@@ -965,8 +967,7 @@ namespace
NGRAPH_CHECK(affineExprs.size() == isEq.size() && isEq.size() == 2 * spatialRank, NGRAPH_CHECK(affineExprs.size() == isEq.size() && isEq.size() == 2 * spatialRank,
"Invalid number of expressions in the IntegerSet"); "Invalid number of expressions in the IntegerSet");
nonPaddedRange = nonPaddedRange = IntegerSet::get(spatialRank, 2 * spatialRank, affineExprs, isEq);
rewriter.getIntegerSet(spatialRank, 2 * spatialRank, affineExprs, isEq);
} }
// Initialize output to zero // Initialize output to zero
...@@ -975,9 +976,9 @@ namespace ...@@ -975,9 +976,9 @@ namespace
auto resSpatialIndices = makeIndexHandles(spatialRank); auto resSpatialIndices = makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs = makeIndexHandlePointers(resSpatialIndices); auto resSpatialIndicesPtrs = makeIndexHandlePointers(resSpatialIndices);
LoopBuilder(&n, batchLb, batchUb, 1)([&] { LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] {
LoopBuilder(&k, numFiltersLb, numFiltersUb, 1)([&] { LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] {
LoopNestBuilder( AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] { resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
SmallVector<IndexHandle, 4> resIndices; SmallVector<IndexHandle, 4> resIndices;
// Result indices // Result indices
...@@ -994,13 +995,13 @@ namespace ...@@ -994,13 +995,13 @@ namespace
IndexHandle n, k, c; IndexHandle n, k, c;
// Convolution loop // Convolution loop
LoopBuilder(&n, batchLb, batchUb, 1)([&] { LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] {
// Number of filters loop // Number of filters loop
LoopBuilder(&k, numFiltersLb, numFiltersUb, 1)([&] { LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] {
// Channels loop // Channels loop
LoopBuilder(&c, numChannelsLb, numChannelsUb, 1)([&] { LoopBuilder::makeAffine(&c, numChannelsLb, numChannelsUb, 1)([&] {
// Results loop // Results loop
LoopNestBuilder( AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] { resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
// Compute image start indices // Compute image start indices
SmallVector<IndexHandle, 4> imgStartIndices; SmallVector<IndexHandle, 4> imgStartIndices;
...@@ -1017,10 +1018,10 @@ namespace ...@@ -1017,10 +1018,10 @@ namespace
resIndices.insert( resIndices.insert(
resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end()); resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end());
// Filters spatial loop // Filters spatial loop
LoopNestBuilder(filtersSpatialIndicesPtrs, AffineLoopNestBuilder(filtersSpatialIndicesPtrs,
filtersSpatialLbs, filtersSpatialLbs,
filtersSpatialUbs, filtersSpatialUbs,
filtersSteps)([&] { filtersSteps)([&] {
SmallVector<IndexHandle, 4> imgIndices, filtersIndices; SmallVector<IndexHandle, 4> imgIndices, filtersIndices;
// Image indices // Image indices
// Here we compute the virtual start index into the padded image. // Here we compute the virtual start index into the padded image.
...@@ -1131,7 +1132,7 @@ namespace ...@@ -1131,7 +1132,7 @@ namespace
NGRAPH_CHECK(lhs->getType().isa<MemRefType>()); NGRAPH_CHECK(lhs->getType().isa<MemRefType>());
Type elemTy = lhs->getType().cast<MemRefType>().getElementType(); Type elemTy = lhs->getType().cast<MemRefType>().getElementType();
LoopNestBuilder(pivs, lbs, ubs, steps)([&] { AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs); ValueHandle val = iLHS(ivs);
if (isa<NGNegOp>(op)) if (isa<NGNegOp>(op))
{ {
...@@ -1173,7 +1174,7 @@ namespace ...@@ -1173,7 +1174,7 @@ namespace
auto pivs = makeIndexHandlePointers(ivs); auto pivs = makeIndexHandlePointers(ivs);
// Steps // Steps
auto steps = vLHS.getSteps(); auto steps = vLHS.getSteps();
LoopNestBuilder(pivs, lbs, ubs, steps)( AffineLoopNestBuilder(pivs, lbs, ubs, steps)(
// single stmt body // single stmt body
[&] { [&] {
if (isa<NGAddOp>(op)) if (isa<NGAddOp>(op))
...@@ -1266,7 +1267,7 @@ namespace ...@@ -1266,7 +1267,7 @@ namespace
auto pivs = makeIndexHandlePointers(ivs); auto pivs = makeIndexHandlePointers(ivs);
auto steps = vRes.getSteps(); auto steps = vRes.getSteps();
auto initVal = vArg.lb(axis); auto initVal = vArg.lb(axis);
LoopNestBuilder(pivs, resLbs, resUbs, steps)( AffineLoopNestBuilder(pivs, resLbs, resUbs, steps)(
[&] { iRes(ivs) = ValueHandle::create<IndexCastOp>(initVal, resTy); }); [&] { iRes(ivs) = ValueHandle::create<IndexCastOp>(initVal, resTy); });
} }
...@@ -1282,7 +1283,7 @@ namespace ...@@ -1282,7 +1283,7 @@ namespace
"Expected integer result type in index reduction"); "Expected integer result type in index reduction");
// iterate over all argument dimensions // iterate over all argument dimensions
LoopNestBuilder(pAllIVs, argLbs, argUbs, steps)([&] { AffineLoopNestBuilder(pAllIVs, argLbs, argUbs, steps)([&] {
// build a list of non-reduction IVs // build a list of non-reduction IVs
for (auto i = 0; i < vArg.rank(); i++) for (auto i = 0; i < vArg.rank(); i++)
{ {
......
...@@ -64,7 +64,6 @@ ...@@ -64,7 +64,6 @@
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h> #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h> #include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h> #include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/ExecutionEngine/OptUtils.h> #include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Pass/PassManager.h> #include <mlir/Pass/PassManager.h>
#include <mlir/Target/LLVMIR.h> #include <mlir/Target/LLVMIR.h>
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/Module.h> #include <mlir/IR/Module.h>
#include <mlir/IR/Types.h> #include <mlir/IR/Types.h>
......
...@@ -18,12 +18,12 @@ ...@@ -18,12 +18,12 @@
// not expose public API to the rest of nGraph codebase and heavily depends on MLIR API. // not expose public API to the rest of nGraph codebase and heavily depends on MLIR API.
#include "dialect.hpp" #include "dialect.hpp"
#include <mlir/IR/DialectImplementation.h>
#include <mlir/Parser.h>
#include "ngraph/check.hpp" #include "ngraph/check.hpp"
#include "ops.hpp" #include "ops.hpp"
#include "type.hpp" #include "type.hpp"
#include <mlir/Parser.h>
using namespace mlir; using namespace mlir;
NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx) NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx)
...@@ -39,63 +39,64 @@ NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx) ...@@ -39,63 +39,64 @@ NGraphOpsDialect::NGraphOpsDialect(mlir::MLIRContext* ctx)
>(); >();
} }
mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location loc) const mlir::Type NGraphOpsDialect::parseType(mlir::DialectAsmParser& parser) const
{ {
StringRef origTypeStr = tyData;
MLIRContext* context = getContext(); MLIRContext* context = getContext();
// Process nGraph tensor type. // Process nGraph tensor type.
if (tyData.consume_front("tensor")) // failure is true
if (!parser.parseOptionalKeyword("tensor"))
{ {
if (!tyData.consume_front("<") || !tyData.consume_back(">")) llvm::SMLoc typeLoc = parser.getCurrentLocation();
if (parser.parseLess())
{ {
return (emitError(loc, "expected '<' and '>' enclosing the tensor shape: " + tyData), parser.emitError(typeLoc, "expected '<' and '>' enclosing the tensor shape");
Type()); return Type();
} }
// Get x-separated sub-strings.
SmallVector<StringRef, 8> subStrings;
tyData.split(subStrings, "x");
// Parse shape dimensions. // Parse shape dimensions.
SmallVector<int64_t, 4> shape; SmallVector<int64_t, 4> shape;
for (unsigned i = 0, end = subStrings.size() - 1; i < end; ++i) parser.parseDimensionList(shape);
{
StringRef dimStr = subStrings[i]; // Parse the current element type.
int64_t dim = -1; Type eltType;
// NOTE: `consumeInteger` returns false if an integer was parsed successfully.
if (dimStr.consumeInteger(/*Radix=*/10, dim) || !dimStr.empty())
{
return (
emitError(loc, "expected a list of '[0-9]+x' dimension specifiers: " + tyData),
Type());
}
shape.push_back(dim);
}
// Parse nGraph element type. parser.parseType(eltType);
auto elem_ty = mlir::parseType(subStrings.back(), context); if (!eltType)
if (!elem_ty)
{ {
return (emitError(loc, "Unexpected element type in tensor type: " + tyData), Type()); typeLoc = parser.getCurrentLocation();
parser.emitError(typeLoc, "Invalid tensor element type");
} }
parser.parseGreater();
return NGTensorType::get(context, elem_ty, shape); return NGTensorType::get(context, eltType, shape);
} }
else
{
// parse nGraph scalar type
return parseEltType(parser);
}
}
mlir::Type NGraphOpsDialect::parseEltType(mlir::DialectAsmParser& parser) const
{
// Process nGraph integer element types. // Process nGraph integer element types.
MLIRContext* context = getContext();
int width = 0;
bool isSigned = false;
llvm::SMLoc loc = parser.getCurrentLocation();
StringRef tyData = parser.getFullSymbolSpec();
StringRef origTypeStr = tyData;
if (tyData.startswith("i") || tyData.startswith("u")) if (tyData.startswith("i") || tyData.startswith("u"))
{ {
bool isSigned = tyData.consume_front("i"); isSigned = tyData.consume_front("i");
bool isUnsigned = tyData.consume_front("u"); tyData.consume_front("u");
NGRAPH_CHECK(isSigned != isUnsigned, "nGraph integer cannot be signed and unsigned");
unsigned width = 0; unsigned width = 0;
// NOTE: `consumeInteger` returns false if an integer was parsed successfully. // NOTE: `consumeInteger` returns false if an integer was parsed successfully.
if (tyData.consumeInteger(/*Radix=*/10, width) || width == 0 || !tyData.empty()) if (tyData.consumeInteger(/*Radix=*/10, width) || width == 0 || !tyData.empty())
{ {
return (emitError(loc, "Unexpected nGraph integer type: " + origTypeStr), Type()); parser.emitError(loc, "Unexpected nGraph integer type: " + origTypeStr);
} }
switch (width) switch (width)
...@@ -108,9 +109,7 @@ mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location lo ...@@ -108,9 +109,7 @@ mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location lo
return isSigned ? NGIntegerType::getInt32(context) : NGIntegerType::getUInt32(context); return isSigned ? NGIntegerType::getInt32(context) : NGIntegerType::getUInt32(context);
case 64: case 64:
return isSigned ? NGIntegerType::getInt64(context) : NGIntegerType::getUInt64(context); return isSigned ? NGIntegerType::getInt64(context) : NGIntegerType::getUInt64(context);
default: default: parser.emitError(loc, "Unexpected width for nGraph integer type: " + origTypeStr);
return (emitError(loc, "Unexpected width for nGraph integer type: " + origTypeStr),
Type());
} }
} }
...@@ -119,43 +118,49 @@ mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location lo ...@@ -119,43 +118,49 @@ mlir::Type NGraphOpsDialect::parseType(llvm::StringRef tyData, mlir::Location lo
"Floating point types should be processed by standard parser"); "Floating point types should be processed by standard parser");
// NOTE: We may hit this error if the nGraph type is not yet supported in parser. // NOTE: We may hit this error if the nGraph type is not yet supported in parser.
return (emitError(loc, "Unknown nGraph type: " + origTypeStr), Type()); parser.emitError(loc, "Unknown nGraph type: " + origTypeStr);
return Type();
} }
void NGraphOpsDialect::printType(mlir::Type type, raw_ostream& os) const void NGraphOpsDialect::printType(mlir::Type type, mlir::DialectAsmPrinter& printer) const
{ {
switch (type.getKind()) switch (type.getKind())
{ {
case NG_TENSOR_TYPE_ID: case NG_TENSOR_TYPE_ID:
{ {
os << "tensor<"; printer << "tensor<";
auto tensorTy = type.cast<NGTensorType>(); auto tensorTy = type.cast<NGTensorType>();
for (auto dim : tensorTy.getShape()) for (auto dim : tensorTy.getShape())
{ {
os << dim << 'x'; printer << dim << 'x';
} }
os << tensorTy.getElementType() << '>'; printer << tensorTy.getElementType() << '>';
return; return;
} }
case NG_I8_TYPE_ID: case NG_I8_TYPE_ID:
case NG_I16_TYPE_ID: case NG_I16_TYPE_ID:
case NG_I32_TYPE_ID: case NG_I32_TYPE_ID:
case NG_I64_TYPE_ID: case NG_I64_TYPE_ID:
{
auto intTy = type.cast<NGIntegerType>();
printer << "i" << intTy.getWidth();
return;
}
case NG_U8_TYPE_ID: case NG_U8_TYPE_ID:
case NG_U16_TYPE_ID: case NG_U16_TYPE_ID:
case NG_U32_TYPE_ID: case NG_U32_TYPE_ID:
case NG_U64_TYPE_ID: case NG_U64_TYPE_ID:
{ {
auto intTy = type.cast<NGIntegerType>(); auto intTy = type.cast<NGIntegerType>();
os << "i" << intTy.getWidth(); printer << "u" << intTy.getWidth();
return; return;
} }
case NG_BOOL_TYPE_ID: case NG_BOOL_TYPE_ID:
{ {
os << "bool"; printer << "bool";
return; return;
} }
default: { NGRAPH_CHECK(false, "Incorrect type to print?"); default: NGRAPH_UNREACHABLE("Incorrect type to print?");
}
} }
} }
...@@ -34,9 +34,12 @@ namespace mlir ...@@ -34,9 +34,12 @@ namespace mlir
{ {
public: public:
explicit NGraphOpsDialect(mlir::MLIRContext* ctx); explicit NGraphOpsDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override;
void printType(mlir::Type type, llvm::raw_ostream& os) const override; mlir::Type parseType(mlir::DialectAsmParser& parser) const override;
void printType(mlir::Type type, mlir::DialectAsmPrinter& printer) const override;
static StringRef getDialectNamespace() { return "ng"; } static StringRef getDialectNamespace() { return "ng"; }
private:
mlir::Type parseEltType(mlir::DialectAsmParser& parser) const;
}; };
} }
...@@ -46,32 +46,6 @@ ...@@ -46,32 +46,6 @@
#include "ngraph/op/util/index_reduction.hpp" #include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "contrib/mlir/utils.hpp"
#include <llvm/ADT/STLExtras.h>
#include <llvm/Analysis/TargetTransformInfo.h>
#include <llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/ErrorOr.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Target/TargetMachine.h>
#include <mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Target/LLVMIR.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/Passes.h>
#include <memory>
#include <mutex>
// Defines a new LLVM debug type for this file to be used by LLVM_DEBUG macro. // Defines a new LLVM debug type for this file to be used by LLVM_DEBUG macro.
#define DEBUG_TYPE "mlir-compiler" #define DEBUG_TYPE "mlir-compiler"
......
...@@ -20,21 +20,13 @@ ...@@ -20,21 +20,13 @@
#pragma once #pragma once
#include "contrib/mlir/core/compiler.hpp" #include "contrib/mlir/core/compiler.hpp"
#include "contrib/mlir/runtime/cpu/memory_manager.hpp"
#include "ngraph/check.hpp" #include "ngraph/check.hpp"
#include "ngraph/descriptor/tensor.hpp" #include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/Types.h>
#include <mlir/Pass/Pass.h> #include <mlir/Pass/Pass.h>
#include <typeindex>
#include <unordered_map>
#include <vector>
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
namespace ngraph namespace ngraph
......
...@@ -30,8 +30,8 @@ ...@@ -30,8 +30,8 @@
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetSelect.h> #include <llvm/Support/TargetSelect.h>
#include <llvm/Target/TargetMachine.h> #include <llvm/Target/TargetMachine.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h> #include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/ExecutionEngine/OptUtils.h> #include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/IR/Function.h> #include <mlir/IR/Function.h>
...@@ -81,7 +81,7 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors) ...@@ -81,7 +81,7 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors)
{ {
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
mlir::FuncOp func = m_module->lookupSymbol<mlir::FuncOp>("main"); auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>("main");
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found"); NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments // Set external arguments
...@@ -90,7 +90,7 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors) ...@@ -90,7 +90,7 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors)
// Create list with a type-erased double pointer for each invocation arguments. // Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemrefArgs', which creates the arguments list per call ABI (see // We currently use 'allocateMemrefArgs', which creates the arguments list per call ABI (see
// comment below). // comment below).
// StaticFloatMemref is just a struct with the actual pointer to the data. // StaticMemRef is just a struct with the actual pointer to the data.
auto expectedArguments = allocateMemrefArgs(); auto expectedArguments = allocateMemrefArgs();
NGRAPH_CHECK(expectedArguments.size(), "Arguments can't be created"); NGRAPH_CHECK(expectedArguments.size(), "Arguments can't be created");
...@@ -102,7 +102,7 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors) ...@@ -102,7 +102,7 @@ void MLIRCPURuntime::bindArguments(std::vector<void*>& externalTensors)
// Assign external tensor pointers to invocation arguments. // Assign external tensor pointers to invocation arguments.
for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i) for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i)
{ {
auto* memRefArg = *(reinterpret_cast<mlir::StaticFloatMemRef**>(m_invokeArgs[i])); auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(m_invokeArgs[i]));
memRefArg->data = reinterpret_cast<float*>((*m_externalTensors)[i]); memRefArg->data = reinterpret_cast<float*>((*m_externalTensors)[i]);
} }
} }
...@@ -129,18 +129,18 @@ void MLIRCPURuntime::cleanup() ...@@ -129,18 +129,18 @@ void MLIRCPURuntime::cleanup()
// Free void double pointer arguments without freeing external tensor data. // Free void double pointer arguments without freeing external tensor data.
for (auto* arg : m_invokeArgs) for (auto* arg : m_invokeArgs)
{ {
auto* memRefArg = *(reinterpret_cast<mlir::StaticFloatMemRef**>(arg)); auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(arg));
free(memRefArg); free(memRefArg);
free(arg); free(arg);
} }
} }
// The current call ABI takes a single arg pointer (argPtr) pointing to a list of args. // The current call ABI takes a single arg pointer (argPtr) pointing to a list of args.
// Each arg is a pointer to a StaticFloatMemRef which contains a data pointer // Each arg is a pointer to a StaticMemRef which contains a data pointer
// //
// The args are laid out as follows // The args are laid out as follows
// argPtr-> arg[0]-> StaticFloatMemRef -> <data> // argPtr-> arg[0]-> StaticMemRef -> <data>
// arg[1]-> StaticFloatMemRef -> <data> // arg[1]-> StaticMemRef -> <data>
// ... // ...
SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs() SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs()
{ {
...@@ -148,20 +148,18 @@ SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs() ...@@ -148,20 +148,18 @@ SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs()
for (auto i = 0; i < m_externalTensors->size(); i++) for (auto i = 0; i < m_externalTensors->size(); i++)
{ {
auto descriptor = allocateMemrefDescriptor(); auto descriptor = allocateMemrefDescriptor();
mlir::StaticFloatMemRef** arg = StaticMemRef** arg = reinterpret_cast<StaticMemRef**>(malloc(sizeof(StaticMemRef*)));
reinterpret_cast<mlir::StaticFloatMemRef**>(malloc(sizeof(mlir::StaticFloatMemRef*)));
*arg = descriptor; *arg = descriptor;
args.push_back(arg); args.push_back(arg);
} }
return args; return args;
} }
mlir::StaticFloatMemRef* MLIRCPURuntime::allocateMemrefDescriptor() StaticMemRef* MLIRCPURuntime::allocateMemrefDescriptor()
{ {
// We only use StaticFloatMemRef because that's what MLIR currently offers. // We only use StaticMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs // We should expand this with different types and dynamic MemRefs
auto* descriptor = auto* descriptor = reinterpret_cast<StaticMemRef*>(malloc(sizeof(StaticMemRef)));
reinterpret_cast<mlir::StaticFloatMemRef*>(malloc(sizeof(mlir::StaticFloatMemRef)));
NGRAPH_CHECK(descriptor != nullptr, "NULL MemRef descriptor"); NGRAPH_CHECK(descriptor != nullptr, "NULL MemRef descriptor");
descriptor->data = nullptr; descriptor->data = nullptr;
return descriptor; return descriptor;
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include <memory> #include <memory>
#include <mlir/ExecutionEngine/ExecutionEngine.h> #include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/Module.h> #include <mlir/IR/Module.h>
#include <mlir/IR/Types.h> #include <mlir/IR/Types.h>
...@@ -34,6 +33,10 @@ namespace ngraph ...@@ -34,6 +33,10 @@ namespace ngraph
{ {
namespace ngmlir namespace ngmlir
{ {
struct StaticMemRef
{
void* data;
};
/// A CPU Runtime is an MLIR runtime that owns an MLIR context and a module /// A CPU Runtime is an MLIR runtime that owns an MLIR context and a module
/// The module should be in LLVM dialect and ready to be lowered via an MLIR /// The module should be in LLVM dialect and ready to be lowered via an MLIR
/// ExecutionEngine. The runtime owns the context and must out-live any MLIR /// ExecutionEngine. The runtime owns the context and must out-live any MLIR
...@@ -57,7 +60,7 @@ namespace ngraph ...@@ -57,7 +60,7 @@ namespace ngraph
llvm::SmallVector<void*, 8> allocateMemrefArgs(); llvm::SmallVector<void*, 8> allocateMemrefArgs();
/// Helper to allocate a mem ref object. Handles static shapes only for now. /// Helper to allocate a mem ref object. Handles static shapes only for now.
mlir::StaticFloatMemRef* allocateMemrefDescriptor(); StaticMemRef* allocateMemrefDescriptor();
private: private:
// Pointers to externally allocated memory for sub-graph's input and output tensors. // Pointers to externally allocated memory for sub-graph's input and output tensors.
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include <memory> #include <memory>
#include <mlir/ExecutionEngine/ExecutionEngine.h> #include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/Module.h> #include <mlir/IR/Module.h>
#include <mlir/IR/Types.h> #include <mlir/IR/Types.h>
......
...@@ -55,7 +55,7 @@ func @i64(%arg0: !ng.i64) { ...@@ -55,7 +55,7 @@ func @i64(%arg0: !ng.i64) {
// ----- // -----
// CHECK-LABEL: func @u8 // CHECK-LABEL: func @u8
// CHECK-SAME: (%{{.*}}: !ng.i8) // CHECK-SAME: (%{{.*}}: !ng.u8)
func @u8(%arg0: !ng.u8) { func @u8(%arg0: !ng.u8) {
"ng.return"() : () -> () "ng.return"() : () -> ()
} }
...@@ -63,7 +63,7 @@ func @u8(%arg0: !ng.u8) { ...@@ -63,7 +63,7 @@ func @u8(%arg0: !ng.u8) {
// ----- // -----
// CHECK-LABEL: func @u16 // CHECK-LABEL: func @u16
// CHECK-SAME: (%{{.*}}: !ng.i16) // CHECK-SAME: (%{{.*}}: !ng.u16)
func @u16(%arg0: !ng.u16) { func @u16(%arg0: !ng.u16) {
"ng.return"() : () -> () "ng.return"() : () -> ()
} }
...@@ -71,7 +71,7 @@ func @u16(%arg0: !ng.u16) { ...@@ -71,7 +71,7 @@ func @u16(%arg0: !ng.u16) {
// ----- // -----
// CHECK-LABEL: func @u32 // CHECK-LABEL: func @u32
// CHECK-SAME: (%{{.*}}: !ng.i32) // CHECK-SAME: (%{{.*}}: !ng.u32)
func @u32(%arg0: !ng.u32) { func @u32(%arg0: !ng.u32) {
"ng.return"() : () -> () "ng.return"() : () -> ()
} }
...@@ -83,3 +83,83 @@ func @u32(%arg0: !ng.u32) { ...@@ -83,3 +83,83 @@ func @u32(%arg0: !ng.u32) {
func @u64(%arg0: !ng.u64) { func @u64(%arg0: !ng.u64) {
"ng.return"() : () -> () "ng.return"() : () -> ()
} }
// -----
// CHECK: func @tensor_i8
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2x!ng.i8>)
func @tensor_i8(%arg0: !ng.tensor<2x2x!ng.i8>) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @tensor_i16
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2x!ng.i16>)
func @tensor_i16(%arg0: !ng.tensor<2x2x!ng.i16>) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @tensor_i32
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2x!ng.i32>)
func @tensor_i32(%arg0: !ng.tensor<2x2x!ng.i32>) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @tensor_i64
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2x!ng.i64>)
func @tensor_i64(%arg0: !ng.tensor<2x2x!ng.i64>) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @tensor_u8
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2x!ng.u8>)
func @tensor_u8(%arg0: !ng.tensor<2x2x!ng.u8>) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @tensor_u16
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2x!ng.u16>)
func @tensor_u16(%arg0: !ng.tensor<2x2x!ng.u16>) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @tensor_u32
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2x!ng.u32>)
func @tensor_u32(%arg0: !ng.tensor<2x2x!ng.u32>) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @tensor_u64
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2x!ng.u64>)
func @tensor_u64(%arg0: !ng.tensor<2x2x!ng.u64>) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @tensor_f32
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2xf32>)
func @tensor_f32(%arg0: !ng.tensor<2x2xf32>) {
"ng.return"() : () -> ()
}
// -----
// CHECK: func @tensor_f64
// CHECK-SAME: (%{{.*}}: !ng.tensor<2x2xf64>)
func @tensor_f64(%arg0: !ng.tensor<2x2xf64>) {
"ng.return"() : () -> ()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment