Commit a643213d authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Update build system to use LLVM mono-repo for MLIR (#4188)

* [MLIR] Update build system to use LLVM mono-repo for MLIR

* [MLIR] LLVM mono-repo conflicts

* Disable lit tests

* Fix formatting

* Fix memopt tests

* PR fix

* Fix view test
Co-authored-by: 's avatarNagy Mostafa <nagy.mostafa@gmail.com>
parent 1e13ad94
......@@ -17,11 +17,9 @@
include(ExternalProject)
set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID c36773c7)
set(MLIR_COMMIT_ID 606e96a1)
set(MLIR_LLVM_COMMIT_ID d6295255)
# MLIR environment variables. Some of them are used by LIT tool.
......@@ -32,9 +30,10 @@ else()
endif()
set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects)
set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir)
set(MLIR_BUILD_DIR ${MLIR_LLVM_ROOT}/build)
set(MLIR_TOOLS_DIR ${MLIR_BUILD_DIR}/bin)
set(MLIR_LLVM_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm)
set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/mlir)
set(MLIR_LLVM_BUILD_DIR ${MLIR_PROJECT_ROOT}/build)
set(MLIR_LLVM_TOOLS_DIR ${MLIR_LLVM_BUILD_DIR}/bin)
set(NGRAPH_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
set(NGRAPH_LIT_TEST_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/test/mlir)
......@@ -48,17 +47,13 @@ if (NOT NGRAPH_USE_PREBUILT_MLIR)
execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
WORKING_DIRECTORY "${MLIR_PROJECT_ROOT}")
# clone and build llvm
# Clone and build llvm + mlir.
execute_process(COMMAND "${CMAKE_COMMAND}" --build . --target ext_mlir_llvm
WORKING_DIRECTORY "${MLIR_PROJECT_ROOT}")
# clone and build mlir
execute_process(COMMAND "${CMAKE_COMMAND}" --build . --target ext_mlir
WORKING_DIRECTORY "${MLIR_PROJECT_ROOT}")
endif()
# Enable modules for LLVM.
set(LLVM_DIR "${MLIR_BUILD_DIR}/lib/cmake/llvm"
set(LLVM_DIR "${MLIR_LLVM_BUILD_DIR}/lib/cmake/llvm"
CACHE PATH "Path to LLVM cmake modules")
list(APPEND CMAKE_MODULE_PATH "${LLVM_DIR}")
include(AddLLVM)
......@@ -71,7 +66,7 @@ message(STATUS "Using modules in: ${LLVM_DIR}")
message(STATUS "LLVM RTTI is ${LLVM_ENABLE_RTTI}")
set(MLIR_SRC_INCLUDE_PATH ${MLIR_SOURCE_DIR}/include)
set(MLIR_BIN_INCLUDE_PATH ${MLIR_BUILD_DIR}/projects/mlir/include)
set(MLIR_BIN_INCLUDE_PATH ${MLIR_LLVM_BUILD_DIR}/tools/mlir/include)
set(MLIR_INCLUDE_PATHS ${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH})
set(MLIR_LLVM_INCLUDE_PATH ${LLVM_INCLUDE_DIRS})
......
......@@ -20,22 +20,6 @@ include(ExternalProject)
project(mlir-fetch NONE)
ExternalProject_Add(
ext_mlir_llvm
PREFIX mlir_llvm
GIT_REPOSITORY @MLIR_LLVM_REPO_URL@
GIT_TAG @MLIR_LLVM_COMMIT_ID@
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
UPDATE_COMMAND ""
SOURCE_DIR @MLIR_LLVM_ROOT@
DOWNLOAD_NO_PROGRESS TRUE
EXCLUDE_FROM_ALL TRUE
)
set(MLIR_DEPENDS ext_mlir_llvm)
include(ProcessorCount)
ProcessorCount(N)
if(N EQUAL 0)
......@@ -43,22 +27,18 @@ if(N EQUAL 0)
endif()
ExternalProject_Add(
ext_mlir
PREFIX mlir
DEPENDS ${MLIR_DEPENDS}
GIT_REPOSITORY @MLIR_REPO_URL@
GIT_TAG @MLIR_COMMIT_ID@
CONFIGURE_COMMAND ""
CMAKE_GENERATOR "@CMAKE_GENERATOR@"
ext_mlir_llvm
PREFIX mlir_llvm
GIT_REPOSITORY @MLIR_LLVM_REPO_URL@
GIT_TAG @MLIR_LLVM_COMMIT_ID@
CMAKE_GENERATOR @CMAKE_GENERATOR@
CMAKE_GENERATOR_PLATFORM @CMAKE_GENERATOR_PLATFORM@
CMAKE_GENERATOR_TOOLSET @CMAKE_GENERATOR_TOOLSET@
BUILD_COMMAND @CMAKE_COMMAND@ ../llvm -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD=host -DLLVM_ENABLE_RTTI=ON -DCMAKE_BUILD_TYPE=@CMAKE_BUILD_TYPE@
COMMAND @CMAKE_COMMAND@ --build . --target check-mlir -- -j${N}
CONFIGURE_COMMAND @CMAKE_COMMAND@ @MLIR_LLVM_SOURCE_DIR@ -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD=host -DLLVM_ENABLE_RTTI=ON -DCMAKE_BUILD_TYPE=@CMAKE_BUILD_TYPE@
BUILD_COMMAND @CMAKE_COMMAND@ --build . --target check-mlir -- -j${N}
INSTALL_COMMAND ""
UPDATE_COMMAND ""
SOURCE_DIR @MLIR_SOURCE_DIR@
BINARY_DIR @MLIR_BUILD_DIR@
SOURCE_DIR @MLIR_LLVM_ROOT@
BINARY_DIR @MLIR_LLVM_BUILD_DIR@
STAMP_DIR "@MLIR_PROJECT_ROOT@/mlir/stamp"
DOWNLOAD_NO_PROGRESS TRUE
EXCLUDE_FROM_ALL TRUE
......
......@@ -68,16 +68,16 @@ namespace
{
public:
/// Initialize the relationship for a number of syms
void init(std::unordered_set<Value*>& symbols);
void init(DenseSet<Value>& symbols);
/// Checks if values a and b can alias
bool canAlias(Value* a, Value* b);
void insertNoAlias(Value* a, Value* b);
bool canAlias(Value a, Value b);
void insertNoAlias(Value a, Value b);
private:
using BV = llvm::BitVector;
std::unordered_map<Value*, unsigned> m_valueToIdx;
std::unordered_map<unsigned, Value*> m_idxToValue;
std::unordered_map<Value*, BV*> m_valueToSet;
DenseMap<Value, unsigned> m_valueToIdx;
DenseMap<unsigned, Value> m_idxToValue;
DenseMap<Value, BV*> m_valueToSet;
SmallVector<BV, 10> m_sets;
};
......@@ -86,16 +86,15 @@ namespace
class LivenessAnalysis
{
public:
bool isLive(Value* v);
void setLive(Value* v);
void kill(Value* v);
void getLiveValues(llvm::SmallVectorImpl<Value*>& values);
void reset();
bool isLive(Value v);
void setLive(Value v);
void kill(Value v);
void getLiveValues(llvm::SmallVectorImpl<Value>& values);
private:
unsigned m_maxIdx = 0;
SmallVector<bool, 10> m_liveness;
std::unordered_map<Value*, unsigned> m_valueToIdx;
DenseMap<Value, unsigned> m_valueToIdx;
};
// Memory Assignment analysis
......@@ -121,7 +120,7 @@ namespace
void processDestructiveInPlace(mlir::Operation* op);
void processConcat(mlir::Operation* op);
bool isSafeInPlace(mlir::Operation* op);
bool isInputOrOutputValue(mlir::Value* value);
bool isInputOrOutputValue(mlir::Value value);
LivenessAnalysis m_liveness;
AliasRelation m_aliasRelation;
std::unordered_map<std::string, bool> m_inplaceOps;
......@@ -132,7 +131,7 @@ namespace
// helpers
// Determines the buffer size a value needs based on its type
// offset is where that value should start in the buffer
static unsigned getBufferSizeForOperand(mlir::Value* value, int offset);
static unsigned getBufferSizeForOperand(mlir::Value value, int offset);
// Go backwards over instructions
//
......@@ -184,14 +183,14 @@ namespace
auto& block = *(blocks.begin());
// count number of syms in the code and initialize alias relationship
std::unordered_set<Value*> syms;
DenseSet<Value> syms;
for (auto it = block.begin(); it != block.end(); it++)
{
Operation* op = &(*it);
for (auto it : op->getResults())
{
Value* v = it;
Value v = it;
if (syms.find(v) == syms.end())
{
syms.insert(v);
......@@ -199,7 +198,7 @@ namespace
}
for (auto it : op->getOperands())
{
Value* v = it;
Value v = it;
if (syms.find(v) == syms.end())
{
syms.insert(v);
......@@ -245,7 +244,7 @@ namespace
// concat on the highest non-one axis
auto concatAxis = concat.concatenation_axis();
auto result = concat.getResult();
auto shape = (result->getType().cast<NGTensorType>()).getShape();
auto shape = (result.getType().cast<NGTensorType>()).getShape();
std::vector<int> opndOffsets;
BufferInfo bufferInfo;
int bufferId = -1, baseOffset = 0;
......@@ -288,7 +287,7 @@ namespace
else
{
auto opnd = op->getOperand(i - 1);
auto tensorType = opnd->getType().cast<NGTensorType>();
auto tensorType = opnd.getType().cast<NGTensorType>();
opndOffset += tensorType.getNumElements();
opndOffsets.push_back(opndOffset);
}
......@@ -306,7 +305,7 @@ namespace
for (auto i = 0; i < op->getNumOperands(); i++)
{
auto opnd = op->getOperand(i);
auto defOp = opnd->getDefiningOp();
auto defOp = opnd.getDefiningOp();
NGRAPH_CHECK(defOp != nullptr, "Defining operation expected");
// calculate expected absolute offset in the buffer
bufferOffset = baseOffset + opndOffsets[i];
......@@ -357,7 +356,7 @@ namespace
// For now, assign only if all srcs have no prior assignments
for (auto opnd : op->getOperands())
{
if (m_memAnalysis->getBufferInfo(opnd->getDefiningOp()).isValid())
if (m_memAnalysis->getBufferInfo(opnd.getDefiningOp()).isValid())
{
return;
}
......@@ -381,7 +380,7 @@ namespace
for (auto i = 0; i < op->getNumOperands(); i++)
{
auto opnd = op->getOperand(i);
auto defOp = opnd->getDefiningOp();
auto defOp = opnd.getDefiningOp();
NGRAPH_CHECK(defOp != nullptr, "Defining operation expected");
auto opndOffset = baseOffset + opndOffsets[i];
m_memAnalysis->setBufferInfo(defOp, {bufferId, opndOffset});
......@@ -392,7 +391,7 @@ namespace
void MemoryAssignment::processDestructiveInPlace(mlir::Operation* op)
{
NGRAPH_CHECK(op->getNumResults() == 1, "Destructive in-place with multi-def ?");
Value* use = nullptr;
Value use = nullptr;
int useCount = -1;
if (isInputOrOutputValue(op->getResult(0)))
......@@ -405,11 +404,7 @@ namespace
{
if (!m_liveness.isLive(opnd) && !isInputOrOutputValue(opnd))
{
int uses = 0;
for (auto& i : opnd->getUses())
{
uses++;
}
int uses = std::distance(opnd.getUses().begin(), opnd.getUses().end());
if (useCount == -1 || uses < useCount)
{
use = opnd;
......@@ -428,28 +423,28 @@ namespace
// attach a new buffer id, and 0 offset on obth src and result
bufferInfo = {m_bufferId++, 0};
m_memAnalysis->setBufferInfo(op, bufferInfo);
m_memAnalysis->setBufferInfo(use->getDefiningOp(), bufferInfo);
m_memAnalysis->setBufferInfo(use.getDefiningOp(), bufferInfo);
}
else
{
// copy result buffer id and offset to src
m_memAnalysis->setBufferInfo(use->getDefiningOp(), bufferInfo);
m_memAnalysis->setBufferInfo(use.getDefiningOp(), bufferInfo);
}
auto bufferSize = 0;
bufferSize = getBufferSizeForOperand(op->getResult(0), bufferInfo.m_offset);
m_memAnalysis->setBufferSize(bufferInfo.m_bufferId, bufferSize);
// update aliasing info
// use value cannot alias any live value
SmallVector<Value*, 10> liveValues;
SmallVector<Value, 10> liveValues;
m_liveness.getLiveValues(liveValues);
for (auto& value : liveValues)
{
m_aliasRelation.insertNoAlias(use, value);
}
}
bool MemoryAssignment::isInputOrOutputValue(mlir::Value* value)
bool MemoryAssignment::isInputOrOutputValue(mlir::Value value)
{
auto defOp = value->getDefiningOp();
auto defOp = value.getDefiningOp();
// If no defining op, then this is a block arg, skip operand
//
// TODO: This check is assuming single BB function, improve to handle control-flow.
......@@ -464,7 +459,7 @@ namespace
//
// TODO: Improve to support control flow. Track value use-chain along branches/block-args,
// if we hit a use in a return, it is an output value.
for (auto& use : value->getUses())
for (auto& use : value.getUses())
{
auto useOp = use.getOwner();
if (isa<NGReturnOp>(useOp))
......@@ -482,7 +477,7 @@ namespace
return it != m_inplaceOps.end() ? it->second : false;
}
void AliasRelation::init(std::unordered_set<Value*>& symbols)
void AliasRelation::init(DenseSet<Value>& symbols)
{
unsigned numSyms = symbols.size();
m_sets.resize(numSyms);
......@@ -503,13 +498,13 @@ namespace
}
}
bool AliasRelation::canAlias(Value* a, Value* b)
bool AliasRelation::canAlias(Value a, Value b)
{
// check if a and b are in the same set
return m_valueToSet[a] != m_valueToSet[b];
}
void AliasRelation::insertNoAlias(Value* a, Value* b)
void AliasRelation::insertNoAlias(Value a, Value b)
{
// union the two sets that a and b belong to
// update the maps accordingly
......@@ -535,14 +530,7 @@ namespace
}
}
void LivenessAnalysis::reset()
{
m_valueToIdx.clear();
m_liveness.clear();
m_maxIdx = 0;
}
void LivenessAnalysis::getLiveValues(llvm::SmallVectorImpl<Value*>& values)
void LivenessAnalysis::getLiveValues(llvm::SmallVectorImpl<Value>& values)
{
for (auto& entry : m_valueToIdx)
{
......@@ -553,7 +541,7 @@ namespace
}
}
bool LivenessAnalysis::isLive(Value* v)
bool LivenessAnalysis::isLive(Value v)
{
auto it = m_valueToIdx.find(v);
if (it == m_valueToIdx.end())
......@@ -563,7 +551,7 @@ namespace
return m_liveness[it->second];
}
void LivenessAnalysis::setLive(Value* v)
void LivenessAnalysis::setLive(Value v)
{
auto it = m_valueToIdx.find(v);
if (it == m_valueToIdx.end())
......@@ -578,7 +566,7 @@ namespace
}
}
void LivenessAnalysis::kill(Value* v)
void LivenessAnalysis::kill(Value v)
{
auto it = m_valueToIdx.find(v);
if (it == m_valueToIdx.end())
......@@ -589,9 +577,9 @@ namespace
m_liveness[it->second] = false;
}
// helpers
unsigned getBufferSizeForOperand(mlir::Value* value, int offset)
unsigned getBufferSizeForOperand(mlir::Value value, int offset)
{
auto tensorType = value->getType().dyn_cast<NGTensorType>();
auto tensorType = value.getType().dyn_cast<NGTensorType>();
NGRAPH_CHECK(tensorType, "Invalid type to find buffer size for");
unsigned bufferSize = offset * std::ceil(tensorType.getElementBitWidth() / 8);
......
......@@ -86,7 +86,7 @@ namespace
} \
\
PatternMatchResult matchAndRewrite(Operation* op, \
ArrayRef<Value*> operands, \
ArrayRef<Value> operands, \
ConversionPatternRewriter& rewriter) const override; \
};
......@@ -104,7 +104,7 @@ namespace
/// Hook for derived classes to implement combined matching and rewriting.
PatternMatchResult matchAndRewrite(Operation* op,
ArrayRef<Value*> operands,
ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override
{
auto funcOp = cast<FuncOp>(op);
......@@ -153,19 +153,19 @@ namespace
// Helpers
template <typename RedOp>
void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
ArrayRef<Value> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
template <typename OP>
void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands,
ArrayRef<Value> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
template <typename OP>
void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands,
ArrayRef<Value> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
......@@ -184,32 +184,30 @@ namespace
// cLb/Ub : Values representing bounds on channel dim in image (C_IN)
// kLb/Ub : Values representing bounds on numFilters dim in filters (C_OUT)
// gId : Value representing induction variable for the outer loop
void lowerConvolution(Value* result,
Value* images,
Value* filters,
void lowerConvolution(Value result,
Value images,
Value filters,
ArrayAttr stridesAttr,
ArrayAttr padBelowAttr,
ArrayAttr padAboveAttr,
PatternRewriter& rewriter,
DialectLoweringPass& pass,
Location loc,
Value* cLb = nullptr,
Value* cUb = nullptr,
Value* kLb = nullptr,
Value* kUb = nullptr,
Value* gId = nullptr);
Value cLb = nullptr,
Value cUb = nullptr,
Value kLb = nullptr,
Value kUb = nullptr,
Value gId = nullptr);
template <typename OP>
void lowerPooling(Operation* op,
ArrayRef<Value*> operands,
ArrayRef<Value> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
ValueHandle createZeroConstant(mlir::Type type);
ValueHandle createOneConstant(mlir::Type type);
bool isInPlaceConcat(mlir::Operation* op, DialectLoweringPass& pass);
/// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter
{
......@@ -228,10 +226,10 @@ namespace
public:
void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
SmallVector<Value, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
/// Allocates a linear buffer for a temporary memref that shares its
/// underlying memory. Used in conjunction with createTempMemref
Value* createTempBuffer(int bufferId, PatternRewriter& rewriter);
Value createTempBuffer(int bufferId, PatternRewriter& rewriter);
/// Creates an allocation or view of a memref.
/// type MemRef Type
/// buffer Optional buffer value to create view over
......@@ -239,8 +237,7 @@ namespace
///
/// If buffer is null it allocates a Memref directly and Offset is ignored.
/// If not, it creates a view over the pre-allocated buffer at the given offset.
Value*
createTempMemref(Type type, Value* buffer, unsigned offset, PatternRewriter& rewriter);
Value createTempMemref(Type type, Value buffer, unsigned offset, PatternRewriter& rewriter);
/// Inserts dealloc Ops for each temporary allocated by AllocOp
void insertDeallocs(PatternRewriter& rewriter);
NGraphTypeConverter& getTypeConverter() { return typeConverter; }
......@@ -261,11 +258,11 @@ namespace
private:
NGraphTypeConverter typeConverter;
// List of temporary memrefs to deallocate at end of function
SmallVector<Value*, 4> memRefsToDealloc;
SmallVector<Value, 4> memRefsToDealloc;
// Ops maybe assigned mem-refs in previous memory optimization passes.
// Track pre-assigned buffers for each Value and re-use it if one is available.
using IdToMemRefMap = std::unordered_map<unsigned, Value*>;
using IdToMemRefMap = std::unordered_map<unsigned, Value>;
IdToMemRefMap m_id_to_memref;
MemoryAnalysis* m_memAnalysis;
// TODO: Workaround for findOutputValues and buildOutputDefs. See NGCPU-470.
......@@ -344,7 +341,7 @@ namespace
FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found");
SmallVector<Value*, 4> outputList;
SmallVector<Value, 4> outputList;
unsigned outputCount = 0;
unsigned inputCount = f.getType().getNumInputs();
// we find out output values by looking at returned values
......@@ -354,7 +351,7 @@ namespace
{
// annotate instructions defining outputs with the arg idx of the output
auto outputValue = ret.getOperand(i);
auto op = outputValue->getDefiningOp();
auto op = outputValue.getDefiningOp();
op->setAttr(
"graphOutputIdx",
......@@ -365,13 +362,13 @@ namespace
});
}
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
SmallVector<Value, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
PatternRewriter& rewriter)
{
FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found");
SmallVector<Value*, 4> newResults;
SmallVector<Value, 4> newResults;
for (auto origResult : op->getResults())
{
// find output arg if this operation produces any sub-graph outputs
......@@ -390,11 +387,11 @@ namespace
// the linear buffer.
// If two memrefs are defined via 2 Views over the same buffer, then they share and
// will re-use the same buffer.
auto tensorType = origResult->getType().cast<NGTensorType>();
Value* newResult = nullptr;
auto tensorType = origResult.getType().cast<NGTensorType>();
Value newResult = nullptr;
auto bufferInfo = m_memAnalysis->getBufferInfo(op);
Type memRefType = typeConverter.convertType(tensorType);
Value* bufferValue = nullptr;
Value bufferValue = nullptr;
if (!bufferInfo.isValid())
{
......@@ -427,7 +424,7 @@ namespace
return newResults;
}
Value* DialectLoweringPass::createTempBuffer(int bufferId, PatternRewriter& rewriter)
Value DialectLoweringPass::createTempBuffer(int bufferId, PatternRewriter& rewriter)
{
unsigned sizeInBytes = getMemAnalysis()->getBufferSize(bufferId);
NGRAPH_CHECK(bufferId >= 0, "Invalid buffer id to allocate");
......@@ -438,7 +435,7 @@ namespace
MemRefType::get({sizeInBytes}, IntegerType::get(8, rewriter.getContext()), {});
// TODO: Set alignment
Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), bufferType);
Value alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), bufferType);
memRefsToDealloc.push_back(alloc);
......@@ -455,8 +452,8 @@ namespace
return alloc;
}
Value* DialectLoweringPass::createTempMemref(Type type,
Value* buffer,
Value DialectLoweringPass::createTempMemref(Type type,
Value buffer,
unsigned offset,
PatternRewriter& rewriter)
{
......@@ -481,14 +478,14 @@ namespace
auto map = makeStridedLinearLayoutMap(strides, offset, rewriter.getContext());
MemRefType newMemRefType = MemRefType::get(shape, memRefType.getElementType(), map);
auto viewOp = rewriter.create<mlir::ViewOp>(
buffer->getDefiningOp()->getLoc(), newMemRefType, buffer, llvm::None);
buffer.getDefiningOp()->getLoc(), newMemRefType, buffer, llvm::None);
return viewOp.getResult();
}
// No buffer, create an atomic memref without underlying buffer
NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported");
Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType);
Value alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType);
memRefsToDealloc.push_back(alloc);
return alloc;
}
......@@ -501,9 +498,9 @@ namespace
NGRAPH_CHECK(func, "FuncOp '" + funcName + "' not found");
unsigned int argIdx = 0;
for (auto* arg : func.getArguments())
for (auto arg : func.getArguments())
{
if (arg->getType().isa<MemRefType>())
if (arg.getType().isa<MemRefType>())
{
func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
}
......@@ -526,7 +523,6 @@ namespace
PatternRewriter& rewriter)
{
auto module = getModule();
auto* context = getModule().getContext();
auto callBackFunc = module.lookupSymbol<mlir::FuncOp>(name);
if (!callBackFunc)
{
......@@ -583,7 +579,7 @@ namespace
#define REWRITER(OP) \
PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) const
Operation* op, ArrayRef<Value> operands, ConversionPatternRewriter& rewriter) const
REWRITER(NGAddOp)
{
......@@ -675,12 +671,12 @@ namespace
auto loc = cast<NGReluOp>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
NGRAPH_CHECK(result.getType().isa<MemRefType>());
// Note that builder's current function is still the original function body.
// use getBlock to get the new block instead.
// get new operands
Value* lhs = operands[0];
Value lhs = operands[0];
ScopedContext scope(rewriter, loc);
// Views
......@@ -696,8 +692,8 @@ namespace
// Steps
auto steps = vLHS.getSteps();
NGRAPH_CHECK(lhs->getType().isa<MemRefType>());
Type elemTy = lhs->getType().dyn_cast<MemRefType>().getElementType();
NGRAPH_CHECK(lhs.getType().isa<MemRefType>());
Type elemTy = lhs.getType().dyn_cast<MemRefType>().getElementType();
AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs);
......@@ -723,14 +719,14 @@ namespace
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* lhs = operands[0];
Value* rhs = operands[1];
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value lhs = operands[0];
Value rhs = operands[1];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp");
auto resultTy = result->getType().dyn_cast<MemRefType>();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>();
auto rhsTy = rhs->getType().dyn_cast<MemRefType>();
auto resultTy = result.getType().dyn_cast<MemRefType>();
auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto rhsTy = rhs.getType().dyn_cast<MemRefType>();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
NGRAPH_CHECK(rhsTy, "Unexpected non-memref RHS type");
......@@ -792,7 +788,7 @@ namespace
ScopedContext scope(rewriter, loc);
// Create Value for result, and extract type info.
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
// Create view to write into result.
......@@ -870,11 +866,11 @@ namespace
ScopedContext scope(rewriter, loc);
// Get operands
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in GatherOp");
Value* params = operands[0];
Value* indices = operands[1];
Value params = operands[0];
Value indices = operands[1];
auto axis = gatherOp.axis().getSExtValue();
// Create view to write into result.
......@@ -987,10 +983,10 @@ namespace
auto convolOp = cast<NGConvolutionOp>(op);
// Get operands
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in Convolution Op");
Value* images = operands[0];
Value* filters = operands[1];
Value images = operands[0];
Value filters = operands[1];
auto strides = convolOp.strides();
auto padBelow = convolOp.padBelow();
auto padAbove = convolOp.padBelow();
......@@ -1014,10 +1010,10 @@ namespace
auto gConvOp = cast<NGGroupConvOp>(op);
ScopedContext scope(rewriter, gConvOp.getLoc());
// Get operands
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in Convolution Op");
Value* images = operands[0];
Value* filters = operands[1];
Value images = operands[0];
Value filters = operands[1];
auto strides = gConvOp.strides();
auto padBelow = gConvOp.padBelow();
auto padAbove = gConvOp.padBelow();
......@@ -1030,10 +1026,9 @@ namespace
ValueHandle lb = intrinsics::constant_index(0);
ValueHandle ub = intrinsics::constant_index(groups);
ValueHandle step = intrinsics::constant_index(1);
auto imagesType = images->getType().cast<MemRefType>();
auto filtersType = filters->getType().cast<MemRefType>();
auto imagesType = images.getType().cast<MemRefType>();
auto filtersType = filters.getType().cast<MemRefType>();
auto imagesShape = imagesType.getShape();
auto filtersShape = filtersType.getShape();
......@@ -1086,8 +1081,8 @@ namespace
}
// Use callback: Pooling, MatMul, Gemm, Softmax
static void castMemRef(SmallVector<mlir::Value*, 4> inputs,
SmallVector<mlir::Value*, 4>& outputs,
static void castMemRef(SmallVector<mlir::Value, 4>& inputs,
SmallVector<mlir::Value, 4>& outputs,
PatternRewriter& rewriter,
UnrankedMemRefType type)
{
......@@ -1123,22 +1118,21 @@ namespace
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* src = operands[0];
Value* delta = operands[1];
Value src = operands[0];
Value delta = operands[1];
ArrayRef<Attribute> windowShape = pooling.windowShape().getValue();
ArrayRef<Attribute> windowStrides = pooling.windowMovementStrides().getValue();
ArrayRef<Attribute> padBelow = pooling.padBelow().getValue();
ArrayRef<Attribute> padAbove = pooling.padAbove().getValue();
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(src && delta && result, "Unexpected null values in MaxPoolBackprop Op");
auto resultTy = result->getType().dyn_cast<MemRefType>();
auto resultTy = result.getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape();
auto srcTy = src->getType().dyn_cast<MemRefType>();
auto srcTy = src.getType().dyn_cast<MemRefType>();
auto srcShape = srcTy.getShape();
auto deltaTy = delta->getType().dyn_cast<MemRefType>();
auto deltaShape = deltaTy.getShape();
auto deltaTy = delta.getType().dyn_cast<MemRefType>();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(srcTy, "Unexpected non-memref src type");
NGRAPH_CHECK(deltaTy, "Unexpected non-memref delta type");
......@@ -1153,8 +1147,8 @@ namespace
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
SmallVector<mlir::Value*, 4> inputs = {src, delta, result};
SmallVector<mlir::Value*, 4> outputs;
SmallVector<mlir::Value, 4> inputs = {src, delta, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc = pass.getCallDecl(
......@@ -1177,7 +1171,6 @@ namespace
}
else if (srcShape.size() == 5)
{
opAttrs attrs;
attrs.poolAttrs3d.includePaddingInAvgComputation = false;
for (auto i = 0; i < 3; i++)
{
......@@ -1192,7 +1185,7 @@ namespace
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MAXPOOLBACKPROP), 64);
SmallVector<mlir::Value*, 4> args = {
SmallVector<mlir::Value, 4> args = {
outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
......@@ -1207,16 +1200,16 @@ namespace
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* lhs = operands[0];
Value* rhs = operands[1];
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value lhs = operands[0];
Value rhs = operands[1];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in MatMulOp");
auto resultTy = result->getType().dyn_cast<MemRefType>();
auto resultTy = result.getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>();
auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto lhsShape = lhsTy.getShape();
auto rhsTy = rhs->getType().dyn_cast<MemRefType>();
auto rhsTy = rhs.getType().dyn_cast<MemRefType>();
auto rhsShape = rhsTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
......@@ -1262,10 +1255,10 @@ namespace
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MATMUL), 64);
SmallVector<mlir::Value*, 4> inputs = {lhs, rhs, result};
SmallVector<mlir::Value*, 4> outputs;
SmallVector<mlir::Value, 4> inputs = {lhs, rhs, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value*, 4> args = {
SmallVector<mlir::Value, 4> args = {
outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
......@@ -1281,18 +1274,18 @@ namespace
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* lhs = operands[0];
Value* rhs = operands[1];
Value* bias = operands[2];
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value lhs = operands[0];
Value rhs = operands[1];
Value bias = operands[2];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && rhs && bias && result, "Unexpected null values in GemmOp");
auto resultTy = result->getType().dyn_cast<MemRefType>();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>();
auto resultTy = result.getType().dyn_cast<MemRefType>();
auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto lhsShape = lhsTy.getShape();
auto rhsTy = rhs->getType().dyn_cast<MemRefType>();
auto rhsTy = rhs.getType().dyn_cast<MemRefType>();
auto rhsShape = rhsTy.getShape();
auto biasTy = bias->getType().dyn_cast<MemRefType>();
auto biasTy = bias.getType().dyn_cast<MemRefType>();
auto biasShape = biasTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
......@@ -1382,10 +1375,10 @@ namespace
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::GEMM), 64);
SmallVector<mlir::Value*, 4> inputs = {lhs, rhs, bias, result};
SmallVector<mlir::Value*, 4> outputs;
SmallVector<mlir::Value, 4> inputs = {lhs, rhs, bias, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value*, 4> args = {
SmallVector<mlir::Value, 4> args = {
outputs[0], outputs[1], outputs[2], outputs[3], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
......@@ -1401,13 +1394,13 @@ namespace
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* lhs = operands[0];
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value lhs = operands[0];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && result, "Unexpected null values in SoftmaxOp");
auto resultTy = result->getType().dyn_cast<MemRefType>();
auto resultTy = result.getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>();
auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto lhsShape = lhsTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
......@@ -1436,10 +1429,10 @@ namespace
{},
rewriter);
SmallVector<mlir::Value*, 4> inputs = {lhs, result};
SmallVector<mlir::Value*, 4> outputs;
SmallVector<mlir::Value, 4> inputs = {lhs, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value*, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg};
SmallVector<mlir::Value, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result);
......@@ -1450,20 +1443,20 @@ namespace
#undef REWRITER
/// End of pattern matchers
void lowerConvolution(Value* result,
Value* images,
Value* filters,
void lowerConvolution(Value result,
Value images,
Value filters,
ArrayAttr stridesAttr,
ArrayAttr padBelowAttr,
ArrayAttr padAboveAttr,
PatternRewriter& rewriter,
DialectLoweringPass& pass,
Location loc,
Value* cLb,
Value* cUb,
Value* kLb,
Value* kUb,
Value* gId)
Value cLb,
Value cUb,
Value kLb,
Value kUb,
Value gId)
{
// Let Images shape be [N, C_IN, D_1, ... D_f]
// Let Filters shape be [C_OUT, C_IN, F_1, ... F_f]
......@@ -1516,7 +1509,7 @@ namespace
auto strides = stridesAttr.getValue();
auto padBelow = padBelowAttr.getValue();
auto padAbove = padBelowAttr.getValue();
Type elemTy = images->getType().cast<MemRefType>().getElementType();
Type elemTy = images.getType().cast<MemRefType>().getElementType();
// Create views
MemRefView vRes(result), vImages(images), vFilters(filters);
......@@ -1767,7 +1760,7 @@ namespace
// if args : img dims, img lbs, img ubs
SmallVector<IndexHandle, 4>::iterator it = imgIndices.begin();
std::advance(it, 2);
SmallVector<Value*, 4> affineIfArgs(it, imgIndices.end());
SmallVector<Value, 4> affineIfArgs(it, imgIndices.end());
affineIfArgs.insert(
affineIfArgs.end(), imgSpatialLbs.begin(), imgSpatialLbs.end());
affineIfArgs.insert(
......@@ -1811,19 +1804,19 @@ namespace
template <typename OP>
void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands,
ArrayRef<Value> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
{
auto loc = cast<OP>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
NGRAPH_CHECK(result.getType().isa<MemRefType>());
// Note that builder's current function is still the original function body.
// use getBlock to get the new block instead.
// get new operands
Value* lhs = operands[0];
Value lhs = operands[0];
ScopedContext scope(rewriter, loc);
// Views
......@@ -1839,8 +1832,8 @@ namespace
// Steps
auto steps = vLHS.getSteps();
NGRAPH_CHECK(lhs->getType().isa<MemRefType>());
Type elemTy = lhs->getType().cast<MemRefType>().getElementType();
NGRAPH_CHECK(lhs.getType().isa<MemRefType>());
Type elemTy = lhs.getType().cast<MemRefType>().getElementType();
AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs);
......@@ -1860,16 +1853,16 @@ namespace
template <typename OP>
void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands,
ArrayRef<Value> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
{
auto loc = cast<OP>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
NGRAPH_CHECK(result.getType().isa<MemRefType>());
// get new operands
Value* lhs = operands[0];
Value* rhs = operands[1];
Value lhs = operands[0];
Value rhs = operands[1];
ScopedContext scope(rewriter, loc);
// Views
......@@ -1885,7 +1878,7 @@ namespace
// Steps
auto steps = vLHS.getSteps();
// element type of the operand
Type elemTy = result->getType().cast<MemRefType>().getElementType();
Type elemTy = result.getType().cast<MemRefType>().getElementType();
AffineLoopNestBuilder(pivs, lbs, ubs, steps)(
// single stmt body
[&] {
......@@ -1976,7 +1969,7 @@ namespace
template <typename RedOp>
void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
ArrayRef<Value> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
{
......@@ -1996,9 +1989,9 @@ namespace
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* arg = operands[0];
Value arg = operands[0];
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value result = pass.buildOutputDefs(op, rewriter)[0];
// Views
MemRefView vRes(result), vArg(arg);
......@@ -2011,7 +2004,7 @@ namespace
auto argLbs = vArg.getLbs();
auto argUbs = vArg.getUbs();
Type resTy = result->getType().cast<MemRefType>().getElementType();
Type resTy = result.getType().cast<MemRefType>().getElementType();
// Generate loop nest that initializes result to lower bound of the axis to be reduced.
{
auto ivs = makeIndexHandles(vRes.rank());
......@@ -2029,7 +2022,7 @@ namespace
auto steps = vArg.getSteps();
SmallVector<IndexHandle, 8> nonRedIVs;
Type resTy = result->getType().cast<MemRefType>().getElementType();
Type resTy = result.getType().cast<MemRefType>().getElementType();
NGRAPH_CHECK(resTy.isa<IntegerType>(),
"Expected integer result type in index reduction");
......@@ -2069,7 +2062,7 @@ namespace
template <typename OP>
void lowerPooling(Operation* op,
ArrayRef<Value*> operands,
ArrayRef<Value> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
{
......@@ -2078,18 +2071,18 @@ namespace
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* lhs = operands[0];
Value lhs = operands[0];
ArrayRef<Attribute> windowShape = pooling.windowShape().getValue();
ArrayRef<Attribute> windowStrides = pooling.windowMovementStrides().getValue();
ArrayRef<Attribute> padBelow = pooling.padBelow().getValue();
ArrayRef<Attribute> padAbove = pooling.padAbove().getValue();
Value* result = pass.buildOutputDefs(op, rewriter)[0];
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && result, "Unexpected null values in Pooling Op");
auto resultTy = result->getType().dyn_cast<MemRefType>();
auto resultTy = result.getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>();
auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto lhsShape = lhsTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
......@@ -2120,8 +2113,8 @@ namespace
}
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
SmallVector<mlir::Value*, 4> inputs = {lhs, result};
SmallVector<mlir::Value*, 4> outputs;
SmallVector<mlir::Value, 4> inputs = {lhs, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc =
......@@ -2158,7 +2151,7 @@ namespace
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(ty), 64);
SmallVector<mlir::Value*, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg};
SmallVector<mlir::Value, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result);
......@@ -2211,71 +2204,6 @@ namespace
}
NGRAPH_UNREACHABLE("Unsupported type");
}
// Given a concat op, it will check if dst and operands have
// a valid buffer/offset assignment that will make this op
// valid in-place
bool isInPlaceConcat(mlir::Operation* op, DialectLoweringPass& pass)
{
NGRAPH_CHECK(isa<NGConcatOp>(op), "Expecting concat operation");
auto concat = cast<NGConcatOp>(op);
auto concatAxis = concat.concatenation_axis();
auto result = concat.getResult();
auto shape = (result->getType().cast<NGTensorType>()).getShape();
auto memAnalysis = pass.getMemAnalysis();
BufferInfo bufferInfo = memAnalysis->getBufferInfo(op);
if (!bufferInfo.isValid())
{
// no buffer assignment to dst, nothing to do
return false;
}
auto dstBufferId = bufferInfo.m_bufferId;
auto dstOffset = bufferInfo.m_offset;
LLVM_DEBUG(llvm::dbgs() << ">> Check in-place concat\n");
LLVM_DEBUG(op->dump());
for (auto i = 0; i < shape.size(); i++)
{
if (i == concatAxis)
{
break;
}
if (shape[i] != 1)
{
LLVM_DEBUG(llvm::dbgs() << "Axis FAIL. Skipping instruction\n");
return false;
}
}
LLVM_DEBUG(llvm::dbgs() << "Axis OK\n");
// Check if the buffer id and offsets are consistent with what's exepcted
LLVM_DEBUG(llvm::dbgs() << "Dst (id, offset) = (" << dstBufferId << ", " << dstOffset
<< ")\n");
// relative offset in the buffer
int opndOffset = 0;
for (auto opnd : op->getOperands())
{
bufferInfo = memAnalysis->getBufferInfo(opnd->getDefiningOp());
auto srcBufferId = bufferInfo.m_bufferId;
auto srcOffset = bufferInfo.m_offset;
LLVM_DEBUG(llvm::dbgs() << "Src (id, offset) = (" << srcBufferId << ", " << srcOffset
<< ")\n");
if (!bufferInfo.isValid() || srcBufferId != dstBufferId ||
srcOffset != (opndOffset + dstOffset))
{
// mismatch in buffer IDs or offsets
LLVM_DEBUG(llvm::dbgs() << "Buffer ID and Offsets FAIL. Skipping instruction\n");
return false;
}
auto tensorType = opnd->getType().cast<NGTensorType>();
opndOffset += tensorType.getNumElements();
}
LLVM_DEBUG(llvm::dbgs() << "Buffer ID and Offsets OK\n");
return true;
}
} // namespace
namespace mlir
......
......@@ -81,7 +81,6 @@ mlir::Type NGraphOpsDialect::parseEltType(mlir::DialectAsmParser& parser) const
{
// Process nGraph integer element types.
MLIRContext* context = getContext();
int width = 0;
bool isSigned = false;
llvm::SMLoc loc = parser.getCurrentLocation();
......
......@@ -171,7 +171,7 @@ def NGRNNCellOp :
let builders = [
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t, "
"Value X, Value W, Value R, Value H_t, "
"Attribute hiddenSize, ArrayAttr activations,"
"ArrayAttr activationAlpha, ArrayAttr activationBeta, Attribute clip", [{
tblgen_state.addOperands({X, W, R, H_t});
......@@ -192,7 +192,7 @@ def NGRNNCellOp :
void setClip(const Attribute& attr) { this->setAttr("clip", attr); }
// get bias operand if present
Value* B()
Value B()
{
auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
......@@ -263,7 +263,7 @@ def NGMVN :
let builders = [
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *data, ArrayAttr reductionAxes, Attribute normalizeVariance,"
"Value data, ArrayAttr reductionAxes, Attribute normalizeVariance,"
"Attribute eps", [{
tblgen_state.addOperands(data);
tblgen_state.addAttribute("reductionAxes", reductionAxes);
......@@ -363,7 +363,7 @@ def NGLSTMCellOp :
let builders = [
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t, Value* C_t,"
"Value X, Value W, Value R, Value H_t, Value C_t,"
"Attribute hiddenSize, ArrayAttr activations,"
"ArrayAttr activationAlpha, ArrayAttr activationBeta,"
"Attribute clip, Attribute inputForget", [{
......@@ -379,7 +379,7 @@ def NGLSTMCellOp :
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t, Value* C_t,"
"Value X, Value W, Value R, Value H_t, Value C_t,"
"Attribute hiddenSize",
[{
tblgen_state.addOperands({X, W, R, H_t, C_t});
......@@ -390,13 +390,13 @@ def NGLSTMCellOp :
let extraClassDeclaration = [{
// get bias operand if present
Value* B()
Value B()
{
auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
}
// get peephole weights operand if present
Value* P()
Value P()
{
auto varArgs = optionalArgs();
auto it = varArgs.begin();
......@@ -452,7 +452,7 @@ def NGLSTMSequenceOp :
void setActivationsBeta (const ArrayAttr& attr) { this->setAttr("activatiBeta", attr); }
void setClip(const Attribute& attr) { this->setAttr("clip", attr); }
Value* P()
Value P()
{
auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
......@@ -500,7 +500,7 @@ def NGGRUCellOp :
let builders = [
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t,"
"Value X, Value W, Value R, Value H_t,"
"Attribute hiddenSize, ArrayAttr activations,"
"ArrayAttr activationAlpha, ArrayAttr activationBeta,"
"Attribute clip, Attribute linearBeforeReset", [{
......@@ -515,7 +515,7 @@ def NGGRUCellOp :
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t,"
"Value X, Value W, Value R, Value H_t,"
"Attribute hiddenSize",
[{
tblgen_state.addOperands({X, W, R, H_t});
......@@ -532,7 +532,7 @@ def NGGRUCellOp :
void setLinearBeforeReset(const Attribute& attr) { this->setAttr("linearBeforeReset", attr); }
// get Bias operand if present
Value* P()
Value P()
{
auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
......@@ -557,7 +557,7 @@ def NGLayerNormOp :
let builders = [
OpBuilder<
"Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res,"
"Value *data, Attribute keepStats, Attribute beginNormAxis, Attribute epsilon", [{
"Value data, Attribute keepStats, Attribute beginNormAxis, Attribute epsilon", [{
tblgen_state.addOperands(data);
tblgen_state.addAttribute("keepStats", keepStats);
tblgen_state.addAttribute("beginNormAxis", beginNormAxis);
......@@ -568,13 +568,13 @@ def NGLayerNormOp :
let extraClassDeclaration = [{
// get Scale operand if present
Value* Scale()
Value Scale()
{
auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
}
// get Bias operand if present
Value* Bias()
Value Bias()
{
auto varArgs = optionalArgs();
auto it = varArgs.begin();
......@@ -608,7 +608,7 @@ def NGLayerNormBackpropOp :
let builders = [
OpBuilder<
"Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res,"
"Value *data, Value *delta, Value *mean, Value *variance,"
"Value data, Value delta, Value mean, Value variance,"
"Attribute beginNormAxis, Attribute epsilon", [{
tblgen_state.addOperands({data, delta, mean, variance});
tblgen_state.addAttribute("beginNormAxis", beginNormAxis);
......@@ -618,7 +618,7 @@ def NGLayerNormBackpropOp :
OpBuilder<
"Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res,"
"Value *data, Value *delta, Value *scale,"
"Value data, Value delta, Value scale,"
"Attribute beginNormAxis, Attribute epsilon", [{
tblgen_state.addOperands({data, delta, scale});
tblgen_state.addAttribute("beginNormAxis", beginNormAxis);
......@@ -628,7 +628,7 @@ def NGLayerNormBackpropOp :
OpBuilder<
"Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res,"
"Value *data, Value *delta,"
"Value data, Value delta,"
"Attribute beginNormAxis, Attribute epsilon", [{
tblgen_state.addOperands({data, delta});
tblgen_state.addAttribute("beginNormAxis", beginNormAxis);
......@@ -639,13 +639,13 @@ def NGLayerNormBackpropOp :
let extraClassDeclaration = [{
// get Mean operand if present
Value* Mean()
Value Mean()
{
auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
}
// get Variance operand if present
Value* Variance()
Value Variance()
{
auto varArgs = optionalArgs();
auto it = varArgs.begin();
......@@ -722,8 +722,8 @@ def NGGroupConvOp :
let builders = [
// Builder without padType
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *images,"
"Value *filters, ArrayAttr strides, ArrayAttr padBelow, ArrayAttr padAbove,"
"Builder *builder, OperationState &tblgen_state, Type res, Value images,"
"Value filters, ArrayAttr strides, ArrayAttr padBelow, ArrayAttr padAbove,"
"Attribute groups",
[{
tblgen_state.addOperands({images, filters});
......@@ -772,18 +772,18 @@ def NGGroupConvTransposeOp :
let builders = [
OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters, Attribute groups", [{
"Value images, Value filters, Attribute groups", [{
tblgen_state.addOperands({images, filters});
tblgen_state.addAttribute("groups", groups);
tblgen_state.addTypes(res);
}]>,
OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters", [{
"Value images, Value filters", [{
tblgen_state.addOperands({images, filters});
tblgen_state.addTypes(res);
}]>,
OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters, ArrayAttr strides,"
"Value images, Value filters, ArrayAttr strides,"
"ArrayAttr outputPad, ArrayAttr outputShape,"
"Attribute groups", [{
tblgen_state.addOperands({images, filters});
......@@ -793,7 +793,7 @@ def NGGroupConvTransposeOp :
tblgen_state.addAttribute("groups", groups);
}]>,
OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters,"
"Value images, Value filters,"
"ArrayAttr outputShape, Attribute groups", [{
tblgen_state.addOperands({images, filters});
tblgen_state.addAttribute("outputShape", outputShape);
......@@ -951,7 +951,7 @@ def NGConvBiasOp :
let builders = [
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters, Value *bias, Attribute withRelu", [{
"Value images, Value filters, Value bias, Attribute withRelu", [{
tblgen_state.addOperands({images, filters, bias});
tblgen_state.addAttribute("withRelu", withRelu);
tblgen_state.addTypes(res);
......@@ -959,7 +959,7 @@ def NGConvBiasOp :
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters, Value *bias", [{
"Value images, Value filters, Value bias", [{
tblgen_state.addOperands({images, filters, bias});
tblgen_state.addTypes(res);
}]>
......
......@@ -44,12 +44,12 @@ using namespace mlir;
/// Checks if all operands and results are of compatible shapes
template <typename T>
static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkResult = true)
static mlir::LogicalResult verifyCompatibleOperandsAndResults(T op, bool checkResult = true)
{
mlir::Type t0 = op->getOperation()->getOperand(0)->getType();
mlir::Type t0 = op.getOperation()->getOperand(0).getType();
mlir::NGTensorType opType0 = t0.cast<NGTensorType>();
Operation* opr = op->getOperation();
Operation* opr = op.getOperation();
auto i = 0;
for (auto operand : opr->getOperands())
{
......@@ -57,10 +57,10 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR
{
continue;
}
mlir::Type t = operand->getType();
mlir::Type t = operand.getType();
mlir::NGTensorType opType = t.cast<NGTensorType>();
if (!opType.isCompatible(opType0))
return op->emitOpError("Incompatible operand shape");
return op.emitOpError("Incompatible operand shape");
i++;
}
......@@ -68,74 +68,74 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR
{
for (auto result : opr->getResults())
{
mlir::Type t = result->getType();
mlir::Type t = result.getType();
mlir::NGTensorType resType = t.cast<NGTensorType>();
if (!resType.isCompatible(opType0))
return op->emitOpError("Incompatible operand shape");
return op.emitOpError("Incompatible operand shape");
}
}
return mlir::success();
}
template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op)
static mlir::LogicalResult verifyUnaryArithOp(T op)
{
return verifyCompatibleOperandsAndResults(op);
}
template <typename T>
static mlir::LogicalResult verifyBinaryArithOp(T* op)
static mlir::LogicalResult verifyBinaryArithOp(T op)
{
return verifyCompatibleOperandsAndResults(op);
}
template <typename T>
static mlir::LogicalResult verifyAxisReductionOp(T* op)
static mlir::LogicalResult verifyAxisReductionOp(T op)
{
return mlir::failure();
}
template <typename T>
static mlir::LogicalResult verifyLogicalReductionOp(T* op)
static mlir::LogicalResult verifyLogicalReductionOp(T op)
{
// TODO: verifyAxisReductionOp(op) + input and return element type.
return mlir::failure();
}
template <typename T>
static mlir::LogicalResult verifyIndexReductionOp(T* op)
static mlir::LogicalResult verifyIndexReductionOp(T op)
{
// TODO: verifyAxisReductionOp(op) + return element type + single axis.
return mlir::success();
}
template <typename T>
static mlir::LogicalResult verifyOp(T* op)
static mlir::LogicalResult verifyOp(T op)
{
return op->emitOpError("Unsupported verifier for this operation");
return op.emitOpError("Unsupported verifier for this operation");
}
template <>
mlir::LogicalResult verifyOp(NGDotOp* op)
mlir::LogicalResult verifyOp(NGDotOp op)
{
// TODO(dcab): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGConcatOp* op)
mlir::LogicalResult verifyOp(NGConcatOp op)
{
// TODO(amprocte): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGSelectOp* op)
mlir::LogicalResult verifyOp(NGSelectOp op)
{
mlir::Type t0 = op->getOperation()->getOperand(0)->getType();
mlir::Type t1 = op->getOperation()->getOperand(1)->getType();
mlir::Type t2 = op->getOperation()->getOperand(2)->getType();
mlir::Type r0 = op->getOperation()->getResult(0)->getType();
mlir::Type t0 = op.getOperation()->getOperand(0).getType();
mlir::Type t1 = op.getOperation()->getOperand(1).getType();
mlir::Type t2 = op.getOperation()->getOperand(2).getType();
mlir::Type r0 = op.getOperation()->getResult(0).getType();
NGTensorType opType0 = t0.cast<NGTensorType>();
NGTensorType opType1 = t1.cast<NGTensorType>();
......@@ -144,19 +144,19 @@ mlir::LogicalResult verifyOp(NGSelectOp* op)
// arg1 arg2 of same shape and elt type
if (!opType1.isCompatible(opType2))
return op->emitOpError("Incompatible operand shapes or types for select op");
return op.emitOpError("Incompatible operand shapes or types for select op");
// arg0 of same shape and elt type is bool
if (!opType0.isCompatibleShape(opType1) || !opType0.getElementType().isa<NGBoolType>())
return op->emitOpError("Incompatible shape for arg0 of select op");
return op.emitOpError("Incompatible shape for arg0 of select op");
// result is of same shape and elt type as arg1/2
if (!resType.isCompatible(opType1))
return op->emitOpError("Incompatible result shape or type for select op");
return op.emitOpError("Incompatible result shape or type for select op");
return mlir::success();
}
template <typename T>
static mlir::LogicalResult verifyCmpOp(T* op)
static mlir::LogicalResult verifyCmpOp(T op)
{
mlir::LogicalResult result = verifyCompatibleOperandsAndResults(op, false /*checkResult*/);
if (failed(result))
......@@ -164,75 +164,75 @@ static mlir::LogicalResult verifyCmpOp(T* op)
return result;
}
mlir::Type t0 = op->getOperation()->getOperand(0)->getType();
mlir::Type t0 = op.getOperation()->getOperand(0).getType();
mlir::NGTensorType opType0 = t0.cast<NGTensorType>();
mlir::Type r0 = op->getOperation()->getResult(0)->getType();
mlir::Type r0 = op.getOperation()->getResult(0).getType();
NGTensorType resType = r0.cast<NGTensorType>();
// result of same shape as input and has bool type
if (!resType.isCompatibleShape(opType0) ||
!resType.getElementType().cast<NGIntegerType>().isUInt8())
{
return op->emitOpError("Incompatible result shape or type for comparison op");
return op.emitOpError("Incompatible result shape or type for comparison op");
}
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGGatherOp* op)
mlir::LogicalResult verifyOp(NGGatherOp op)
{
Type ty = op->params()->getType();
Type ty = op.params().getType();
NGTensorType inputType = ty.cast<NGTensorType>();
ty = op->indices()->getType();
ty = op.indices().getType();
NGTensorType indicesType = ty.cast<NGTensorType>();
// ensure axis < params rank
if (op->axis().getSExtValue() >= inputType.getRank())
return op->emitOpError("Gather axis is larger than input rank");
if (op.axis().getSExtValue() >= inputType.getRank())
return op.emitOpError("Gather axis is larger than input rank");
ty = indicesType.getElementType();
// ensure indices are I32 or I64
if (!ty.isa<NGIntegerType>())
return op->emitOpError("Indices tensor is not of Integer type");
return op.emitOpError("Indices tensor is not of Integer type");
NGIntegerType indicesEltType = ty.cast<NGIntegerType>();
if (!indicesEltType.isInt32() && !indicesEltType.isInt64())
return op->emitOpError("Indices tensor is not of I32 or I64 type");
return op.emitOpError("Indices tensor is not of I32 or I64 type");
mlir::Type r0 = op->res()->getType();
mlir::Type r0 = op.res().getType();
NGTensorType resType = r0.cast<NGTensorType>();
// ensure result is compatible with input
if (resType.getRank() != inputType.getRank() + indicesType.getRank() - 1)
return op->emitOpError("Incompatible result shape and/or type");
return op.emitOpError("Incompatible result shape and/or type");
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGConvolutionOp* op)
mlir::LogicalResult verifyOp(NGConvolutionOp op)
{
Type ty = op->images()->getType();
Type ty = op.images().getType();
NGTensorType imagesType = ty.cast<NGTensorType>();
Type imagesEt = imagesType.getElementType();
Shape imagesShape = imagesType.getShape();
ty = op->filters()->getType();
ty = op.filters().getType();
NGTensorType filtersType = ty.cast<NGTensorType>();
Type filtersEt = filtersType.getElementType();
Shape filtersShape = filtersType.getShape();
ty = op->res()->getType();
ty = op.res().getType();
NGTensorType resultType = ty.cast<NGTensorType>();
Shape resultShape = resultType.getShape();
ArrayAttr strides = op->strides();
ArrayAttr padBelow = op->padBelow();
ArrayAttr padAbove = op->padAbove();
ArrayAttr strides = op.strides();
ArrayAttr padBelow = op.padBelow();
ArrayAttr padAbove = op.padAbove();
unsigned imagesRank = imagesShape.size();
unsigned filtersRank = filtersShape.size();
......@@ -247,32 +247,32 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op)
// Identical filters and image element types
if (filtersEt != imagesEt)
{
return op->emitOpError("Incompatible image and filters types");
return op.emitOpError("Incompatible image and filters types");
}
// Verify image shape
if (imagesRank < 3)
{
return op->emitOpError("Image shape of rank below 3");
return op.emitOpError("Image shape of rank below 3");
}
// Verify strides and pads shapes
if (imageSpatialRank != stridesRank || imageSpatialRank != padBelowRank ||
imageSpatialRank != padAboveRank)
{
return op->emitOpError("Image spatial rank mismatches strides and/or padding ranks");
return op.emitOpError("Image spatial rank mismatches strides and/or padding ranks");
}
if (imageSpatialRank != filtersSpatialRank)
{
return op->emitOpError("Image and filters spatial ranks mismatch");
return op.emitOpError("Image and filters spatial ranks mismatch");
}
// Batch size is non-zero, and identical non-zero channel depth
if (imagesShape[0] <= 0 || filtersShape[0] <= 0 || imagesShape[1] != filtersShape[1] ||
imagesShape[1] <= 0)
{
return op->emitOpError("Image and filters have invalid shapes");
return op.emitOpError("Image and filters have invalid shapes");
}
for (auto attrs : llvm::zip(strides, padBelow, padAbove))
......@@ -283,7 +283,7 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op)
if (s <= 0)
{
return op->emitOpError("Window stride must be non-negative");
return op.emitOpError("Window stride must be non-negative");
}
stridesVal.push_back(s);
padBelowVal.push_back(pb);
......@@ -294,7 +294,7 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op)
if (resultRank != imagesRank || resultShape[0] != imagesShape[0] ||
resultShape[1] != filtersShape[0])
{
return op->emitOpError("Invalid result shape");
return op.emitOpError("Invalid result shape");
}
for (unsigned i = 0; i < resultRank - 2; i++)
{
......@@ -303,56 +303,42 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op)
stridesVal[i]);
if (resultShape[2 + i] != resDim)
{
return op->emitOpError("Invalid result spatial shape");
return op.emitOpError("Invalid result spatial shape");
}
}
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGMatMulOp* op)
mlir::LogicalResult verifyOp(NGSoftMaxOp op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGGemmOp* op)
mlir::LogicalResult verifyOp(NGAvgPoolOp op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGSoftMaxOp* op)
mlir::LogicalResult verifyOp(NGAvgPoolBackpropOp op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGAvgPoolOp* op)
mlir::LogicalResult verifyOp(NGMaxPoolOp op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGAvgPoolBackpropOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGMaxPoolOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGMaxPoolBackpropOp* op)
mlir::LogicalResult verifyOp(NGMaxPoolBackpropOp op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
......
......@@ -78,7 +78,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
// TODO: Implement
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyUnaryArithOp(this); }];
let verifier = [{ return verifyUnaryArithOp(*this); }];
}
// Base class for arithmetic binary operations without side effects.
......@@ -98,7 +98,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
// TODO: Implement
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyBinaryArithOp(this); }];
let verifier = [{ return verifyBinaryArithOp(*this); }];
}
// Base class for comparison operations with verifier.
......@@ -109,7 +109,7 @@ class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> :
// TODO: Implement
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyCmpOp(this); }];
let verifier = [{ return verifyCmpOp(*this); }];
}
// Base class for ternary operations without side effects.
......@@ -133,7 +133,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
// TODO
let verifier = [{ return verifyAxisReductionOp(this); }];
let verifier = [{ return verifyAxisReductionOp(*this); }];
}
// Base class for terminator operations.
......
......@@ -61,14 +61,14 @@ def NGNotEqOp : NG_Cmp_Op<"not.equal", [OpVersion0]>;
// Other
def NGSelectOp : NG_Ternary_Op<"select", [OpVersion0]>
{
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
}
// Dot Product
def NGDotOp : NG_Binary_Op<"dot", [OpVersion0]>
{
// TODO: Add reduction axis attribute when needed.
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
}
// TODO(amprocte): Might be nice to rebase this on some sort of NG_Variadic_Op
......@@ -80,56 +80,56 @@ def NGConcatOp :
{
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
}
// Axis reduction operations.
def NGSumRedOp : NG_Axis_Reduction_Op<"sum.red", [OpVersion0]>
{
let summary = "Axis sum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
let verifier = [{ return verifyAxisReductionOp(*this); }];
}
def NGProdRedOp : NG_Axis_Reduction_Op<"prod.red", [OpVersion0]>
{
let summary = "Axis product reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
let verifier = [{ return verifyAxisReductionOp(*this); }];
}
def NGMinRedOp : NG_Axis_Reduction_Op<"min.red", [OpVersion0]>
{
let summary = "Axis minimum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
let verifier = [{ return verifyAxisReductionOp(*this); }];
}
def NGMaxRedOp : NG_Axis_Reduction_Op<"max.red", [OpVersion0]>
{
let summary = "Axis maximum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }];
let verifier = [{ return verifyAxisReductionOp(*this); }];
}
def NGArgMinRedOp : NG_Axis_Reduction_Op<"argmin.red", [OpVersion0]>
{
let summary = "Axis minimum index reduction of a tensor.";
let verifier = [{ return verifyIndexReductionOp(this); }];
let verifier = [{ return verifyIndexReductionOp(*this); }];
}
def NGArgMaxRedOp : NG_Axis_Reduction_Op<"argmax.red", [OpVersion0]>
{
let summary = "Axis maximum index reduction of a tensor.";
let verifier = [{ return verifyIndexReductionOp(this); }];
let verifier = [{ return verifyIndexReductionOp(*this); }];
}
def NGAllRedOp : NG_Axis_Reduction_Op<"all.red", [OpVersion0]>
{
let summary = "Axis logical AND reduction of a boolean tensor.";
let verifier = [{ return verifyLogicalReductionOp(this); }];
let verifier = [{ return verifyLogicalReductionOp(*this); }];
}
def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red", [OpVersion0]>
{
let summary = "Axis logical OR reduction of a boolean tensor.";
let verifier = [{ return verifyLogicalReductionOp(this); }];
let verifier = [{ return verifyLogicalReductionOp(*this); }];
}
// Gather
......@@ -147,7 +147,7 @@ def NGGatherOp :
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
}
// Convolution
......@@ -171,7 +171,7 @@ def NGConvolutionOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setStrides(ArrayAttr& arrayAttr) { this->setAttr("strides", arrayAttr); }
void setPadBelow(ArrayAttr& arrayAttr) { this->setAttr("padBelow", arrayAttr); }
......@@ -210,10 +210,10 @@ def NGAvgPoolOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let builders = [
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *arg,"
"Builder *builder, OperationState &tblgen_state, Type res, Value arg,"
"ArrayAttr windowShape, ArrayAttr windowMovementStrides,"
"ArrayAttr padBelow, ArrayAttr padAbove, BoolAttr includPadding, IntegerAttr padType", [{
tblgen_state.addOperands(arg);
......@@ -227,7 +227,7 @@ def NGAvgPoolOp :
}]>,
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *arg,"
"Builder *builder, OperationState &tblgen_state, Type res, Value arg,"
"ArrayAttr windowShape, ArrayAttr windowMovementStrides,"
"ArrayAttr padBelow, ArrayAttr padAbove, BoolAttr includPadding", [{
tblgen_state.addOperands(arg);
......@@ -269,7 +269,7 @@ def NGAvgPoolBackpropOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setForwardArgShape(const ArrayAttr& arrayAttr) { this->setAttr("forwardArgShape", arrayAttr); }
......@@ -296,7 +296,7 @@ def NGBatchNormTrainingOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setEpsilon(const Attribute& attr) { this->setAttr("epsilon", attr); }
......@@ -320,7 +320,7 @@ def NGBatchNormInferenceOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setEpsilon(const Attribute& attr) { this->setAttr("epsilon", attr); }
}];
......@@ -345,7 +345,7 @@ def NGBatchNormTrainingBackPropOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setEpsilon(const Attribute& attr) { this->setAttr("epsilon", attr); }
}];
......@@ -366,7 +366,7 @@ def NGBroadcastOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setAxisSet(const ArrayAttr& attr) { this->setAttr("axisSet", attr); }
void setShape(const ArrayAttr& attr) { this->setAttr("shape", attr); }
......@@ -387,7 +387,7 @@ def NGConstantOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
}
// MaxPool
......@@ -416,23 +416,10 @@ def NGMaxPoolOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let builders = [
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *arg,"
"ArrayAttr windowShape, ArrayAttr windowMovementStrides,"
"ArrayAttr padBelow, ArrayAttr padAbove, IntegerAttr padType", [{
tblgen_state.addOperands(arg);
tblgen_state.addAttribute("windowShape", windowShape);
tblgen_state.addAttribute("windowMovementStrides", windowMovementStrides);
tblgen_state.addAttribute("padBelow", padBelow);
tblgen_state.addAttribute("padAbove", padAbove);
tblgen_state.addAttribute("padType", padType);
tblgen_state.addTypes(res);
}]>,
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *arg,"
"Builder *builder, OperationState &tblgen_state, Type res, Value arg,"
"ArrayAttr windowShape, ArrayAttr windowMovementStrides,"
"ArrayAttr padBelow, ArrayAttr padAbove", [{
tblgen_state.addOperands(arg);
......@@ -471,7 +458,7 @@ def NGMaxPoolBackpropOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setWindowShape(const ArrayAttr& arrayAttr) { this->setAttr("windowShape", arrayAttr); }
......@@ -510,12 +497,12 @@ def NGPadOp :
DefaultValuedAttr<PadModeEnumAttr, "MLIRPadMode::CONSTANT"> :$padMode)>
{
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let builders = [
// Builder without padMode
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, "
"Value *arg, Value *padValue, "
"Value arg, Value padValue, "
"ArrayAttr padBelow, ArrayAttr padAbove", [{
tblgen_state.addOperands(arg);
tblgen_state.addOperands(padValue);
......@@ -556,7 +543,7 @@ def NGReplaceSliceOp :
slice to be replaced.
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setLowerBounds(const ArrayAttr& arrayAttr) { this->setAttr("lowerBounds", arrayAttr); }
void setUpperBounds(const ArrayAttr& arrayAttr) { this->setAttr("upperBounds", arrayAttr); }
......@@ -584,7 +571,7 @@ def NGSliceOp :
slice to be replaced.
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setLowerBounds(const ArrayAttr& arrayAttr) { this->setAttr("lowerBounds", arrayAttr); }
void setUpperBounds(const ArrayAttr& arrayAttr) { this->setAttr("upperBounds", arrayAttr); }
......@@ -610,7 +597,7 @@ def NGReshapeOp :
Pi(a_i) = Pi(b_i)
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setAxisOrder(const ArrayAttr& arrayAttr) { this->setAttr("axisOrder", arrayAttr); }
void setShape(const ArrayAttr& arrayAttr) { this->setAttr("shape", arrayAttr); }
......@@ -630,7 +617,7 @@ def NGSoftMaxOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setAxes(const ArrayAttr& arrayAttr) { this->setAttr("axes", arrayAttr); }
}];
......@@ -657,7 +644,7 @@ def NGTopKOp :
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{
void setK(const Attribute& attr) { this->setAttr("k", attr); }
void setAxis(const Attribute& attr) { this->setAttr("axis", attr); }
......
......@@ -68,7 +68,7 @@ namespace
struct TensorInfo
{
// MLIR values this tensor maps to.
mlir::Value* m_value;
mlir::Value m_value;
};
private:
......@@ -84,7 +84,7 @@ namespace
mlir::Type getMlirType(const ngraph::Node* node);
TensorInfo getTensorValue(descriptor::Tensor* tensor);
void updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value);
void updateTensorValue(descriptor::Tensor* tensor, mlir::Value value);
template <typename Op>
static mlir::Operation* createOp(NgDialectConversionPass& NgDialectObj,
......@@ -176,7 +176,7 @@ void NgDialectConversionPass::runOnModule()
int i = 0;
for (auto input : kernelInputs)
{
mlir::Value* arg = function.getArgument(i);
auto arg = function.getArgument(i);
TensorInfo tensorInfo{arg};
m_tensorToValueMap.insert(TensorToInfo(input->get_output_tensor_ptr().get(), tensorInfo));
i++;
......@@ -264,7 +264,7 @@ mlir::Type NgDialectConversionPass::getMlirType(const ngraph::Node* node)
return getMlirType(outTensor);
}
void NgDialectConversionPass::updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value)
void NgDialectConversionPass::updateTensorValue(descriptor::Tensor* tensor, mlir::Value value)
{
NGRAPH_CHECK(m_tensorToValueMap.find(tensor) == m_tensorToValueMap.end(),
"tensor value already defined");
......@@ -307,7 +307,7 @@ void NgDialectConversionPass::buildNgDialect(mlir::FuncOp function)
{
for (auto i = 0; i < op->getNumResults(); i++)
{
mlir::Value* result = op->getResult(i);
auto result = op->getResult(i);
if (result)
{
updateTensorValue(np->get_output_tensor_ptr(i).get(), result);
......@@ -600,7 +600,6 @@ template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Softmax)
{
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGSoftMaxOp>(ngNode, 1);
auto softmaxNode = static_cast<const ngraph::op::Softmax*>(ngNode);
auto softmaxOp = llvm::cast<mlir::NGSoftMaxOp>(op);
auto originArg = NgDialectObj.getOriginArg(ngNode->input_value(1).get_node());
......@@ -614,7 +613,7 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Softmax)
template <typename Op>
mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ngNode, int inNum)
{
std::vector<mlir::Value*> argValues;
std::vector<mlir::Value> argValues;
std::vector<mlir::Type> resTypes;
auto inputMap = m_compiledKernel->get_input_map();
std::shared_ptr<descriptor::Tensor> argTensor;
......@@ -650,7 +649,7 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng
return (m_builder.create<Op,
ArrayRef<mlir::Type>,
ArrayRef<mlir::Value*>,
ArrayRef<mlir::Value>,
ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(m_context), resTypes, argValues, {/* no attrs */}))
.getOperation();
......@@ -663,7 +662,7 @@ const NgDialectConversionPass::MLIRCompOpMap NgDialectConversionPass::opDispatch
void NgDialectConversionPass::createReturn()
{
std::vector<mlir::Value*> valueList;
std::vector<mlir::Value> valueList;
for (auto output : m_compiledKernel->get_kernel_outputs())
{
valueList.push_back(getTensorValue(output->get_output_tensor_ptr().get()).m_value);
......
......@@ -140,7 +140,6 @@ void MLIRCPURuntime::execute()
void MLIRCPURuntime::cleanup()
{
// Free void double pointer arguments without freeing external tensor data.
int i = 0;
for (auto* arg : m_invokeArgs)
{
auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(arg));
......
......@@ -16,7 +16,7 @@
# Enable use of the lit tool that we build from MLIR repo.
set(LLVM_LIT ${LLVM_MAIN_SRC_DIR}/utils/lit/lit.py)
set(LLVM_DEFAULT_EXTERNAL_LIT ${MLIR_TOOLS_DIR}/llvm-lit)
set(LLVM_DEFAULT_EXTERNAL_LIT ${MLIR_LLVM_TOOLS_DIR}/llvm-lit)
configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
......
......@@ -150,7 +150,8 @@ func @simple_dot(%arg0: !ng.tensor<16x8xf32>, %arg1: !ng.tensor<8x32xf32>) -> !n
// -----
// std.view
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1) -> (d0 * 2 + d1)
// CHECK: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK: %[[T1:[0-9]+]] = alloc() : memref<24xi8>
// CHECK-NEXT: %[[T2:[0-9]+]] = std.view %[[T1]][][] : memref<24xi8> to memref<3x2xf32, #[[MAP0]]>
// CHECK: affine.store %{{[0-9]+}}, %[[T2]][%{{.*}}, %{{.*}}] : memref<3x2xf32, #[[MAP0]]>
......@@ -198,12 +199,12 @@ func @convolution(%arg0: !ng.tensor<1x2x2x2xf32>, %arg1: !ng.tensor<2x2x1x1xf32>
// -----
//
// Group Convolution
// CHECK-DAG: #[[M0:.*]] = (d0) -> (d0 * 2)
// CHECK-DAG: #[[M1:.*]] = (d0) -> (d0 * 2 + 2)
// CHECK-DAG: #[[M2:.*]] = (d0) -> (d0)
// CHECK-DAG: #[[M3:.*]] = (d0) -> (d0 + 1)
// CHECK-DAG: #[[M8:.*]] = (d0, d1) -> (d0 + d1)
// CHECK-DAG: #[[M9:.*]] = (d0, d1) -> (d0 - d1 * 2)
// CHECK-DAG: #[[M0:.*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: #[[M1:.*]] = affine_map<(d0) -> (d0 * 2 + 2)>
// CHECK-DAG: #[[M2:.*]] = affine_map<(d0) -> (d0)>
// CHECK-DAG: #[[M3:.*]] = affine_map<(d0) -> (d0 + 1)>
// CHECK-DAG: #[[M8:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-DAG: #[[M9:.*]] = affine_map<(d0, d1) -> (d0 - d1 * 2)>
// CHECK-LABEL: func @groupConv
//
// Outer groups loops
......
// RUN: ngraph-opt %s --split-input-file --ngraph-memory-opt --ngraph-memory-opt-concat --ngraph-memory-opt-eltwise -convert-ngraph-to-affine | FileCheck %s
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1) -> (d0 * 2 + d1)
// CHECK: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK-LABEL: test0
// CHECK: %[[B:.*]] = alloc() : memref<16xi8>
// CHECK: std.view %[[B]][][] : memref<16xi8> to memref<2x2xf32, #[[MAP0]]>
......@@ -17,8 +17,8 @@ func @test0(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tenso
// -----
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1) -> (d0 * 2 + d1)
// CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = (d0, d1) -> (d0 * 2 + d1 + 4)
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 + 4)>
// CHECK-LABEL: test1
// CHECK: %[[B:.*]] = alloc() : memref<32xi8>
// CHECK: std.view %[[B]][][] : memref<32xi8> to memref<2x2xf32, #[[MAP0]]>
......@@ -35,10 +35,10 @@ func @test1(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tenso
// -----
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2)
// CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 4)
// CHECK-DAG: #[[MAP2:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2)
// CHECK-DAG: #[[MAP3:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2)
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2)>
// CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 4)>
// CHECK-DAG: #[[MAP2:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2)>
// CHECK-DAG: #[[MAP3:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2)>
// CHECK-LABEL: test2
// CHECK: %[[B1:.*]] = alloc() : memref<32xi8>
// CHECK: std.view %[[B1]][][] : memref<32xi8> to memref<1x2x2xf32, #[[MAP0]]>
......@@ -66,13 +66,13 @@ func @test2(%arg0: !ng.tensor<1x2x2xf32>, %arg1: !ng.tensor<1x2x2xf32>) -> (!ng.
// -----
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2)
// CHECK-DAG: #[[MAP8:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 8)
// CHECK-DAG: #[[MAP9:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 16)
// CHECK-DAG: #[[MAP10:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 24)
// CHECK-DAG: #[[MAP11:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2)
// CHECK-DAG: #[[MAP12:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2 + 16)
// CHECK-DAG: #[[MAP13:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2)
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2)>
// CHECK-DAG: #[[MAP8:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 8)>
// CHECK-DAG: #[[MAP9:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 16)>
// CHECK-DAG: #[[MAP10:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 24)>
// CHECK-DAG: #[[MAP11:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2)>
// CHECK-DAG: #[[MAP12:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2 + 16)>
// CHECK-DAG: #[[MAP13:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2)>
// CHECK-LABEL: test3
// CHECK: %[[B:.*]] = alloc() : memref<128xi8>
// CHECK: std.view %[[B]][][] : memref<128xi8> to memref<1x4x2xf32, #[[MAP0]]>
......@@ -97,10 +97,10 @@ func @test3(%arg0: !ng.tensor<1x2x2xf32>, %arg1: !ng.tensor<1x2x2xf32>) -> !ng.t
// -----
//CHECK-DAG: #[[MAP4:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 4)
//CHECK-DAG: #[[MAP5:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2)
//CHECK-DAG: #[[MAP6:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 8)
//CHECK-DAG: #[[MAP12:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 12 + d1 * 2 + d2)
//CHECK-DAG: #[[MAP4:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 4)>
//CHECK-DAG: #[[MAP5:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2)>
//CHECK-DAG: #[[MAP6:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 8)>
//CHECK-DAG: #[[MAP12:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 12 + d1 * 2 + d2)>
// CHECK-LABEL: test4
//CHECK: %[[B1:.*]] = alloc() : memref<1x2x2xf32>
//CHECK: %[[B2:.*]] = alloc() : memref<48xi8>
......
......@@ -17,9 +17,9 @@
import lit.llvm
config.llvm_tools_dir = "@MLIR_TOOLS_DIR@"
config.mlir_obj_root = "@MLIR_BUILD_DIR@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
config.llvm_tools_dir = "@MLIR_LLVM_TOOLS_DIR@"
config.mlir_obj_root = "@MLIR_LLVM_BUILD_DIR@"
config.mlir_tools_dir = "@MLIR_LLVM_TOOLS_DIR@"
config.suffixes = ['.mlir']
config.ngraph_mlir_tools_dir = "@NGRAPH_BUILD_BIN@"
......
......@@ -31,7 +31,6 @@ using namespace mlir;
OpBuilder createBuilder(MLIRContext* context)
{
auto module = ModuleOp::create(UnknownLoc::get(context));
auto funcType = FunctionType::get({}, {}, context);
auto function = FuncOp::create(UnknownLoc::get(context), "main", funcType);
function.addEntryBlock();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment