Commit a643213d authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Update build system to use LLVM mono-repo for MLIR (#4188)

* [MLIR] Update build system to use LLVM mono-repo for MLIR

* [MLIR] LLVM mono-repo conflicts

* Disable lit tests

* Fix formatting

* Fix memopt tests

* PR fix

* Fix view test
Co-authored-by: 's avatarNagy Mostafa <nagy.mostafa@gmail.com>
parent 1e13ad94
...@@ -17,11 +17,9 @@ ...@@ -17,11 +17,9 @@
include(ExternalProject) include(ExternalProject)
set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git) set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions # Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID c36773c7) set(MLIR_LLVM_COMMIT_ID d6295255)
set(MLIR_COMMIT_ID 606e96a1)
# MLIR environment variables. Some of them are used by LIT tool. # MLIR environment variables. Some of them are used by LIT tool.
...@@ -32,9 +30,10 @@ else() ...@@ -32,9 +30,10 @@ else()
endif() endif()
set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects) set(MLIR_LLVM_ROOT ${MLIR_PROJECT_ROOT}/llvm-projects)
set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm/projects/mlir) set(MLIR_LLVM_SOURCE_DIR ${MLIR_LLVM_ROOT}/llvm)
set(MLIR_BUILD_DIR ${MLIR_LLVM_ROOT}/build) set(MLIR_SOURCE_DIR ${MLIR_LLVM_ROOT}/mlir)
set(MLIR_TOOLS_DIR ${MLIR_BUILD_DIR}/bin) set(MLIR_LLVM_BUILD_DIR ${MLIR_PROJECT_ROOT}/build)
set(MLIR_LLVM_TOOLS_DIR ${MLIR_LLVM_BUILD_DIR}/bin)
set(NGRAPH_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir) set(NGRAPH_LIT_TEST_SRC_DIR ${CMAKE_SOURCE_DIR}/test/mlir)
set(NGRAPH_LIT_TEST_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/test/mlir) set(NGRAPH_LIT_TEST_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/test/mlir)
...@@ -48,17 +47,13 @@ if (NOT NGRAPH_USE_PREBUILT_MLIR) ...@@ -48,17 +47,13 @@ if (NOT NGRAPH_USE_PREBUILT_MLIR)
execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
WORKING_DIRECTORY "${MLIR_PROJECT_ROOT}") WORKING_DIRECTORY "${MLIR_PROJECT_ROOT}")
# clone and build llvm # Clone and build llvm + mlir.
execute_process(COMMAND "${CMAKE_COMMAND}" --build . --target ext_mlir_llvm execute_process(COMMAND "${CMAKE_COMMAND}" --build . --target ext_mlir_llvm
WORKING_DIRECTORY "${MLIR_PROJECT_ROOT}") WORKING_DIRECTORY "${MLIR_PROJECT_ROOT}")
# clone and build mlir
execute_process(COMMAND "${CMAKE_COMMAND}" --build . --target ext_mlir
WORKING_DIRECTORY "${MLIR_PROJECT_ROOT}")
endif() endif()
# Enable modules for LLVM. # Enable modules for LLVM.
set(LLVM_DIR "${MLIR_BUILD_DIR}/lib/cmake/llvm" set(LLVM_DIR "${MLIR_LLVM_BUILD_DIR}/lib/cmake/llvm"
CACHE PATH "Path to LLVM cmake modules") CACHE PATH "Path to LLVM cmake modules")
list(APPEND CMAKE_MODULE_PATH "${LLVM_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_DIR}")
include(AddLLVM) include(AddLLVM)
...@@ -71,7 +66,7 @@ message(STATUS "Using modules in: ${LLVM_DIR}") ...@@ -71,7 +66,7 @@ message(STATUS "Using modules in: ${LLVM_DIR}")
message(STATUS "LLVM RTTI is ${LLVM_ENABLE_RTTI}") message(STATUS "LLVM RTTI is ${LLVM_ENABLE_RTTI}")
set(MLIR_SRC_INCLUDE_PATH ${MLIR_SOURCE_DIR}/include) set(MLIR_SRC_INCLUDE_PATH ${MLIR_SOURCE_DIR}/include)
set(MLIR_BIN_INCLUDE_PATH ${MLIR_BUILD_DIR}/projects/mlir/include) set(MLIR_BIN_INCLUDE_PATH ${MLIR_LLVM_BUILD_DIR}/tools/mlir/include)
set(MLIR_INCLUDE_PATHS ${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH}) set(MLIR_INCLUDE_PATHS ${MLIR_SRC_INCLUDE_PATH};${MLIR_BIN_INCLUDE_PATH})
set(MLIR_LLVM_INCLUDE_PATH ${LLVM_INCLUDE_DIRS}) set(MLIR_LLVM_INCLUDE_PATH ${LLVM_INCLUDE_DIRS})
......
...@@ -20,22 +20,6 @@ include(ExternalProject) ...@@ -20,22 +20,6 @@ include(ExternalProject)
project(mlir-fetch NONE) project(mlir-fetch NONE)
ExternalProject_Add(
ext_mlir_llvm
PREFIX mlir_llvm
GIT_REPOSITORY @MLIR_LLVM_REPO_URL@
GIT_TAG @MLIR_LLVM_COMMIT_ID@
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
UPDATE_COMMAND ""
SOURCE_DIR @MLIR_LLVM_ROOT@
DOWNLOAD_NO_PROGRESS TRUE
EXCLUDE_FROM_ALL TRUE
)
set(MLIR_DEPENDS ext_mlir_llvm)
include(ProcessorCount) include(ProcessorCount)
ProcessorCount(N) ProcessorCount(N)
if(N EQUAL 0) if(N EQUAL 0)
...@@ -43,22 +27,18 @@ if(N EQUAL 0) ...@@ -43,22 +27,18 @@ if(N EQUAL 0)
endif() endif()
ExternalProject_Add( ExternalProject_Add(
ext_mlir ext_mlir_llvm
PREFIX mlir PREFIX mlir_llvm
DEPENDS ${MLIR_DEPENDS} GIT_REPOSITORY @MLIR_LLVM_REPO_URL@
GIT_REPOSITORY @MLIR_REPO_URL@ GIT_TAG @MLIR_LLVM_COMMIT_ID@
GIT_TAG @MLIR_COMMIT_ID@ CMAKE_GENERATOR @CMAKE_GENERATOR@
CONFIGURE_COMMAND ""
CMAKE_GENERATOR "@CMAKE_GENERATOR@"
CMAKE_GENERATOR_PLATFORM @CMAKE_GENERATOR_PLATFORM@ CMAKE_GENERATOR_PLATFORM @CMAKE_GENERATOR_PLATFORM@
CMAKE_GENERATOR_TOOLSET @CMAKE_GENERATOR_TOOLSET@ CMAKE_GENERATOR_TOOLSET @CMAKE_GENERATOR_TOOLSET@
CONFIGURE_COMMAND @CMAKE_COMMAND@ @MLIR_LLVM_SOURCE_DIR@ -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD=host -DLLVM_ENABLE_RTTI=ON -DCMAKE_BUILD_TYPE=@CMAKE_BUILD_TYPE@
BUILD_COMMAND @CMAKE_COMMAND@ ../llvm -DLLVM_BUILD_EXAMPLES=ON -DLLVM_TARGETS_TO_BUILD=host -DLLVM_ENABLE_RTTI=ON -DCMAKE_BUILD_TYPE=@CMAKE_BUILD_TYPE@ BUILD_COMMAND @CMAKE_COMMAND@ --build . --target check-mlir -- -j${N}
COMMAND @CMAKE_COMMAND@ --build . --target check-mlir -- -j${N}
INSTALL_COMMAND "" INSTALL_COMMAND ""
UPDATE_COMMAND "" SOURCE_DIR @MLIR_LLVM_ROOT@
SOURCE_DIR @MLIR_SOURCE_DIR@ BINARY_DIR @MLIR_LLVM_BUILD_DIR@
BINARY_DIR @MLIR_BUILD_DIR@
STAMP_DIR "@MLIR_PROJECT_ROOT@/mlir/stamp" STAMP_DIR "@MLIR_PROJECT_ROOT@/mlir/stamp"
DOWNLOAD_NO_PROGRESS TRUE DOWNLOAD_NO_PROGRESS TRUE
EXCLUDE_FROM_ALL TRUE EXCLUDE_FROM_ALL TRUE
......
...@@ -68,16 +68,16 @@ namespace ...@@ -68,16 +68,16 @@ namespace
{ {
public: public:
/// Initialize the relationship for a number of syms /// Initialize the relationship for a number of syms
void init(std::unordered_set<Value*>& symbols); void init(DenseSet<Value>& symbols);
/// Checks if values a and b can alias /// Checks if values a and b can alias
bool canAlias(Value* a, Value* b); bool canAlias(Value a, Value b);
void insertNoAlias(Value* a, Value* b); void insertNoAlias(Value a, Value b);
private: private:
using BV = llvm::BitVector; using BV = llvm::BitVector;
std::unordered_map<Value*, unsigned> m_valueToIdx; DenseMap<Value, unsigned> m_valueToIdx;
std::unordered_map<unsigned, Value*> m_idxToValue; DenseMap<unsigned, Value> m_idxToValue;
std::unordered_map<Value*, BV*> m_valueToSet; DenseMap<Value, BV*> m_valueToSet;
SmallVector<BV, 10> m_sets; SmallVector<BV, 10> m_sets;
}; };
...@@ -86,16 +86,15 @@ namespace ...@@ -86,16 +86,15 @@ namespace
class LivenessAnalysis class LivenessAnalysis
{ {
public: public:
bool isLive(Value* v); bool isLive(Value v);
void setLive(Value* v); void setLive(Value v);
void kill(Value* v); void kill(Value v);
void getLiveValues(llvm::SmallVectorImpl<Value*>& values); void getLiveValues(llvm::SmallVectorImpl<Value>& values);
void reset();
private: private:
unsigned m_maxIdx = 0; unsigned m_maxIdx = 0;
SmallVector<bool, 10> m_liveness; SmallVector<bool, 10> m_liveness;
std::unordered_map<Value*, unsigned> m_valueToIdx; DenseMap<Value, unsigned> m_valueToIdx;
}; };
// Memory Assignment analysis // Memory Assignment analysis
...@@ -121,7 +120,7 @@ namespace ...@@ -121,7 +120,7 @@ namespace
void processDestructiveInPlace(mlir::Operation* op); void processDestructiveInPlace(mlir::Operation* op);
void processConcat(mlir::Operation* op); void processConcat(mlir::Operation* op);
bool isSafeInPlace(mlir::Operation* op); bool isSafeInPlace(mlir::Operation* op);
bool isInputOrOutputValue(mlir::Value* value); bool isInputOrOutputValue(mlir::Value value);
LivenessAnalysis m_liveness; LivenessAnalysis m_liveness;
AliasRelation m_aliasRelation; AliasRelation m_aliasRelation;
std::unordered_map<std::string, bool> m_inplaceOps; std::unordered_map<std::string, bool> m_inplaceOps;
...@@ -132,7 +131,7 @@ namespace ...@@ -132,7 +131,7 @@ namespace
// helpers // helpers
// Determines the buffer size a value needs based on its type // Determines the buffer size a value needs based on its type
// offset is where that value should start in the buffer // offset is where that value should start in the buffer
static unsigned getBufferSizeForOperand(mlir::Value* value, int offset); static unsigned getBufferSizeForOperand(mlir::Value value, int offset);
// Go backwards over instructions // Go backwards over instructions
// //
...@@ -184,14 +183,14 @@ namespace ...@@ -184,14 +183,14 @@ namespace
auto& block = *(blocks.begin()); auto& block = *(blocks.begin());
// count number of syms in the code and initialize alias relationship // count number of syms in the code and initialize alias relationship
std::unordered_set<Value*> syms; DenseSet<Value> syms;
for (auto it = block.begin(); it != block.end(); it++) for (auto it = block.begin(); it != block.end(); it++)
{ {
Operation* op = &(*it); Operation* op = &(*it);
for (auto it : op->getResults()) for (auto it : op->getResults())
{ {
Value* v = it; Value v = it;
if (syms.find(v) == syms.end()) if (syms.find(v) == syms.end())
{ {
syms.insert(v); syms.insert(v);
...@@ -199,7 +198,7 @@ namespace ...@@ -199,7 +198,7 @@ namespace
} }
for (auto it : op->getOperands()) for (auto it : op->getOperands())
{ {
Value* v = it; Value v = it;
if (syms.find(v) == syms.end()) if (syms.find(v) == syms.end())
{ {
syms.insert(v); syms.insert(v);
...@@ -245,7 +244,7 @@ namespace ...@@ -245,7 +244,7 @@ namespace
// concat on the highest non-one axis // concat on the highest non-one axis
auto concatAxis = concat.concatenation_axis(); auto concatAxis = concat.concatenation_axis();
auto result = concat.getResult(); auto result = concat.getResult();
auto shape = (result->getType().cast<NGTensorType>()).getShape(); auto shape = (result.getType().cast<NGTensorType>()).getShape();
std::vector<int> opndOffsets; std::vector<int> opndOffsets;
BufferInfo bufferInfo; BufferInfo bufferInfo;
int bufferId = -1, baseOffset = 0; int bufferId = -1, baseOffset = 0;
...@@ -288,7 +287,7 @@ namespace ...@@ -288,7 +287,7 @@ namespace
else else
{ {
auto opnd = op->getOperand(i - 1); auto opnd = op->getOperand(i - 1);
auto tensorType = opnd->getType().cast<NGTensorType>(); auto tensorType = opnd.getType().cast<NGTensorType>();
opndOffset += tensorType.getNumElements(); opndOffset += tensorType.getNumElements();
opndOffsets.push_back(opndOffset); opndOffsets.push_back(opndOffset);
} }
...@@ -306,7 +305,7 @@ namespace ...@@ -306,7 +305,7 @@ namespace
for (auto i = 0; i < op->getNumOperands(); i++) for (auto i = 0; i < op->getNumOperands(); i++)
{ {
auto opnd = op->getOperand(i); auto opnd = op->getOperand(i);
auto defOp = opnd->getDefiningOp(); auto defOp = opnd.getDefiningOp();
NGRAPH_CHECK(defOp != nullptr, "Defining operation expected"); NGRAPH_CHECK(defOp != nullptr, "Defining operation expected");
// calculate expected absolute offset in the buffer // calculate expected absolute offset in the buffer
bufferOffset = baseOffset + opndOffsets[i]; bufferOffset = baseOffset + opndOffsets[i];
...@@ -357,7 +356,7 @@ namespace ...@@ -357,7 +356,7 @@ namespace
// For now, assign only if all srcs have no prior assignments // For now, assign only if all srcs have no prior assignments
for (auto opnd : op->getOperands()) for (auto opnd : op->getOperands())
{ {
if (m_memAnalysis->getBufferInfo(opnd->getDefiningOp()).isValid()) if (m_memAnalysis->getBufferInfo(opnd.getDefiningOp()).isValid())
{ {
return; return;
} }
...@@ -381,7 +380,7 @@ namespace ...@@ -381,7 +380,7 @@ namespace
for (auto i = 0; i < op->getNumOperands(); i++) for (auto i = 0; i < op->getNumOperands(); i++)
{ {
auto opnd = op->getOperand(i); auto opnd = op->getOperand(i);
auto defOp = opnd->getDefiningOp(); auto defOp = opnd.getDefiningOp();
NGRAPH_CHECK(defOp != nullptr, "Defining operation expected"); NGRAPH_CHECK(defOp != nullptr, "Defining operation expected");
auto opndOffset = baseOffset + opndOffsets[i]; auto opndOffset = baseOffset + opndOffsets[i];
m_memAnalysis->setBufferInfo(defOp, {bufferId, opndOffset}); m_memAnalysis->setBufferInfo(defOp, {bufferId, opndOffset});
...@@ -392,7 +391,7 @@ namespace ...@@ -392,7 +391,7 @@ namespace
void MemoryAssignment::processDestructiveInPlace(mlir::Operation* op) void MemoryAssignment::processDestructiveInPlace(mlir::Operation* op)
{ {
NGRAPH_CHECK(op->getNumResults() == 1, "Destructive in-place with multi-def ?"); NGRAPH_CHECK(op->getNumResults() == 1, "Destructive in-place with multi-def ?");
Value* use = nullptr; Value use = nullptr;
int useCount = -1; int useCount = -1;
if (isInputOrOutputValue(op->getResult(0))) if (isInputOrOutputValue(op->getResult(0)))
...@@ -405,11 +404,7 @@ namespace ...@@ -405,11 +404,7 @@ namespace
{ {
if (!m_liveness.isLive(opnd) && !isInputOrOutputValue(opnd)) if (!m_liveness.isLive(opnd) && !isInputOrOutputValue(opnd))
{ {
int uses = 0; int uses = std::distance(opnd.getUses().begin(), opnd.getUses().end());
for (auto& i : opnd->getUses())
{
uses++;
}
if (useCount == -1 || uses < useCount) if (useCount == -1 || uses < useCount)
{ {
use = opnd; use = opnd;
...@@ -428,28 +423,28 @@ namespace ...@@ -428,28 +423,28 @@ namespace
// attach a new buffer id, and 0 offset on obth src and result // attach a new buffer id, and 0 offset on obth src and result
bufferInfo = {m_bufferId++, 0}; bufferInfo = {m_bufferId++, 0};
m_memAnalysis->setBufferInfo(op, bufferInfo); m_memAnalysis->setBufferInfo(op, bufferInfo);
m_memAnalysis->setBufferInfo(use->getDefiningOp(), bufferInfo); m_memAnalysis->setBufferInfo(use.getDefiningOp(), bufferInfo);
} }
else else
{ {
// copy result buffer id and offset to src // copy result buffer id and offset to src
m_memAnalysis->setBufferInfo(use->getDefiningOp(), bufferInfo); m_memAnalysis->setBufferInfo(use.getDefiningOp(), bufferInfo);
} }
auto bufferSize = 0; auto bufferSize = 0;
bufferSize = getBufferSizeForOperand(op->getResult(0), bufferInfo.m_offset); bufferSize = getBufferSizeForOperand(op->getResult(0), bufferInfo.m_offset);
m_memAnalysis->setBufferSize(bufferInfo.m_bufferId, bufferSize); m_memAnalysis->setBufferSize(bufferInfo.m_bufferId, bufferSize);
// update aliasing info // update aliasing info
// use value cannot alias any live value // use value cannot alias any live value
SmallVector<Value*, 10> liveValues; SmallVector<Value, 10> liveValues;
m_liveness.getLiveValues(liveValues); m_liveness.getLiveValues(liveValues);
for (auto& value : liveValues) for (auto& value : liveValues)
{ {
m_aliasRelation.insertNoAlias(use, value); m_aliasRelation.insertNoAlias(use, value);
} }
} }
bool MemoryAssignment::isInputOrOutputValue(mlir::Value* value) bool MemoryAssignment::isInputOrOutputValue(mlir::Value value)
{ {
auto defOp = value->getDefiningOp(); auto defOp = value.getDefiningOp();
// If no defining op, then this is a block arg, skip operand // If no defining op, then this is a block arg, skip operand
// //
// TODO: This check is assuming single BB function, improve to handle control-flow. // TODO: This check is assuming single BB function, improve to handle control-flow.
...@@ -464,7 +459,7 @@ namespace ...@@ -464,7 +459,7 @@ namespace
// //
// TODO: Improve to support control flow. Track value use-chain along branches/block-args, // TODO: Improve to support control flow. Track value use-chain along branches/block-args,
// if we hit a use in a return, it is an output value. // if we hit a use in a return, it is an output value.
for (auto& use : value->getUses()) for (auto& use : value.getUses())
{ {
auto useOp = use.getOwner(); auto useOp = use.getOwner();
if (isa<NGReturnOp>(useOp)) if (isa<NGReturnOp>(useOp))
...@@ -482,7 +477,7 @@ namespace ...@@ -482,7 +477,7 @@ namespace
return it != m_inplaceOps.end() ? it->second : false; return it != m_inplaceOps.end() ? it->second : false;
} }
void AliasRelation::init(std::unordered_set<Value*>& symbols) void AliasRelation::init(DenseSet<Value>& symbols)
{ {
unsigned numSyms = symbols.size(); unsigned numSyms = symbols.size();
m_sets.resize(numSyms); m_sets.resize(numSyms);
...@@ -503,13 +498,13 @@ namespace ...@@ -503,13 +498,13 @@ namespace
} }
} }
bool AliasRelation::canAlias(Value* a, Value* b) bool AliasRelation::canAlias(Value a, Value b)
{ {
// check if a and b are in the same set // check if a and b are in the same set
return m_valueToSet[a] != m_valueToSet[b]; return m_valueToSet[a] != m_valueToSet[b];
} }
void AliasRelation::insertNoAlias(Value* a, Value* b) void AliasRelation::insertNoAlias(Value a, Value b)
{ {
// union the two sets that a and b belong to // union the two sets that a and b belong to
// update the maps accordingly // update the maps accordingly
...@@ -535,14 +530,7 @@ namespace ...@@ -535,14 +530,7 @@ namespace
} }
} }
void LivenessAnalysis::reset() void LivenessAnalysis::getLiveValues(llvm::SmallVectorImpl<Value>& values)
{
m_valueToIdx.clear();
m_liveness.clear();
m_maxIdx = 0;
}
void LivenessAnalysis::getLiveValues(llvm::SmallVectorImpl<Value*>& values)
{ {
for (auto& entry : m_valueToIdx) for (auto& entry : m_valueToIdx)
{ {
...@@ -553,7 +541,7 @@ namespace ...@@ -553,7 +541,7 @@ namespace
} }
} }
bool LivenessAnalysis::isLive(Value* v) bool LivenessAnalysis::isLive(Value v)
{ {
auto it = m_valueToIdx.find(v); auto it = m_valueToIdx.find(v);
if (it == m_valueToIdx.end()) if (it == m_valueToIdx.end())
...@@ -563,7 +551,7 @@ namespace ...@@ -563,7 +551,7 @@ namespace
return m_liveness[it->second]; return m_liveness[it->second];
} }
void LivenessAnalysis::setLive(Value* v) void LivenessAnalysis::setLive(Value v)
{ {
auto it = m_valueToIdx.find(v); auto it = m_valueToIdx.find(v);
if (it == m_valueToIdx.end()) if (it == m_valueToIdx.end())
...@@ -578,7 +566,7 @@ namespace ...@@ -578,7 +566,7 @@ namespace
} }
} }
void LivenessAnalysis::kill(Value* v) void LivenessAnalysis::kill(Value v)
{ {
auto it = m_valueToIdx.find(v); auto it = m_valueToIdx.find(v);
if (it == m_valueToIdx.end()) if (it == m_valueToIdx.end())
...@@ -589,9 +577,9 @@ namespace ...@@ -589,9 +577,9 @@ namespace
m_liveness[it->second] = false; m_liveness[it->second] = false;
} }
// helpers // helpers
unsigned getBufferSizeForOperand(mlir::Value* value, int offset) unsigned getBufferSizeForOperand(mlir::Value value, int offset)
{ {
auto tensorType = value->getType().dyn_cast<NGTensorType>(); auto tensorType = value.getType().dyn_cast<NGTensorType>();
NGRAPH_CHECK(tensorType, "Invalid type to find buffer size for"); NGRAPH_CHECK(tensorType, "Invalid type to find buffer size for");
unsigned bufferSize = offset * std::ceil(tensorType.getElementBitWidth() / 8); unsigned bufferSize = offset * std::ceil(tensorType.getElementBitWidth() / 8);
......
...@@ -86,7 +86,7 @@ namespace ...@@ -86,7 +86,7 @@ namespace
} \ } \
\ \
PatternMatchResult matchAndRewrite(Operation* op, \ PatternMatchResult matchAndRewrite(Operation* op, \
ArrayRef<Value*> operands, \ ArrayRef<Value> operands, \
ConversionPatternRewriter& rewriter) const override; \ ConversionPatternRewriter& rewriter) const override; \
}; };
...@@ -104,7 +104,7 @@ namespace ...@@ -104,7 +104,7 @@ namespace
/// Hook for derived classes to implement combined matching and rewriting. /// Hook for derived classes to implement combined matching and rewriting.
PatternMatchResult matchAndRewrite(Operation* op, PatternMatchResult matchAndRewrite(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override ConversionPatternRewriter& rewriter) const override
{ {
auto funcOp = cast<FuncOp>(op); auto funcOp = cast<FuncOp>(op);
...@@ -153,19 +153,19 @@ namespace ...@@ -153,19 +153,19 @@ namespace
// Helpers // Helpers
template <typename RedOp> template <typename RedOp>
void lowerIndexReduction(Operation* op, void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass); DialectLoweringPass& pass);
template <typename OP> template <typename OP>
void lowerBinaryElementwise(Operation* op, void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass); DialectLoweringPass& pass);
template <typename OP> template <typename OP>
void lowerUnaryElementwise(Operation* op, void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass); DialectLoweringPass& pass);
...@@ -184,32 +184,30 @@ namespace ...@@ -184,32 +184,30 @@ namespace
// cLb/Ub : Values representing bounds on channel dim in image (C_IN) // cLb/Ub : Values representing bounds on channel dim in image (C_IN)
// kLb/Ub : Values representing bounds on numFilters dim in filters (C_OUT) // kLb/Ub : Values representing bounds on numFilters dim in filters (C_OUT)
// gId : Value representing induction variable for the outer loop // gId : Value representing induction variable for the outer loop
void lowerConvolution(Value* result, void lowerConvolution(Value result,
Value* images, Value images,
Value* filters, Value filters,
ArrayAttr stridesAttr, ArrayAttr stridesAttr,
ArrayAttr padBelowAttr, ArrayAttr padBelowAttr,
ArrayAttr padAboveAttr, ArrayAttr padAboveAttr,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass, DialectLoweringPass& pass,
Location loc, Location loc,
Value* cLb = nullptr, Value cLb = nullptr,
Value* cUb = nullptr, Value cUb = nullptr,
Value* kLb = nullptr, Value kLb = nullptr,
Value* kUb = nullptr, Value kUb = nullptr,
Value* gId = nullptr); Value gId = nullptr);
template <typename OP> template <typename OP>
void lowerPooling(Operation* op, void lowerPooling(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass); DialectLoweringPass& pass);
ValueHandle createZeroConstant(mlir::Type type); ValueHandle createZeroConstant(mlir::Type type);
ValueHandle createOneConstant(mlir::Type type); ValueHandle createOneConstant(mlir::Type type);
bool isInPlaceConcat(mlir::Operation* op, DialectLoweringPass& pass);
/// Conversion from types in the nGraph dialect to the Standard dialect. /// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter class NGraphTypeConverter : public TypeConverter
{ {
...@@ -228,10 +226,10 @@ namespace ...@@ -228,10 +226,10 @@ namespace
public: public:
void runOnModule() override; void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter); SmallVector<Value, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
/// Allocates a linear buffer for a temporary memref that shares its /// Allocates a linear buffer for a temporary memref that shares its
/// underlying memory. Used in conjunction with createTempMemref /// underlying memory. Used in conjunction with createTempMemref
Value* createTempBuffer(int bufferId, PatternRewriter& rewriter); Value createTempBuffer(int bufferId, PatternRewriter& rewriter);
/// Creates an allocation or view of a memref. /// Creates an allocation or view of a memref.
/// type MemRef Type /// type MemRef Type
/// buffer Optional buffer value to create view over /// buffer Optional buffer value to create view over
...@@ -239,8 +237,7 @@ namespace ...@@ -239,8 +237,7 @@ namespace
/// ///
/// If buffer is null it allocates a Memref directly and Offset is ignored. /// If buffer is null it allocates a Memref directly and Offset is ignored.
/// If not, it creates a view over the pre-allocated buffer at the given offset. /// If not, it creates a view over the pre-allocated buffer at the given offset.
Value* Value createTempMemref(Type type, Value buffer, unsigned offset, PatternRewriter& rewriter);
createTempMemref(Type type, Value* buffer, unsigned offset, PatternRewriter& rewriter);
/// Inserts dealloc Ops for each temporary allocated by AllocOp /// Inserts dealloc Ops for each temporary allocated by AllocOp
void insertDeallocs(PatternRewriter& rewriter); void insertDeallocs(PatternRewriter& rewriter);
NGraphTypeConverter& getTypeConverter() { return typeConverter; } NGraphTypeConverter& getTypeConverter() { return typeConverter; }
...@@ -261,11 +258,11 @@ namespace ...@@ -261,11 +258,11 @@ namespace
private: private:
NGraphTypeConverter typeConverter; NGraphTypeConverter typeConverter;
// List of temporary memrefs to deallocate at end of function // List of temporary memrefs to deallocate at end of function
SmallVector<Value*, 4> memRefsToDealloc; SmallVector<Value, 4> memRefsToDealloc;
// Ops maybe assigned mem-refs in previous memory optimization passes. // Ops maybe assigned mem-refs in previous memory optimization passes.
// Track pre-assigned buffers for each Value and re-use it if one is available. // Track pre-assigned buffers for each Value and re-use it if one is available.
using IdToMemRefMap = std::unordered_map<unsigned, Value*>; using IdToMemRefMap = std::unordered_map<unsigned, Value>;
IdToMemRefMap m_id_to_memref; IdToMemRefMap m_id_to_memref;
MemoryAnalysis* m_memAnalysis; MemoryAnalysis* m_memAnalysis;
// TODO: Workaround for findOutputValues and buildOutputDefs. See NGCPU-470. // TODO: Workaround for findOutputValues and buildOutputDefs. See NGCPU-470.
...@@ -344,7 +341,7 @@ namespace ...@@ -344,7 +341,7 @@ namespace
FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName); FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found"); NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found");
SmallVector<Value*, 4> outputList; SmallVector<Value, 4> outputList;
unsigned outputCount = 0; unsigned outputCount = 0;
unsigned inputCount = f.getType().getNumInputs(); unsigned inputCount = f.getType().getNumInputs();
// we find out output values by looking at returned values // we find out output values by looking at returned values
...@@ -354,7 +351,7 @@ namespace ...@@ -354,7 +351,7 @@ namespace
{ {
// annotate instructions defining outputs with the arg idx of the output // annotate instructions defining outputs with the arg idx of the output
auto outputValue = ret.getOperand(i); auto outputValue = ret.getOperand(i);
auto op = outputValue->getDefiningOp(); auto op = outputValue.getDefiningOp();
op->setAttr( op->setAttr(
"graphOutputIdx", "graphOutputIdx",
...@@ -365,13 +362,13 @@ namespace ...@@ -365,13 +362,13 @@ namespace
}); });
} }
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op, SmallVector<Value, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
PatternRewriter& rewriter) PatternRewriter& rewriter)
{ {
FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName); FuncOp f = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found"); NGRAPH_CHECK(f, "FuncOp '" + funcName + "' not found");
SmallVector<Value*, 4> newResults; SmallVector<Value, 4> newResults;
for (auto origResult : op->getResults()) for (auto origResult : op->getResults())
{ {
// find output arg if this operation produces any sub-graph outputs // find output arg if this operation produces any sub-graph outputs
...@@ -390,11 +387,11 @@ namespace ...@@ -390,11 +387,11 @@ namespace
// the linear buffer. // the linear buffer.
// If two memrefs are defined via 2 Views over the same buffer, then they share and // If two memrefs are defined via 2 Views over the same buffer, then they share and
// will re-use the same buffer. // will re-use the same buffer.
auto tensorType = origResult->getType().cast<NGTensorType>(); auto tensorType = origResult.getType().cast<NGTensorType>();
Value* newResult = nullptr; Value newResult = nullptr;
auto bufferInfo = m_memAnalysis->getBufferInfo(op); auto bufferInfo = m_memAnalysis->getBufferInfo(op);
Type memRefType = typeConverter.convertType(tensorType); Type memRefType = typeConverter.convertType(tensorType);
Value* bufferValue = nullptr; Value bufferValue = nullptr;
if (!bufferInfo.isValid()) if (!bufferInfo.isValid())
{ {
...@@ -427,7 +424,7 @@ namespace ...@@ -427,7 +424,7 @@ namespace
return newResults; return newResults;
} }
Value* DialectLoweringPass::createTempBuffer(int bufferId, PatternRewriter& rewriter) Value DialectLoweringPass::createTempBuffer(int bufferId, PatternRewriter& rewriter)
{ {
unsigned sizeInBytes = getMemAnalysis()->getBufferSize(bufferId); unsigned sizeInBytes = getMemAnalysis()->getBufferSize(bufferId);
NGRAPH_CHECK(bufferId >= 0, "Invalid buffer id to allocate"); NGRAPH_CHECK(bufferId >= 0, "Invalid buffer id to allocate");
...@@ -438,7 +435,7 @@ namespace ...@@ -438,7 +435,7 @@ namespace
MemRefType::get({sizeInBytes}, IntegerType::get(8, rewriter.getContext()), {}); MemRefType::get({sizeInBytes}, IntegerType::get(8, rewriter.getContext()), {});
// TODO: Set alignment // TODO: Set alignment
Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), bufferType); Value alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), bufferType);
memRefsToDealloc.push_back(alloc); memRefsToDealloc.push_back(alloc);
...@@ -455,8 +452,8 @@ namespace ...@@ -455,8 +452,8 @@ namespace
return alloc; return alloc;
} }
Value* DialectLoweringPass::createTempMemref(Type type, Value DialectLoweringPass::createTempMemref(Type type,
Value* buffer, Value buffer,
unsigned offset, unsigned offset,
PatternRewriter& rewriter) PatternRewriter& rewriter)
{ {
...@@ -481,14 +478,14 @@ namespace ...@@ -481,14 +478,14 @@ namespace
auto map = makeStridedLinearLayoutMap(strides, offset, rewriter.getContext()); auto map = makeStridedLinearLayoutMap(strides, offset, rewriter.getContext());
MemRefType newMemRefType = MemRefType::get(shape, memRefType.getElementType(), map); MemRefType newMemRefType = MemRefType::get(shape, memRefType.getElementType(), map);
auto viewOp = rewriter.create<mlir::ViewOp>( auto viewOp = rewriter.create<mlir::ViewOp>(
buffer->getDefiningOp()->getLoc(), newMemRefType, buffer, llvm::None); buffer.getDefiningOp()->getLoc(), newMemRefType, buffer, llvm::None);
return viewOp.getResult(); return viewOp.getResult();
} }
// No buffer, create an atomic memref without underlying buffer // No buffer, create an atomic memref without underlying buffer
NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported"); NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported");
Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType); Value alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType);
memRefsToDealloc.push_back(alloc); memRefsToDealloc.push_back(alloc);
return alloc; return alloc;
} }
...@@ -501,9 +498,9 @@ namespace ...@@ -501,9 +498,9 @@ namespace
NGRAPH_CHECK(func, "FuncOp '" + funcName + "' not found"); NGRAPH_CHECK(func, "FuncOp '" + funcName + "' not found");
unsigned int argIdx = 0; unsigned int argIdx = 0;
for (auto* arg : func.getArguments()) for (auto arg : func.getArguments())
{ {
if (arg->getType().isa<MemRefType>()) if (arg.getType().isa<MemRefType>())
{ {
func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext())); func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
} }
...@@ -526,7 +523,6 @@ namespace ...@@ -526,7 +523,6 @@ namespace
PatternRewriter& rewriter) PatternRewriter& rewriter)
{ {
auto module = getModule(); auto module = getModule();
auto* context = getModule().getContext();
auto callBackFunc = module.lookupSymbol<mlir::FuncOp>(name); auto callBackFunc = module.lookupSymbol<mlir::FuncOp>(name);
if (!callBackFunc) if (!callBackFunc)
{ {
...@@ -583,7 +579,7 @@ namespace ...@@ -583,7 +579,7 @@ namespace
#define REWRITER(OP) \ #define REWRITER(OP) \
PatternMatchResult OP##Conversion::matchAndRewrite( \ PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) const Operation* op, ArrayRef<Value> operands, ConversionPatternRewriter& rewriter) const
REWRITER(NGAddOp) REWRITER(NGAddOp)
{ {
...@@ -675,12 +671,12 @@ namespace ...@@ -675,12 +671,12 @@ namespace
auto loc = cast<NGReluOp>(op).getLoc(); auto loc = cast<NGReluOp>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0]; auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>()); NGRAPH_CHECK(result.getType().isa<MemRefType>());
// Note that builder's current function is still the original function body. // Note that builder's current function is still the original function body.
// use getBlock to get the new block instead. // use getBlock to get the new block instead.
// get new operands // get new operands
Value* lhs = operands[0]; Value lhs = operands[0];
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Views // Views
...@@ -696,8 +692,8 @@ namespace ...@@ -696,8 +692,8 @@ namespace
// Steps // Steps
auto steps = vLHS.getSteps(); auto steps = vLHS.getSteps();
NGRAPH_CHECK(lhs->getType().isa<MemRefType>()); NGRAPH_CHECK(lhs.getType().isa<MemRefType>());
Type elemTy = lhs->getType().dyn_cast<MemRefType>().getElementType(); Type elemTy = lhs.getType().dyn_cast<MemRefType>().getElementType();
AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] { AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs); ValueHandle val = iLHS(ivs);
...@@ -723,14 +719,14 @@ namespace ...@@ -723,14 +719,14 @@ namespace
// Retrieve/generate Values for operands and result. // Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* lhs = operands[0]; Value lhs = operands[0];
Value* rhs = operands[1]; Value rhs = operands[1];
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp"); NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in DotOp");
auto resultTy = result->getType().dyn_cast<MemRefType>(); auto resultTy = result.getType().dyn_cast<MemRefType>();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>(); auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto rhsTy = rhs->getType().dyn_cast<MemRefType>(); auto rhsTy = rhs.getType().dyn_cast<MemRefType>();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type"); NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type"); NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
NGRAPH_CHECK(rhsTy, "Unexpected non-memref RHS type"); NGRAPH_CHECK(rhsTy, "Unexpected non-memref RHS type");
...@@ -792,7 +788,7 @@ namespace ...@@ -792,7 +788,7 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Create Value for result, and extract type info. // Create Value for result, and extract type info.
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp"); NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
// Create view to write into result. // Create view to write into result.
...@@ -870,11 +866,11 @@ namespace ...@@ -870,11 +866,11 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Get operands // Get operands
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in GatherOp"); NGRAPH_CHECK(result, "Unexpected null result in GatherOp");
Value* params = operands[0]; Value params = operands[0];
Value* indices = operands[1]; Value indices = operands[1];
auto axis = gatherOp.axis().getSExtValue(); auto axis = gatherOp.axis().getSExtValue();
// Create view to write into result. // Create view to write into result.
...@@ -987,10 +983,10 @@ namespace ...@@ -987,10 +983,10 @@ namespace
auto convolOp = cast<NGConvolutionOp>(op); auto convolOp = cast<NGConvolutionOp>(op);
// Get operands // Get operands
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in Convolution Op"); NGRAPH_CHECK(result, "Unexpected null result in Convolution Op");
Value* images = operands[0]; Value images = operands[0];
Value* filters = operands[1]; Value filters = operands[1];
auto strides = convolOp.strides(); auto strides = convolOp.strides();
auto padBelow = convolOp.padBelow(); auto padBelow = convolOp.padBelow();
auto padAbove = convolOp.padBelow(); auto padAbove = convolOp.padBelow();
...@@ -1014,10 +1010,10 @@ namespace ...@@ -1014,10 +1010,10 @@ namespace
auto gConvOp = cast<NGGroupConvOp>(op); auto gConvOp = cast<NGGroupConvOp>(op);
ScopedContext scope(rewriter, gConvOp.getLoc()); ScopedContext scope(rewriter, gConvOp.getLoc());
// Get operands // Get operands
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in Convolution Op"); NGRAPH_CHECK(result, "Unexpected null result in Convolution Op");
Value* images = operands[0]; Value images = operands[0];
Value* filters = operands[1]; Value filters = operands[1];
auto strides = gConvOp.strides(); auto strides = gConvOp.strides();
auto padBelow = gConvOp.padBelow(); auto padBelow = gConvOp.padBelow();
auto padAbove = gConvOp.padBelow(); auto padAbove = gConvOp.padBelow();
...@@ -1030,10 +1026,9 @@ namespace ...@@ -1030,10 +1026,9 @@ namespace
ValueHandle lb = intrinsics::constant_index(0); ValueHandle lb = intrinsics::constant_index(0);
ValueHandle ub = intrinsics::constant_index(groups); ValueHandle ub = intrinsics::constant_index(groups);
ValueHandle step = intrinsics::constant_index(1);
auto imagesType = images->getType().cast<MemRefType>(); auto imagesType = images.getType().cast<MemRefType>();
auto filtersType = filters->getType().cast<MemRefType>(); auto filtersType = filters.getType().cast<MemRefType>();
auto imagesShape = imagesType.getShape(); auto imagesShape = imagesType.getShape();
auto filtersShape = filtersType.getShape(); auto filtersShape = filtersType.getShape();
...@@ -1086,8 +1081,8 @@ namespace ...@@ -1086,8 +1081,8 @@ namespace
} }
// Use callback: Pooling, MatMul, Gemm, Softmax // Use callback: Pooling, MatMul, Gemm, Softmax
static void castMemRef(SmallVector<mlir::Value*, 4> inputs, static void castMemRef(SmallVector<mlir::Value, 4>& inputs,
SmallVector<mlir::Value*, 4>& outputs, SmallVector<mlir::Value, 4>& outputs,
PatternRewriter& rewriter, PatternRewriter& rewriter,
UnrankedMemRefType type) UnrankedMemRefType type)
{ {
...@@ -1123,22 +1118,21 @@ namespace ...@@ -1123,22 +1118,21 @@ namespace
// Retrieve/generate Values for operands and result. // Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* src = operands[0]; Value src = operands[0];
Value* delta = operands[1]; Value delta = operands[1];
ArrayRef<Attribute> windowShape = pooling.windowShape().getValue(); ArrayRef<Attribute> windowShape = pooling.windowShape().getValue();
ArrayRef<Attribute> windowStrides = pooling.windowMovementStrides().getValue(); ArrayRef<Attribute> windowStrides = pooling.windowMovementStrides().getValue();
ArrayRef<Attribute> padBelow = pooling.padBelow().getValue(); ArrayRef<Attribute> padBelow = pooling.padBelow().getValue();
ArrayRef<Attribute> padAbove = pooling.padAbove().getValue(); ArrayRef<Attribute> padAbove = pooling.padAbove().getValue();
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(src && delta && result, "Unexpected null values in MaxPoolBackprop Op"); NGRAPH_CHECK(src && delta && result, "Unexpected null values in MaxPoolBackprop Op");
auto resultTy = result->getType().dyn_cast<MemRefType>(); auto resultTy = result.getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape(); auto resultShape = resultTy.getShape();
auto srcTy = src->getType().dyn_cast<MemRefType>(); auto srcTy = src.getType().dyn_cast<MemRefType>();
auto srcShape = srcTy.getShape(); auto srcShape = srcTy.getShape();
auto deltaTy = delta->getType().dyn_cast<MemRefType>(); auto deltaTy = delta.getType().dyn_cast<MemRefType>();
auto deltaShape = deltaTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type"); NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(srcTy, "Unexpected non-memref src type"); NGRAPH_CHECK(srcTy, "Unexpected non-memref src type");
NGRAPH_CHECK(deltaTy, "Unexpected non-memref delta type"); NGRAPH_CHECK(deltaTy, "Unexpected non-memref delta type");
...@@ -1153,8 +1147,8 @@ namespace ...@@ -1153,8 +1147,8 @@ namespace
auto int64Ty = rewriter.getIntegerType(64); auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0); auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
SmallVector<mlir::Value*, 4> inputs = {src, delta, result}; SmallVector<mlir::Value, 4> inputs = {src, delta, result};
SmallVector<mlir::Value*, 4> outputs; SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy); castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc = pass.getCallDecl( FuncOp callBackFunc = pass.getCallDecl(
...@@ -1177,7 +1171,6 @@ namespace ...@@ -1177,7 +1171,6 @@ namespace
} }
else if (srcShape.size() == 5) else if (srcShape.size() == 5)
{ {
opAttrs attrs;
attrs.poolAttrs3d.includePaddingInAvgComputation = false; attrs.poolAttrs3d.includePaddingInAvgComputation = false;
for (auto i = 0; i < 3; i++) for (auto i = 0; i < 3; i++)
{ {
...@@ -1192,7 +1185,7 @@ namespace ...@@ -1192,7 +1185,7 @@ namespace
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64); rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>( auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MAXPOOLBACKPROP), 64); rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MAXPOOLBACKPROP), 64);
SmallVector<mlir::Value*, 4> args = { SmallVector<mlir::Value, 4> args = {
outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg}; outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args); rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
...@@ -1207,16 +1200,16 @@ namespace ...@@ -1207,16 +1200,16 @@ namespace
// Retrieve/generate Values for operands and result. // Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* lhs = operands[0]; Value lhs = operands[0];
Value* rhs = operands[1]; Value rhs = operands[1];
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in MatMulOp"); NGRAPH_CHECK(lhs && rhs && result, "Unexpected null values in MatMulOp");
auto resultTy = result->getType().dyn_cast<MemRefType>(); auto resultTy = result.getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape(); auto resultShape = resultTy.getShape();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>(); auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto lhsShape = lhsTy.getShape(); auto lhsShape = lhsTy.getShape();
auto rhsTy = rhs->getType().dyn_cast<MemRefType>(); auto rhsTy = rhs.getType().dyn_cast<MemRefType>();
auto rhsShape = rhsTy.getShape(); auto rhsShape = rhsTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type"); NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type"); NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
...@@ -1262,10 +1255,10 @@ namespace ...@@ -1262,10 +1255,10 @@ namespace
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64); rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>( auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MATMUL), 64); rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MATMUL), 64);
SmallVector<mlir::Value*, 4> inputs = {lhs, rhs, result}; SmallVector<mlir::Value, 4> inputs = {lhs, rhs, result};
SmallVector<mlir::Value*, 4> outputs; SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy); castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value*, 4> args = { SmallVector<mlir::Value, 4> args = {
outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg}; outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args); rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
...@@ -1281,18 +1274,18 @@ namespace ...@@ -1281,18 +1274,18 @@ namespace
// Retrieve/generate Values for operands and result. // Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* lhs = operands[0]; Value lhs = operands[0];
Value* rhs = operands[1]; Value rhs = operands[1];
Value* bias = operands[2]; Value bias = operands[2];
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && rhs && bias && result, "Unexpected null values in GemmOp"); NGRAPH_CHECK(lhs && rhs && bias && result, "Unexpected null values in GemmOp");
auto resultTy = result->getType().dyn_cast<MemRefType>(); auto resultTy = result.getType().dyn_cast<MemRefType>();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>(); auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto lhsShape = lhsTy.getShape(); auto lhsShape = lhsTy.getShape();
auto rhsTy = rhs->getType().dyn_cast<MemRefType>(); auto rhsTy = rhs.getType().dyn_cast<MemRefType>();
auto rhsShape = rhsTy.getShape(); auto rhsShape = rhsTy.getShape();
auto biasTy = bias->getType().dyn_cast<MemRefType>(); auto biasTy = bias.getType().dyn_cast<MemRefType>();
auto biasShape = biasTy.getShape(); auto biasShape = biasTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type"); NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type"); NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
...@@ -1382,10 +1375,10 @@ namespace ...@@ -1382,10 +1375,10 @@ namespace
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64); rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>( auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::GEMM), 64); rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::GEMM), 64);
SmallVector<mlir::Value*, 4> inputs = {lhs, rhs, bias, result}; SmallVector<mlir::Value, 4> inputs = {lhs, rhs, bias, result};
SmallVector<mlir::Value*, 4> outputs; SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy); castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value*, 4> args = { SmallVector<mlir::Value, 4> args = {
outputs[0], outputs[1], outputs[2], outputs[3], attrsIndexArg, opTypeArg}; outputs[0], outputs[1], outputs[2], outputs[3], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args); rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
...@@ -1401,13 +1394,13 @@ namespace ...@@ -1401,13 +1394,13 @@ namespace
// Retrieve/generate Values for operands and result. // Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* lhs = operands[0]; Value lhs = operands[0];
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && result, "Unexpected null values in SoftmaxOp"); NGRAPH_CHECK(lhs && result, "Unexpected null values in SoftmaxOp");
auto resultTy = result->getType().dyn_cast<MemRefType>(); auto resultTy = result.getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape(); auto resultShape = resultTy.getShape();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>(); auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto lhsShape = lhsTy.getShape(); auto lhsShape = lhsTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type"); NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type"); NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
...@@ -1436,10 +1429,10 @@ namespace ...@@ -1436,10 +1429,10 @@ namespace
{}, {},
rewriter); rewriter);
SmallVector<mlir::Value*, 4> inputs = {lhs, result}; SmallVector<mlir::Value, 4> inputs = {lhs, result};
SmallVector<mlir::Value*, 4> outputs; SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy); castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value*, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg}; SmallVector<mlir::Value, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args); rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
...@@ -1450,20 +1443,20 @@ namespace ...@@ -1450,20 +1443,20 @@ namespace
#undef REWRITER #undef REWRITER
/// End of pattern matchers /// End of pattern matchers
void lowerConvolution(Value* result, void lowerConvolution(Value result,
Value* images, Value images,
Value* filters, Value filters,
ArrayAttr stridesAttr, ArrayAttr stridesAttr,
ArrayAttr padBelowAttr, ArrayAttr padBelowAttr,
ArrayAttr padAboveAttr, ArrayAttr padAboveAttr,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass, DialectLoweringPass& pass,
Location loc, Location loc,
Value* cLb, Value cLb,
Value* cUb, Value cUb,
Value* kLb, Value kLb,
Value* kUb, Value kUb,
Value* gId) Value gId)
{ {
// Let Images shape be [N, C_IN, D_1, ... D_f] // Let Images shape be [N, C_IN, D_1, ... D_f]
// Let Filters shape be [C_OUT, C_IN, F_1, ... F_f] // Let Filters shape be [C_OUT, C_IN, F_1, ... F_f]
...@@ -1516,7 +1509,7 @@ namespace ...@@ -1516,7 +1509,7 @@ namespace
auto strides = stridesAttr.getValue(); auto strides = stridesAttr.getValue();
auto padBelow = padBelowAttr.getValue(); auto padBelow = padBelowAttr.getValue();
auto padAbove = padBelowAttr.getValue(); auto padAbove = padBelowAttr.getValue();
Type elemTy = images->getType().cast<MemRefType>().getElementType(); Type elemTy = images.getType().cast<MemRefType>().getElementType();
// Create views // Create views
MemRefView vRes(result), vImages(images), vFilters(filters); MemRefView vRes(result), vImages(images), vFilters(filters);
...@@ -1767,7 +1760,7 @@ namespace ...@@ -1767,7 +1760,7 @@ namespace
// if args : img dims, img lbs, img ubs // if args : img dims, img lbs, img ubs
SmallVector<IndexHandle, 4>::iterator it = imgIndices.begin(); SmallVector<IndexHandle, 4>::iterator it = imgIndices.begin();
std::advance(it, 2); std::advance(it, 2);
SmallVector<Value*, 4> affineIfArgs(it, imgIndices.end()); SmallVector<Value, 4> affineIfArgs(it, imgIndices.end());
affineIfArgs.insert( affineIfArgs.insert(
affineIfArgs.end(), imgSpatialLbs.begin(), imgSpatialLbs.end()); affineIfArgs.end(), imgSpatialLbs.begin(), imgSpatialLbs.end());
affineIfArgs.insert( affineIfArgs.insert(
...@@ -1811,19 +1804,19 @@ namespace ...@@ -1811,19 +1804,19 @@ namespace
template <typename OP> template <typename OP>
void lowerUnaryElementwise(Operation* op, void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass) DialectLoweringPass& pass)
{ {
auto loc = cast<OP>(op).getLoc(); auto loc = cast<OP>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0]; auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>()); NGRAPH_CHECK(result.getType().isa<MemRefType>());
// Note that builder's current function is still the original function body. // Note that builder's current function is still the original function body.
// use getBlock to get the new block instead. // use getBlock to get the new block instead.
// get new operands // get new operands
Value* lhs = operands[0]; Value lhs = operands[0];
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Views // Views
...@@ -1839,8 +1832,8 @@ namespace ...@@ -1839,8 +1832,8 @@ namespace
// Steps // Steps
auto steps = vLHS.getSteps(); auto steps = vLHS.getSteps();
NGRAPH_CHECK(lhs->getType().isa<MemRefType>()); NGRAPH_CHECK(lhs.getType().isa<MemRefType>());
Type elemTy = lhs->getType().cast<MemRefType>().getElementType(); Type elemTy = lhs.getType().cast<MemRefType>().getElementType();
AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] { AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs); ValueHandle val = iLHS(ivs);
...@@ -1860,16 +1853,16 @@ namespace ...@@ -1860,16 +1853,16 @@ namespace
template <typename OP> template <typename OP>
void lowerBinaryElementwise(Operation* op, void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass) DialectLoweringPass& pass)
{ {
auto loc = cast<OP>(op).getLoc(); auto loc = cast<OP>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0]; auto result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>()); NGRAPH_CHECK(result.getType().isa<MemRefType>());
// get new operands // get new operands
Value* lhs = operands[0]; Value lhs = operands[0];
Value* rhs = operands[1]; Value rhs = operands[1];
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Views // Views
...@@ -1885,7 +1878,7 @@ namespace ...@@ -1885,7 +1878,7 @@ namespace
// Steps // Steps
auto steps = vLHS.getSteps(); auto steps = vLHS.getSteps();
// element type of the operand // element type of the operand
Type elemTy = result->getType().cast<MemRefType>().getElementType(); Type elemTy = result.getType().cast<MemRefType>().getElementType();
AffineLoopNestBuilder(pivs, lbs, ubs, steps)( AffineLoopNestBuilder(pivs, lbs, ubs, steps)(
// single stmt body // single stmt body
[&] { [&] {
...@@ -1976,7 +1969,7 @@ namespace ...@@ -1976,7 +1969,7 @@ namespace
template <typename RedOp> template <typename RedOp>
void lowerIndexReduction(Operation* op, void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass) DialectLoweringPass& pass)
{ {
...@@ -1996,9 +1989,9 @@ namespace ...@@ -1996,9 +1989,9 @@ namespace
// Retrieve/generate Values for operands and result. // Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* arg = operands[0]; Value arg = operands[0];
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
// Views // Views
MemRefView vRes(result), vArg(arg); MemRefView vRes(result), vArg(arg);
...@@ -2011,7 +2004,7 @@ namespace ...@@ -2011,7 +2004,7 @@ namespace
auto argLbs = vArg.getLbs(); auto argLbs = vArg.getLbs();
auto argUbs = vArg.getUbs(); auto argUbs = vArg.getUbs();
Type resTy = result->getType().cast<MemRefType>().getElementType(); Type resTy = result.getType().cast<MemRefType>().getElementType();
// Generate loop nest that initializes result to lower bound of the axis to be reduced. // Generate loop nest that initializes result to lower bound of the axis to be reduced.
{ {
auto ivs = makeIndexHandles(vRes.rank()); auto ivs = makeIndexHandles(vRes.rank());
...@@ -2029,7 +2022,7 @@ namespace ...@@ -2029,7 +2022,7 @@ namespace
auto steps = vArg.getSteps(); auto steps = vArg.getSteps();
SmallVector<IndexHandle, 8> nonRedIVs; SmallVector<IndexHandle, 8> nonRedIVs;
Type resTy = result->getType().cast<MemRefType>().getElementType(); Type resTy = result.getType().cast<MemRefType>().getElementType();
NGRAPH_CHECK(resTy.isa<IntegerType>(), NGRAPH_CHECK(resTy.isa<IntegerType>(),
"Expected integer result type in index reduction"); "Expected integer result type in index reduction");
...@@ -2069,7 +2062,7 @@ namespace ...@@ -2069,7 +2062,7 @@ namespace
template <typename OP> template <typename OP>
void lowerPooling(Operation* op, void lowerPooling(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass) DialectLoweringPass& pass)
{ {
...@@ -2078,18 +2071,18 @@ namespace ...@@ -2078,18 +2071,18 @@ namespace
// Retrieve/generate Values for operands and result. // Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* lhs = operands[0]; Value lhs = operands[0];
ArrayRef<Attribute> windowShape = pooling.windowShape().getValue(); ArrayRef<Attribute> windowShape = pooling.windowShape().getValue();
ArrayRef<Attribute> windowStrides = pooling.windowMovementStrides().getValue(); ArrayRef<Attribute> windowStrides = pooling.windowMovementStrides().getValue();
ArrayRef<Attribute> padBelow = pooling.padBelow().getValue(); ArrayRef<Attribute> padBelow = pooling.padBelow().getValue();
ArrayRef<Attribute> padAbove = pooling.padAbove().getValue(); ArrayRef<Attribute> padAbove = pooling.padAbove().getValue();
Value* result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(lhs && result, "Unexpected null values in Pooling Op"); NGRAPH_CHECK(lhs && result, "Unexpected null values in Pooling Op");
auto resultTy = result->getType().dyn_cast<MemRefType>(); auto resultTy = result.getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape(); auto resultShape = resultTy.getShape();
auto lhsTy = lhs->getType().dyn_cast<MemRefType>(); auto lhsTy = lhs.getType().dyn_cast<MemRefType>();
auto lhsShape = lhsTy.getShape(); auto lhsShape = lhsTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type"); NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type"); NGRAPH_CHECK(lhsTy, "Unexpected non-memref LHS type");
...@@ -2120,8 +2113,8 @@ namespace ...@@ -2120,8 +2113,8 @@ namespace
} }
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0); auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
SmallVector<mlir::Value*, 4> inputs = {lhs, result}; SmallVector<mlir::Value, 4> inputs = {lhs, result};
SmallVector<mlir::Value*, 4> outputs; SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy); castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc = FuncOp callBackFunc =
...@@ -2158,7 +2151,7 @@ namespace ...@@ -2158,7 +2151,7 @@ namespace
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64); rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>( auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(ty), 64); rewriter.getUnknownLoc(), static_cast<int64_t>(ty), 64);
SmallVector<mlir::Value*, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg}; SmallVector<mlir::Value, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args); rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
...@@ -2211,71 +2204,6 @@ namespace ...@@ -2211,71 +2204,6 @@ namespace
} }
NGRAPH_UNREACHABLE("Unsupported type"); NGRAPH_UNREACHABLE("Unsupported type");
} }
// Given a concat op, it will check if dst and operands have
// a valid buffer/offset assignment that will make this op
// valid in-place
bool isInPlaceConcat(mlir::Operation* op, DialectLoweringPass& pass)
{
NGRAPH_CHECK(isa<NGConcatOp>(op), "Expecting concat operation");
auto concat = cast<NGConcatOp>(op);
auto concatAxis = concat.concatenation_axis();
auto result = concat.getResult();
auto shape = (result->getType().cast<NGTensorType>()).getShape();
auto memAnalysis = pass.getMemAnalysis();
BufferInfo bufferInfo = memAnalysis->getBufferInfo(op);
if (!bufferInfo.isValid())
{
// no buffer assignment to dst, nothing to do
return false;
}
auto dstBufferId = bufferInfo.m_bufferId;
auto dstOffset = bufferInfo.m_offset;
LLVM_DEBUG(llvm::dbgs() << ">> Check in-place concat\n");
LLVM_DEBUG(op->dump());
for (auto i = 0; i < shape.size(); i++)
{
if (i == concatAxis)
{
break;
}
if (shape[i] != 1)
{
LLVM_DEBUG(llvm::dbgs() << "Axis FAIL. Skipping instruction\n");
return false;
}
}
LLVM_DEBUG(llvm::dbgs() << "Axis OK\n");
// Check if the buffer id and offsets are consistent with what's exepcted
LLVM_DEBUG(llvm::dbgs() << "Dst (id, offset) = (" << dstBufferId << ", " << dstOffset
<< ")\n");
// relative offset in the buffer
int opndOffset = 0;
for (auto opnd : op->getOperands())
{
bufferInfo = memAnalysis->getBufferInfo(opnd->getDefiningOp());
auto srcBufferId = bufferInfo.m_bufferId;
auto srcOffset = bufferInfo.m_offset;
LLVM_DEBUG(llvm::dbgs() << "Src (id, offset) = (" << srcBufferId << ", " << srcOffset
<< ")\n");
if (!bufferInfo.isValid() || srcBufferId != dstBufferId ||
srcOffset != (opndOffset + dstOffset))
{
// mismatch in buffer IDs or offsets
LLVM_DEBUG(llvm::dbgs() << "Buffer ID and Offsets FAIL. Skipping instruction\n");
return false;
}
auto tensorType = opnd->getType().cast<NGTensorType>();
opndOffset += tensorType.getNumElements();
}
LLVM_DEBUG(llvm::dbgs() << "Buffer ID and Offsets OK\n");
return true;
}
} // namespace } // namespace
namespace mlir namespace mlir
......
...@@ -81,7 +81,6 @@ mlir::Type NGraphOpsDialect::parseEltType(mlir::DialectAsmParser& parser) const ...@@ -81,7 +81,6 @@ mlir::Type NGraphOpsDialect::parseEltType(mlir::DialectAsmParser& parser) const
{ {
// Process nGraph integer element types. // Process nGraph integer element types.
MLIRContext* context = getContext(); MLIRContext* context = getContext();
int width = 0;
bool isSigned = false; bool isSigned = false;
llvm::SMLoc loc = parser.getCurrentLocation(); llvm::SMLoc loc = parser.getCurrentLocation();
......
...@@ -171,7 +171,7 @@ def NGRNNCellOp : ...@@ -171,7 +171,7 @@ def NGRNNCellOp :
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res," "Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t, " "Value X, Value W, Value R, Value H_t, "
"Attribute hiddenSize, ArrayAttr activations," "Attribute hiddenSize, ArrayAttr activations,"
"ArrayAttr activationAlpha, ArrayAttr activationBeta, Attribute clip", [{ "ArrayAttr activationAlpha, ArrayAttr activationBeta, Attribute clip", [{
tblgen_state.addOperands({X, W, R, H_t}); tblgen_state.addOperands({X, W, R, H_t});
...@@ -192,7 +192,7 @@ def NGRNNCellOp : ...@@ -192,7 +192,7 @@ def NGRNNCellOp :
void setClip(const Attribute& attr) { this->setAttr("clip", attr); } void setClip(const Attribute& attr) { this->setAttr("clip", attr); }
// get bias operand if present // get bias operand if present
Value* B() Value B()
{ {
auto varArgs = optionalArgs(); auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr; return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
...@@ -263,7 +263,7 @@ def NGMVN : ...@@ -263,7 +263,7 @@ def NGMVN :
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res," "Builder *builder, OperationState &tblgen_state, Type res,"
"Value *data, ArrayAttr reductionAxes, Attribute normalizeVariance," "Value data, ArrayAttr reductionAxes, Attribute normalizeVariance,"
"Attribute eps", [{ "Attribute eps", [{
tblgen_state.addOperands(data); tblgen_state.addOperands(data);
tblgen_state.addAttribute("reductionAxes", reductionAxes); tblgen_state.addAttribute("reductionAxes", reductionAxes);
...@@ -363,7 +363,7 @@ def NGLSTMCellOp : ...@@ -363,7 +363,7 @@ def NGLSTMCellOp :
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res," "Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t, Value* C_t," "Value X, Value W, Value R, Value H_t, Value C_t,"
"Attribute hiddenSize, ArrayAttr activations," "Attribute hiddenSize, ArrayAttr activations,"
"ArrayAttr activationAlpha, ArrayAttr activationBeta," "ArrayAttr activationAlpha, ArrayAttr activationBeta,"
"Attribute clip, Attribute inputForget", [{ "Attribute clip, Attribute inputForget", [{
...@@ -379,7 +379,7 @@ def NGLSTMCellOp : ...@@ -379,7 +379,7 @@ def NGLSTMCellOp :
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res," "Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t, Value* C_t," "Value X, Value W, Value R, Value H_t, Value C_t,"
"Attribute hiddenSize", "Attribute hiddenSize",
[{ [{
tblgen_state.addOperands({X, W, R, H_t, C_t}); tblgen_state.addOperands({X, W, R, H_t, C_t});
...@@ -390,13 +390,13 @@ def NGLSTMCellOp : ...@@ -390,13 +390,13 @@ def NGLSTMCellOp :
let extraClassDeclaration = [{ let extraClassDeclaration = [{
// get bias operand if present // get bias operand if present
Value* B() Value B()
{ {
auto varArgs = optionalArgs(); auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr; return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
} }
// get peephole weights operand if present // get peephole weights operand if present
Value* P() Value P()
{ {
auto varArgs = optionalArgs(); auto varArgs = optionalArgs();
auto it = varArgs.begin(); auto it = varArgs.begin();
...@@ -452,7 +452,7 @@ def NGLSTMSequenceOp : ...@@ -452,7 +452,7 @@ def NGLSTMSequenceOp :
void setActivationsBeta (const ArrayAttr& attr) { this->setAttr("activatiBeta", attr); } void setActivationsBeta (const ArrayAttr& attr) { this->setAttr("activatiBeta", attr); }
void setClip(const Attribute& attr) { this->setAttr("clip", attr); } void setClip(const Attribute& attr) { this->setAttr("clip", attr); }
Value* P() Value P()
{ {
auto varArgs = optionalArgs(); auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr; return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
...@@ -500,7 +500,7 @@ def NGGRUCellOp : ...@@ -500,7 +500,7 @@ def NGGRUCellOp :
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res," "Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t," "Value X, Value W, Value R, Value H_t,"
"Attribute hiddenSize, ArrayAttr activations," "Attribute hiddenSize, ArrayAttr activations,"
"ArrayAttr activationAlpha, ArrayAttr activationBeta," "ArrayAttr activationAlpha, ArrayAttr activationBeta,"
"Attribute clip, Attribute linearBeforeReset", [{ "Attribute clip, Attribute linearBeforeReset", [{
...@@ -515,7 +515,7 @@ def NGGRUCellOp : ...@@ -515,7 +515,7 @@ def NGGRUCellOp :
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res," "Builder *builder, OperationState &tblgen_state, Type res,"
"Value *X, Value* W, Value* R, Value* H_t," "Value X, Value W, Value R, Value H_t,"
"Attribute hiddenSize", "Attribute hiddenSize",
[{ [{
tblgen_state.addOperands({X, W, R, H_t}); tblgen_state.addOperands({X, W, R, H_t});
...@@ -532,7 +532,7 @@ def NGGRUCellOp : ...@@ -532,7 +532,7 @@ def NGGRUCellOp :
void setLinearBeforeReset(const Attribute& attr) { this->setAttr("linearBeforeReset", attr); } void setLinearBeforeReset(const Attribute& attr) { this->setAttr("linearBeforeReset", attr); }
// get Bias operand if present // get Bias operand if present
Value* P() Value P()
{ {
auto varArgs = optionalArgs(); auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr; return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
...@@ -557,7 +557,7 @@ def NGLayerNormOp : ...@@ -557,7 +557,7 @@ def NGLayerNormOp :
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res," "Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res,"
"Value *data, Attribute keepStats, Attribute beginNormAxis, Attribute epsilon", [{ "Value data, Attribute keepStats, Attribute beginNormAxis, Attribute epsilon", [{
tblgen_state.addOperands(data); tblgen_state.addOperands(data);
tblgen_state.addAttribute("keepStats", keepStats); tblgen_state.addAttribute("keepStats", keepStats);
tblgen_state.addAttribute("beginNormAxis", beginNormAxis); tblgen_state.addAttribute("beginNormAxis", beginNormAxis);
...@@ -568,13 +568,13 @@ def NGLayerNormOp : ...@@ -568,13 +568,13 @@ def NGLayerNormOp :
let extraClassDeclaration = [{ let extraClassDeclaration = [{
// get Scale operand if present // get Scale operand if present
Value* Scale() Value Scale()
{ {
auto varArgs = optionalArgs(); auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr; return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
} }
// get Bias operand if present // get Bias operand if present
Value* Bias() Value Bias()
{ {
auto varArgs = optionalArgs(); auto varArgs = optionalArgs();
auto it = varArgs.begin(); auto it = varArgs.begin();
...@@ -608,7 +608,7 @@ def NGLayerNormBackpropOp : ...@@ -608,7 +608,7 @@ def NGLayerNormBackpropOp :
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res," "Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res,"
"Value *data, Value *delta, Value *mean, Value *variance," "Value data, Value delta, Value mean, Value variance,"
"Attribute beginNormAxis, Attribute epsilon", [{ "Attribute beginNormAxis, Attribute epsilon", [{
tblgen_state.addOperands({data, delta, mean, variance}); tblgen_state.addOperands({data, delta, mean, variance});
tblgen_state.addAttribute("beginNormAxis", beginNormAxis); tblgen_state.addAttribute("beginNormAxis", beginNormAxis);
...@@ -618,7 +618,7 @@ def NGLayerNormBackpropOp : ...@@ -618,7 +618,7 @@ def NGLayerNormBackpropOp :
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res," "Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res,"
"Value *data, Value *delta, Value *scale," "Value data, Value delta, Value scale,"
"Attribute beginNormAxis, Attribute epsilon", [{ "Attribute beginNormAxis, Attribute epsilon", [{
tblgen_state.addOperands({data, delta, scale}); tblgen_state.addOperands({data, delta, scale});
tblgen_state.addAttribute("beginNormAxis", beginNormAxis); tblgen_state.addAttribute("beginNormAxis", beginNormAxis);
...@@ -628,7 +628,7 @@ def NGLayerNormBackpropOp : ...@@ -628,7 +628,7 @@ def NGLayerNormBackpropOp :
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res," "Builder *builder, OperationState &tblgen_state, ArrayRef<Type> res,"
"Value *data, Value *delta," "Value data, Value delta,"
"Attribute beginNormAxis, Attribute epsilon", [{ "Attribute beginNormAxis, Attribute epsilon", [{
tblgen_state.addOperands({data, delta}); tblgen_state.addOperands({data, delta});
tblgen_state.addAttribute("beginNormAxis", beginNormAxis); tblgen_state.addAttribute("beginNormAxis", beginNormAxis);
...@@ -639,13 +639,13 @@ def NGLayerNormBackpropOp : ...@@ -639,13 +639,13 @@ def NGLayerNormBackpropOp :
let extraClassDeclaration = [{ let extraClassDeclaration = [{
// get Mean operand if present // get Mean operand if present
Value* Mean() Value Mean()
{ {
auto varArgs = optionalArgs(); auto varArgs = optionalArgs();
return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr; return varArgs.begin() != varArgs.end() ? *varArgs.begin() : nullptr;
} }
// get Variance operand if present // get Variance operand if present
Value* Variance() Value Variance()
{ {
auto varArgs = optionalArgs(); auto varArgs = optionalArgs();
auto it = varArgs.begin(); auto it = varArgs.begin();
...@@ -722,8 +722,8 @@ def NGGroupConvOp : ...@@ -722,8 +722,8 @@ def NGGroupConvOp :
let builders = [ let builders = [
// Builder without padType // Builder without padType
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *images," "Builder *builder, OperationState &tblgen_state, Type res, Value images,"
"Value *filters, ArrayAttr strides, ArrayAttr padBelow, ArrayAttr padAbove," "Value filters, ArrayAttr strides, ArrayAttr padBelow, ArrayAttr padAbove,"
"Attribute groups", "Attribute groups",
[{ [{
tblgen_state.addOperands({images, filters}); tblgen_state.addOperands({images, filters});
...@@ -772,18 +772,18 @@ def NGGroupConvTransposeOp : ...@@ -772,18 +772,18 @@ def NGGroupConvTransposeOp :
let builders = [ let builders = [
OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res," OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters, Attribute groups", [{ "Value images, Value filters, Attribute groups", [{
tblgen_state.addOperands({images, filters}); tblgen_state.addOperands({images, filters});
tblgen_state.addAttribute("groups", groups); tblgen_state.addAttribute("groups", groups);
tblgen_state.addTypes(res); tblgen_state.addTypes(res);
}]>, }]>,
OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res," OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters", [{ "Value images, Value filters", [{
tblgen_state.addOperands({images, filters}); tblgen_state.addOperands({images, filters});
tblgen_state.addTypes(res); tblgen_state.addTypes(res);
}]>, }]>,
OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res," OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters, ArrayAttr strides," "Value images, Value filters, ArrayAttr strides,"
"ArrayAttr outputPad, ArrayAttr outputShape," "ArrayAttr outputPad, ArrayAttr outputShape,"
"Attribute groups", [{ "Attribute groups", [{
tblgen_state.addOperands({images, filters}); tblgen_state.addOperands({images, filters});
...@@ -793,7 +793,7 @@ def NGGroupConvTransposeOp : ...@@ -793,7 +793,7 @@ def NGGroupConvTransposeOp :
tblgen_state.addAttribute("groups", groups); tblgen_state.addAttribute("groups", groups);
}]>, }]>,
OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res," OpBuilder<"Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters," "Value images, Value filters,"
"ArrayAttr outputShape, Attribute groups", [{ "ArrayAttr outputShape, Attribute groups", [{
tblgen_state.addOperands({images, filters}); tblgen_state.addOperands({images, filters});
tblgen_state.addAttribute("outputShape", outputShape); tblgen_state.addAttribute("outputShape", outputShape);
...@@ -951,7 +951,7 @@ def NGConvBiasOp : ...@@ -951,7 +951,7 @@ def NGConvBiasOp :
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res," "Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters, Value *bias, Attribute withRelu", [{ "Value images, Value filters, Value bias, Attribute withRelu", [{
tblgen_state.addOperands({images, filters, bias}); tblgen_state.addOperands({images, filters, bias});
tblgen_state.addAttribute("withRelu", withRelu); tblgen_state.addAttribute("withRelu", withRelu);
tblgen_state.addTypes(res); tblgen_state.addTypes(res);
...@@ -959,7 +959,7 @@ def NGConvBiasOp : ...@@ -959,7 +959,7 @@ def NGConvBiasOp :
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res," "Builder *builder, OperationState &tblgen_state, Type res,"
"Value *images, Value *filters, Value *bias", [{ "Value images, Value filters, Value bias", [{
tblgen_state.addOperands({images, filters, bias}); tblgen_state.addOperands({images, filters, bias});
tblgen_state.addTypes(res); tblgen_state.addTypes(res);
}]> }]>
......
...@@ -44,12 +44,12 @@ using namespace mlir; ...@@ -44,12 +44,12 @@ using namespace mlir;
/// Checks if all operands and results are of compatible shapes /// Checks if all operands and results are of compatible shapes
template <typename T> template <typename T>
static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkResult = true) static mlir::LogicalResult verifyCompatibleOperandsAndResults(T op, bool checkResult = true)
{ {
mlir::Type t0 = op->getOperation()->getOperand(0)->getType(); mlir::Type t0 = op.getOperation()->getOperand(0).getType();
mlir::NGTensorType opType0 = t0.cast<NGTensorType>(); mlir::NGTensorType opType0 = t0.cast<NGTensorType>();
Operation* opr = op->getOperation(); Operation* opr = op.getOperation();
auto i = 0; auto i = 0;
for (auto operand : opr->getOperands()) for (auto operand : opr->getOperands())
{ {
...@@ -57,10 +57,10 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR ...@@ -57,10 +57,10 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR
{ {
continue; continue;
} }
mlir::Type t = operand->getType(); mlir::Type t = operand.getType();
mlir::NGTensorType opType = t.cast<NGTensorType>(); mlir::NGTensorType opType = t.cast<NGTensorType>();
if (!opType.isCompatible(opType0)) if (!opType.isCompatible(opType0))
return op->emitOpError("Incompatible operand shape"); return op.emitOpError("Incompatible operand shape");
i++; i++;
} }
...@@ -68,74 +68,74 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR ...@@ -68,74 +68,74 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR
{ {
for (auto result : opr->getResults()) for (auto result : opr->getResults())
{ {
mlir::Type t = result->getType(); mlir::Type t = result.getType();
mlir::NGTensorType resType = t.cast<NGTensorType>(); mlir::NGTensorType resType = t.cast<NGTensorType>();
if (!resType.isCompatible(opType0)) if (!resType.isCompatible(opType0))
return op->emitOpError("Incompatible operand shape"); return op.emitOpError("Incompatible operand shape");
} }
} }
return mlir::success(); return mlir::success();
} }
template <typename T> template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op) static mlir::LogicalResult verifyUnaryArithOp(T op)
{ {
return verifyCompatibleOperandsAndResults(op); return verifyCompatibleOperandsAndResults(op);
} }
template <typename T> template <typename T>
static mlir::LogicalResult verifyBinaryArithOp(T* op) static mlir::LogicalResult verifyBinaryArithOp(T op)
{ {
return verifyCompatibleOperandsAndResults(op); return verifyCompatibleOperandsAndResults(op);
} }
template <typename T> template <typename T>
static mlir::LogicalResult verifyAxisReductionOp(T* op) static mlir::LogicalResult verifyAxisReductionOp(T op)
{ {
return mlir::failure(); return mlir::failure();
} }
template <typename T> template <typename T>
static mlir::LogicalResult verifyLogicalReductionOp(T* op) static mlir::LogicalResult verifyLogicalReductionOp(T op)
{ {
// TODO: verifyAxisReductionOp(op) + input and return element type. // TODO: verifyAxisReductionOp(op) + input and return element type.
return mlir::failure(); return mlir::failure();
} }
template <typename T> template <typename T>
static mlir::LogicalResult verifyIndexReductionOp(T* op) static mlir::LogicalResult verifyIndexReductionOp(T op)
{ {
// TODO: verifyAxisReductionOp(op) + return element type + single axis. // TODO: verifyAxisReductionOp(op) + return element type + single axis.
return mlir::success(); return mlir::success();
} }
template <typename T> template <typename T>
static mlir::LogicalResult verifyOp(T* op) static mlir::LogicalResult verifyOp(T op)
{ {
return op->emitOpError("Unsupported verifier for this operation"); return op.emitOpError("Unsupported verifier for this operation");
} }
template <> template <>
mlir::LogicalResult verifyOp(NGDotOp* op) mlir::LogicalResult verifyOp(NGDotOp op)
{ {
// TODO(dcab): Improve verification: proper shapes, etc. // TODO(dcab): Improve verification: proper shapes, etc.
return mlir::success(); return mlir::success();
} }
template <> template <>
mlir::LogicalResult verifyOp(NGConcatOp* op) mlir::LogicalResult verifyOp(NGConcatOp op)
{ {
// TODO(amprocte): Improve verification: proper shapes, etc. // TODO(amprocte): Improve verification: proper shapes, etc.
return mlir::success(); return mlir::success();
} }
template <> template <>
mlir::LogicalResult verifyOp(NGSelectOp* op) mlir::LogicalResult verifyOp(NGSelectOp op)
{ {
mlir::Type t0 = op->getOperation()->getOperand(0)->getType(); mlir::Type t0 = op.getOperation()->getOperand(0).getType();
mlir::Type t1 = op->getOperation()->getOperand(1)->getType(); mlir::Type t1 = op.getOperation()->getOperand(1).getType();
mlir::Type t2 = op->getOperation()->getOperand(2)->getType(); mlir::Type t2 = op.getOperation()->getOperand(2).getType();
mlir::Type r0 = op->getOperation()->getResult(0)->getType(); mlir::Type r0 = op.getOperation()->getResult(0).getType();
NGTensorType opType0 = t0.cast<NGTensorType>(); NGTensorType opType0 = t0.cast<NGTensorType>();
NGTensorType opType1 = t1.cast<NGTensorType>(); NGTensorType opType1 = t1.cast<NGTensorType>();
...@@ -144,19 +144,19 @@ mlir::LogicalResult verifyOp(NGSelectOp* op) ...@@ -144,19 +144,19 @@ mlir::LogicalResult verifyOp(NGSelectOp* op)
// arg1 arg2 of same shape and elt type // arg1 arg2 of same shape and elt type
if (!opType1.isCompatible(opType2)) if (!opType1.isCompatible(opType2))
return op->emitOpError("Incompatible operand shapes or types for select op"); return op.emitOpError("Incompatible operand shapes or types for select op");
// arg0 of same shape and elt type is bool // arg0 of same shape and elt type is bool
if (!opType0.isCompatibleShape(opType1) || !opType0.getElementType().isa<NGBoolType>()) if (!opType0.isCompatibleShape(opType1) || !opType0.getElementType().isa<NGBoolType>())
return op->emitOpError("Incompatible shape for arg0 of select op"); return op.emitOpError("Incompatible shape for arg0 of select op");
// result is of same shape and elt type as arg1/2 // result is of same shape and elt type as arg1/2
if (!resType.isCompatible(opType1)) if (!resType.isCompatible(opType1))
return op->emitOpError("Incompatible result shape or type for select op"); return op.emitOpError("Incompatible result shape or type for select op");
return mlir::success(); return mlir::success();
} }
template <typename T> template <typename T>
static mlir::LogicalResult verifyCmpOp(T* op) static mlir::LogicalResult verifyCmpOp(T op)
{ {
mlir::LogicalResult result = verifyCompatibleOperandsAndResults(op, false /*checkResult*/); mlir::LogicalResult result = verifyCompatibleOperandsAndResults(op, false /*checkResult*/);
if (failed(result)) if (failed(result))
...@@ -164,75 +164,75 @@ static mlir::LogicalResult verifyCmpOp(T* op) ...@@ -164,75 +164,75 @@ static mlir::LogicalResult verifyCmpOp(T* op)
return result; return result;
} }
mlir::Type t0 = op->getOperation()->getOperand(0)->getType(); mlir::Type t0 = op.getOperation()->getOperand(0).getType();
mlir::NGTensorType opType0 = t0.cast<NGTensorType>(); mlir::NGTensorType opType0 = t0.cast<NGTensorType>();
mlir::Type r0 = op->getOperation()->getResult(0)->getType(); mlir::Type r0 = op.getOperation()->getResult(0).getType();
NGTensorType resType = r0.cast<NGTensorType>(); NGTensorType resType = r0.cast<NGTensorType>();
// result of same shape as input and has bool type // result of same shape as input and has bool type
if (!resType.isCompatibleShape(opType0) || if (!resType.isCompatibleShape(opType0) ||
!resType.getElementType().cast<NGIntegerType>().isUInt8()) !resType.getElementType().cast<NGIntegerType>().isUInt8())
{ {
return op->emitOpError("Incompatible result shape or type for comparison op"); return op.emitOpError("Incompatible result shape or type for comparison op");
} }
return mlir::success(); return mlir::success();
} }
template <> template <>
mlir::LogicalResult verifyOp(NGGatherOp* op) mlir::LogicalResult verifyOp(NGGatherOp op)
{ {
Type ty = op->params()->getType(); Type ty = op.params().getType();
NGTensorType inputType = ty.cast<NGTensorType>(); NGTensorType inputType = ty.cast<NGTensorType>();
ty = op->indices()->getType(); ty = op.indices().getType();
NGTensorType indicesType = ty.cast<NGTensorType>(); NGTensorType indicesType = ty.cast<NGTensorType>();
// ensure axis < params rank // ensure axis < params rank
if (op->axis().getSExtValue() >= inputType.getRank()) if (op.axis().getSExtValue() >= inputType.getRank())
return op->emitOpError("Gather axis is larger than input rank"); return op.emitOpError("Gather axis is larger than input rank");
ty = indicesType.getElementType(); ty = indicesType.getElementType();
// ensure indices are I32 or I64 // ensure indices are I32 or I64
if (!ty.isa<NGIntegerType>()) if (!ty.isa<NGIntegerType>())
return op->emitOpError("Indices tensor is not of Integer type"); return op.emitOpError("Indices tensor is not of Integer type");
NGIntegerType indicesEltType = ty.cast<NGIntegerType>(); NGIntegerType indicesEltType = ty.cast<NGIntegerType>();
if (!indicesEltType.isInt32() && !indicesEltType.isInt64()) if (!indicesEltType.isInt32() && !indicesEltType.isInt64())
return op->emitOpError("Indices tensor is not of I32 or I64 type"); return op.emitOpError("Indices tensor is not of I32 or I64 type");
mlir::Type r0 = op->res()->getType(); mlir::Type r0 = op.res().getType();
NGTensorType resType = r0.cast<NGTensorType>(); NGTensorType resType = r0.cast<NGTensorType>();
// ensure result is compatible with input // ensure result is compatible with input
if (resType.getRank() != inputType.getRank() + indicesType.getRank() - 1) if (resType.getRank() != inputType.getRank() + indicesType.getRank() - 1)
return op->emitOpError("Incompatible result shape and/or type"); return op.emitOpError("Incompatible result shape and/or type");
return mlir::success(); return mlir::success();
} }
template <> template <>
mlir::LogicalResult verifyOp(NGConvolutionOp* op) mlir::LogicalResult verifyOp(NGConvolutionOp op)
{ {
Type ty = op->images()->getType(); Type ty = op.images().getType();
NGTensorType imagesType = ty.cast<NGTensorType>(); NGTensorType imagesType = ty.cast<NGTensorType>();
Type imagesEt = imagesType.getElementType(); Type imagesEt = imagesType.getElementType();
Shape imagesShape = imagesType.getShape(); Shape imagesShape = imagesType.getShape();
ty = op->filters()->getType(); ty = op.filters().getType();
NGTensorType filtersType = ty.cast<NGTensorType>(); NGTensorType filtersType = ty.cast<NGTensorType>();
Type filtersEt = filtersType.getElementType(); Type filtersEt = filtersType.getElementType();
Shape filtersShape = filtersType.getShape(); Shape filtersShape = filtersType.getShape();
ty = op->res()->getType(); ty = op.res().getType();
NGTensorType resultType = ty.cast<NGTensorType>(); NGTensorType resultType = ty.cast<NGTensorType>();
Shape resultShape = resultType.getShape(); Shape resultShape = resultType.getShape();
ArrayAttr strides = op->strides(); ArrayAttr strides = op.strides();
ArrayAttr padBelow = op->padBelow(); ArrayAttr padBelow = op.padBelow();
ArrayAttr padAbove = op->padAbove(); ArrayAttr padAbove = op.padAbove();
unsigned imagesRank = imagesShape.size(); unsigned imagesRank = imagesShape.size();
unsigned filtersRank = filtersShape.size(); unsigned filtersRank = filtersShape.size();
...@@ -247,32 +247,32 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op) ...@@ -247,32 +247,32 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op)
// Identical filters and image element types // Identical filters and image element types
if (filtersEt != imagesEt) if (filtersEt != imagesEt)
{ {
return op->emitOpError("Incompatible image and filters types"); return op.emitOpError("Incompatible image and filters types");
} }
// Verify image shape // Verify image shape
if (imagesRank < 3) if (imagesRank < 3)
{ {
return op->emitOpError("Image shape of rank below 3"); return op.emitOpError("Image shape of rank below 3");
} }
// Verify strides and pads shapes // Verify strides and pads shapes
if (imageSpatialRank != stridesRank || imageSpatialRank != padBelowRank || if (imageSpatialRank != stridesRank || imageSpatialRank != padBelowRank ||
imageSpatialRank != padAboveRank) imageSpatialRank != padAboveRank)
{ {
return op->emitOpError("Image spatial rank mismatches strides and/or padding ranks"); return op.emitOpError("Image spatial rank mismatches strides and/or padding ranks");
} }
if (imageSpatialRank != filtersSpatialRank) if (imageSpatialRank != filtersSpatialRank)
{ {
return op->emitOpError("Image and filters spatial ranks mismatch"); return op.emitOpError("Image and filters spatial ranks mismatch");
} }
// Batch size is non-zero, and identical non-zero channel depth // Batch size is non-zero, and identical non-zero channel depth
if (imagesShape[0] <= 0 || filtersShape[0] <= 0 || imagesShape[1] != filtersShape[1] || if (imagesShape[0] <= 0 || filtersShape[0] <= 0 || imagesShape[1] != filtersShape[1] ||
imagesShape[1] <= 0) imagesShape[1] <= 0)
{ {
return op->emitOpError("Image and filters have invalid shapes"); return op.emitOpError("Image and filters have invalid shapes");
} }
for (auto attrs : llvm::zip(strides, padBelow, padAbove)) for (auto attrs : llvm::zip(strides, padBelow, padAbove))
...@@ -283,7 +283,7 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op) ...@@ -283,7 +283,7 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op)
if (s <= 0) if (s <= 0)
{ {
return op->emitOpError("Window stride must be non-negative"); return op.emitOpError("Window stride must be non-negative");
} }
stridesVal.push_back(s); stridesVal.push_back(s);
padBelowVal.push_back(pb); padBelowVal.push_back(pb);
...@@ -294,7 +294,7 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op) ...@@ -294,7 +294,7 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op)
if (resultRank != imagesRank || resultShape[0] != imagesShape[0] || if (resultRank != imagesRank || resultShape[0] != imagesShape[0] ||
resultShape[1] != filtersShape[0]) resultShape[1] != filtersShape[0])
{ {
return op->emitOpError("Invalid result shape"); return op.emitOpError("Invalid result shape");
} }
for (unsigned i = 0; i < resultRank - 2; i++) for (unsigned i = 0; i < resultRank - 2; i++)
{ {
...@@ -303,56 +303,42 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op) ...@@ -303,56 +303,42 @@ mlir::LogicalResult verifyOp(NGConvolutionOp* op)
stridesVal[i]); stridesVal[i]);
if (resultShape[2 + i] != resDim) if (resultShape[2 + i] != resDim)
{ {
return op->emitOpError("Invalid result spatial shape"); return op.emitOpError("Invalid result spatial shape");
} }
} }
return mlir::success(); return mlir::success();
} }
template <> template <>
mlir::LogicalResult verifyOp(NGMatMulOp* op) mlir::LogicalResult verifyOp(NGSoftMaxOp op)
{ {
// TODO(ayzhuang): Improve verification: proper shapes, etc. // TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success(); return mlir::success();
} }
template <> template <>
mlir::LogicalResult verifyOp(NGGemmOp* op) mlir::LogicalResult verifyOp(NGAvgPoolOp op)
{ {
// TODO(ayzhuang): Improve verification: proper shapes, etc. // TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success(); return mlir::success();
} }
template <> template <>
mlir::LogicalResult verifyOp(NGSoftMaxOp* op) mlir::LogicalResult verifyOp(NGAvgPoolBackpropOp op)
{ {
// TODO(ayzhuang): Improve verification: proper shapes, etc. // TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success(); return mlir::success();
} }
template <> template <>
mlir::LogicalResult verifyOp(NGAvgPoolOp* op) mlir::LogicalResult verifyOp(NGMaxPoolOp op)
{ {
// TODO(ayzhuang): Improve verification: proper shapes, etc. // TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success(); return mlir::success();
} }
template <> template <>
mlir::LogicalResult verifyOp(NGAvgPoolBackpropOp* op) mlir::LogicalResult verifyOp(NGMaxPoolBackpropOp op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGMaxPoolOp* op)
{
// TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGMaxPoolBackpropOp* op)
{ {
// TODO(ayzhuang): Improve verification: proper shapes, etc. // TODO(ayzhuang): Improve verification: proper shapes, etc.
return mlir::success(); return mlir::success();
......
...@@ -78,7 +78,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -78,7 +78,7 @@ class NG_Unary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyUnaryArithOp(this); }]; let verifier = [{ return verifyUnaryArithOp(*this); }];
} }
// Base class for arithmetic binary operations without side effects. // Base class for arithmetic binary operations without side effects.
...@@ -98,7 +98,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -98,7 +98,7 @@ class NG_Binary_Arith_Op<string mnemonic, list<OpTrait> traits = []> :
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyBinaryArithOp(this); }]; let verifier = [{ return verifyBinaryArithOp(*this); }];
} }
// Base class for comparison operations with verifier. // Base class for comparison operations with verifier.
...@@ -109,7 +109,7 @@ class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -109,7 +109,7 @@ class NG_Cmp_Op<string mnemonic, list<OpTrait> traits = []> :
// TODO: Implement // TODO: Implement
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyCmpOp(this); }]; let verifier = [{ return verifyCmpOp(*this); }];
} }
// Base class for ternary operations without side effects. // Base class for ternary operations without side effects.
...@@ -133,7 +133,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -133,7 +133,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
// TODO // TODO
let verifier = [{ return verifyAxisReductionOp(this); }]; let verifier = [{ return verifyAxisReductionOp(*this); }];
} }
// Base class for terminator operations. // Base class for terminator operations.
......
...@@ -61,14 +61,14 @@ def NGNotEqOp : NG_Cmp_Op<"not.equal", [OpVersion0]>; ...@@ -61,14 +61,14 @@ def NGNotEqOp : NG_Cmp_Op<"not.equal", [OpVersion0]>;
// Other // Other
def NGSelectOp : NG_Ternary_Op<"select", [OpVersion0]> def NGSelectOp : NG_Ternary_Op<"select", [OpVersion0]>
{ {
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
} }
// Dot Product // Dot Product
def NGDotOp : NG_Binary_Op<"dot", [OpVersion0]> def NGDotOp : NG_Binary_Op<"dot", [OpVersion0]>
{ {
// TODO: Add reduction axis attribute when needed. // TODO: Add reduction axis attribute when needed.
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
} }
// TODO(amprocte): Might be nice to rebase this on some sort of NG_Variadic_Op // TODO(amprocte): Might be nice to rebase this on some sort of NG_Variadic_Op
...@@ -80,56 +80,56 @@ def NGConcatOp : ...@@ -80,56 +80,56 @@ def NGConcatOp :
{ {
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
} }
// Axis reduction operations. // Axis reduction operations.
def NGSumRedOp : NG_Axis_Reduction_Op<"sum.red", [OpVersion0]> def NGSumRedOp : NG_Axis_Reduction_Op<"sum.red", [OpVersion0]>
{ {
let summary = "Axis sum reduction of a tensor."; let summary = "Axis sum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }]; let verifier = [{ return verifyAxisReductionOp(*this); }];
} }
def NGProdRedOp : NG_Axis_Reduction_Op<"prod.red", [OpVersion0]> def NGProdRedOp : NG_Axis_Reduction_Op<"prod.red", [OpVersion0]>
{ {
let summary = "Axis product reduction of a tensor."; let summary = "Axis product reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }]; let verifier = [{ return verifyAxisReductionOp(*this); }];
} }
def NGMinRedOp : NG_Axis_Reduction_Op<"min.red", [OpVersion0]> def NGMinRedOp : NG_Axis_Reduction_Op<"min.red", [OpVersion0]>
{ {
let summary = "Axis minimum reduction of a tensor."; let summary = "Axis minimum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }]; let verifier = [{ return verifyAxisReductionOp(*this); }];
} }
def NGMaxRedOp : NG_Axis_Reduction_Op<"max.red", [OpVersion0]> def NGMaxRedOp : NG_Axis_Reduction_Op<"max.red", [OpVersion0]>
{ {
let summary = "Axis maximum reduction of a tensor."; let summary = "Axis maximum reduction of a tensor.";
let verifier = [{ return verifyAxisReductionOp(this); }]; let verifier = [{ return verifyAxisReductionOp(*this); }];
} }
def NGArgMinRedOp : NG_Axis_Reduction_Op<"argmin.red", [OpVersion0]> def NGArgMinRedOp : NG_Axis_Reduction_Op<"argmin.red", [OpVersion0]>
{ {
let summary = "Axis minimum index reduction of a tensor."; let summary = "Axis minimum index reduction of a tensor.";
let verifier = [{ return verifyIndexReductionOp(this); }]; let verifier = [{ return verifyIndexReductionOp(*this); }];
} }
def NGArgMaxRedOp : NG_Axis_Reduction_Op<"argmax.red", [OpVersion0]> def NGArgMaxRedOp : NG_Axis_Reduction_Op<"argmax.red", [OpVersion0]>
{ {
let summary = "Axis maximum index reduction of a tensor."; let summary = "Axis maximum index reduction of a tensor.";
let verifier = [{ return verifyIndexReductionOp(this); }]; let verifier = [{ return verifyIndexReductionOp(*this); }];
} }
def NGAllRedOp : NG_Axis_Reduction_Op<"all.red", [OpVersion0]> def NGAllRedOp : NG_Axis_Reduction_Op<"all.red", [OpVersion0]>
{ {
let summary = "Axis logical AND reduction of a boolean tensor."; let summary = "Axis logical AND reduction of a boolean tensor.";
let verifier = [{ return verifyLogicalReductionOp(this); }]; let verifier = [{ return verifyLogicalReductionOp(*this); }];
} }
def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red", [OpVersion0]> def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red", [OpVersion0]>
{ {
let summary = "Axis logical OR reduction of a boolean tensor."; let summary = "Axis logical OR reduction of a boolean tensor.";
let verifier = [{ return verifyLogicalReductionOp(this); }]; let verifier = [{ return verifyLogicalReductionOp(*this); }];
} }
// Gather // Gather
...@@ -147,7 +147,7 @@ def NGGatherOp : ...@@ -147,7 +147,7 @@ def NGGatherOp :
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
} }
// Convolution // Convolution
...@@ -171,7 +171,7 @@ def NGConvolutionOp : ...@@ -171,7 +171,7 @@ def NGConvolutionOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setStrides(ArrayAttr& arrayAttr) { this->setAttr("strides", arrayAttr); } void setStrides(ArrayAttr& arrayAttr) { this->setAttr("strides", arrayAttr); }
void setPadBelow(ArrayAttr& arrayAttr) { this->setAttr("padBelow", arrayAttr); } void setPadBelow(ArrayAttr& arrayAttr) { this->setAttr("padBelow", arrayAttr); }
...@@ -210,10 +210,10 @@ def NGAvgPoolOp : ...@@ -210,10 +210,10 @@ def NGAvgPoolOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *arg," "Builder *builder, OperationState &tblgen_state, Type res, Value arg,"
"ArrayAttr windowShape, ArrayAttr windowMovementStrides," "ArrayAttr windowShape, ArrayAttr windowMovementStrides,"
"ArrayAttr padBelow, ArrayAttr padAbove, BoolAttr includPadding, IntegerAttr padType", [{ "ArrayAttr padBelow, ArrayAttr padAbove, BoolAttr includPadding, IntegerAttr padType", [{
tblgen_state.addOperands(arg); tblgen_state.addOperands(arg);
...@@ -227,7 +227,7 @@ def NGAvgPoolOp : ...@@ -227,7 +227,7 @@ def NGAvgPoolOp :
}]>, }]>,
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *arg," "Builder *builder, OperationState &tblgen_state, Type res, Value arg,"
"ArrayAttr windowShape, ArrayAttr windowMovementStrides," "ArrayAttr windowShape, ArrayAttr windowMovementStrides,"
"ArrayAttr padBelow, ArrayAttr padAbove, BoolAttr includPadding", [{ "ArrayAttr padBelow, ArrayAttr padAbove, BoolAttr includPadding", [{
tblgen_state.addOperands(arg); tblgen_state.addOperands(arg);
...@@ -269,7 +269,7 @@ def NGAvgPoolBackpropOp : ...@@ -269,7 +269,7 @@ def NGAvgPoolBackpropOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setForwardArgShape(const ArrayAttr& arrayAttr) { this->setAttr("forwardArgShape", arrayAttr); } void setForwardArgShape(const ArrayAttr& arrayAttr) { this->setAttr("forwardArgShape", arrayAttr); }
...@@ -296,7 +296,7 @@ def NGBatchNormTrainingOp : ...@@ -296,7 +296,7 @@ def NGBatchNormTrainingOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setEpsilon(const Attribute& attr) { this->setAttr("epsilon", attr); } void setEpsilon(const Attribute& attr) { this->setAttr("epsilon", attr); }
...@@ -320,7 +320,7 @@ def NGBatchNormInferenceOp : ...@@ -320,7 +320,7 @@ def NGBatchNormInferenceOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setEpsilon(const Attribute& attr) { this->setAttr("epsilon", attr); } void setEpsilon(const Attribute& attr) { this->setAttr("epsilon", attr); }
}]; }];
...@@ -345,7 +345,7 @@ def NGBatchNormTrainingBackPropOp : ...@@ -345,7 +345,7 @@ def NGBatchNormTrainingBackPropOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setEpsilon(const Attribute& attr) { this->setAttr("epsilon", attr); } void setEpsilon(const Attribute& attr) { this->setAttr("epsilon", attr); }
}]; }];
...@@ -366,7 +366,7 @@ def NGBroadcastOp : ...@@ -366,7 +366,7 @@ def NGBroadcastOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setAxisSet(const ArrayAttr& attr) { this->setAttr("axisSet", attr); } void setAxisSet(const ArrayAttr& attr) { this->setAttr("axisSet", attr); }
void setShape(const ArrayAttr& attr) { this->setAttr("shape", attr); } void setShape(const ArrayAttr& attr) { this->setAttr("shape", attr); }
...@@ -387,7 +387,7 @@ def NGConstantOp : ...@@ -387,7 +387,7 @@ def NGConstantOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
} }
// MaxPool // MaxPool
...@@ -416,23 +416,10 @@ def NGMaxPoolOp : ...@@ -416,23 +416,10 @@ def NGMaxPoolOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let builders = [ let builders = [
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *arg," "Builder *builder, OperationState &tblgen_state, Type res, Value arg,"
"ArrayAttr windowShape, ArrayAttr windowMovementStrides,"
"ArrayAttr padBelow, ArrayAttr padAbove, IntegerAttr padType", [{
tblgen_state.addOperands(arg);
tblgen_state.addAttribute("windowShape", windowShape);
tblgen_state.addAttribute("windowMovementStrides", windowMovementStrides);
tblgen_state.addAttribute("padBelow", padBelow);
tblgen_state.addAttribute("padAbove", padAbove);
tblgen_state.addAttribute("padType", padType);
tblgen_state.addTypes(res);
}]>,
OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, Value *arg,"
"ArrayAttr windowShape, ArrayAttr windowMovementStrides," "ArrayAttr windowShape, ArrayAttr windowMovementStrides,"
"ArrayAttr padBelow, ArrayAttr padAbove", [{ "ArrayAttr padBelow, ArrayAttr padAbove", [{
tblgen_state.addOperands(arg); tblgen_state.addOperands(arg);
...@@ -471,7 +458,7 @@ def NGMaxPoolBackpropOp : ...@@ -471,7 +458,7 @@ def NGMaxPoolBackpropOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setWindowShape(const ArrayAttr& arrayAttr) { this->setAttr("windowShape", arrayAttr); } void setWindowShape(const ArrayAttr& arrayAttr) { this->setAttr("windowShape", arrayAttr); }
...@@ -510,12 +497,12 @@ def NGPadOp : ...@@ -510,12 +497,12 @@ def NGPadOp :
DefaultValuedAttr<PadModeEnumAttr, "MLIRPadMode::CONSTANT"> :$padMode)> DefaultValuedAttr<PadModeEnumAttr, "MLIRPadMode::CONSTANT"> :$padMode)>
{ {
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let builders = [ let builders = [
// Builder without padMode // Builder without padMode
OpBuilder< OpBuilder<
"Builder *builder, OperationState &tblgen_state, Type res, " "Builder *builder, OperationState &tblgen_state, Type res, "
"Value *arg, Value *padValue, " "Value arg, Value padValue, "
"ArrayAttr padBelow, ArrayAttr padAbove", [{ "ArrayAttr padBelow, ArrayAttr padAbove", [{
tblgen_state.addOperands(arg); tblgen_state.addOperands(arg);
tblgen_state.addOperands(padValue); tblgen_state.addOperands(padValue);
...@@ -556,7 +543,7 @@ def NGReplaceSliceOp : ...@@ -556,7 +543,7 @@ def NGReplaceSliceOp :
slice to be replaced. slice to be replaced.
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setLowerBounds(const ArrayAttr& arrayAttr) { this->setAttr("lowerBounds", arrayAttr); } void setLowerBounds(const ArrayAttr& arrayAttr) { this->setAttr("lowerBounds", arrayAttr); }
void setUpperBounds(const ArrayAttr& arrayAttr) { this->setAttr("upperBounds", arrayAttr); } void setUpperBounds(const ArrayAttr& arrayAttr) { this->setAttr("upperBounds", arrayAttr); }
...@@ -584,7 +571,7 @@ def NGSliceOp : ...@@ -584,7 +571,7 @@ def NGSliceOp :
slice to be replaced. slice to be replaced.
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setLowerBounds(const ArrayAttr& arrayAttr) { this->setAttr("lowerBounds", arrayAttr); } void setLowerBounds(const ArrayAttr& arrayAttr) { this->setAttr("lowerBounds", arrayAttr); }
void setUpperBounds(const ArrayAttr& arrayAttr) { this->setAttr("upperBounds", arrayAttr); } void setUpperBounds(const ArrayAttr& arrayAttr) { this->setAttr("upperBounds", arrayAttr); }
...@@ -610,7 +597,7 @@ def NGReshapeOp : ...@@ -610,7 +597,7 @@ def NGReshapeOp :
Pi(a_i) = Pi(b_i) Pi(a_i) = Pi(b_i)
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setAxisOrder(const ArrayAttr& arrayAttr) { this->setAttr("axisOrder", arrayAttr); } void setAxisOrder(const ArrayAttr& arrayAttr) { this->setAttr("axisOrder", arrayAttr); }
void setShape(const ArrayAttr& arrayAttr) { this->setAttr("shape", arrayAttr); } void setShape(const ArrayAttr& arrayAttr) { this->setAttr("shape", arrayAttr); }
...@@ -630,7 +617,7 @@ def NGSoftMaxOp : ...@@ -630,7 +617,7 @@ def NGSoftMaxOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setAxes(const ArrayAttr& arrayAttr) { this->setAttr("axes", arrayAttr); } void setAxes(const ArrayAttr& arrayAttr) { this->setAttr("axes", arrayAttr); }
}]; }];
...@@ -657,7 +644,7 @@ def NGTopKOp : ...@@ -657,7 +644,7 @@ def NGTopKOp :
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }]; let verifier = [{ return verifyOp(*this); }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
void setK(const Attribute& attr) { this->setAttr("k", attr); } void setK(const Attribute& attr) { this->setAttr("k", attr); }
void setAxis(const Attribute& attr) { this->setAttr("axis", attr); } void setAxis(const Attribute& attr) { this->setAttr("axis", attr); }
......
...@@ -68,7 +68,7 @@ namespace ...@@ -68,7 +68,7 @@ namespace
struct TensorInfo struct TensorInfo
{ {
// MLIR values this tensor maps to. // MLIR values this tensor maps to.
mlir::Value* m_value; mlir::Value m_value;
}; };
private: private:
...@@ -84,7 +84,7 @@ namespace ...@@ -84,7 +84,7 @@ namespace
mlir::Type getMlirType(const ngraph::Node* node); mlir::Type getMlirType(const ngraph::Node* node);
TensorInfo getTensorValue(descriptor::Tensor* tensor); TensorInfo getTensorValue(descriptor::Tensor* tensor);
void updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value); void updateTensorValue(descriptor::Tensor* tensor, mlir::Value value);
template <typename Op> template <typename Op>
static mlir::Operation* createOp(NgDialectConversionPass& NgDialectObj, static mlir::Operation* createOp(NgDialectConversionPass& NgDialectObj,
...@@ -176,7 +176,7 @@ void NgDialectConversionPass::runOnModule() ...@@ -176,7 +176,7 @@ void NgDialectConversionPass::runOnModule()
int i = 0; int i = 0;
for (auto input : kernelInputs) for (auto input : kernelInputs)
{ {
mlir::Value* arg = function.getArgument(i); auto arg = function.getArgument(i);
TensorInfo tensorInfo{arg}; TensorInfo tensorInfo{arg};
m_tensorToValueMap.insert(TensorToInfo(input->get_output_tensor_ptr().get(), tensorInfo)); m_tensorToValueMap.insert(TensorToInfo(input->get_output_tensor_ptr().get(), tensorInfo));
i++; i++;
...@@ -264,7 +264,7 @@ mlir::Type NgDialectConversionPass::getMlirType(const ngraph::Node* node) ...@@ -264,7 +264,7 @@ mlir::Type NgDialectConversionPass::getMlirType(const ngraph::Node* node)
return getMlirType(outTensor); return getMlirType(outTensor);
} }
void NgDialectConversionPass::updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value) void NgDialectConversionPass::updateTensorValue(descriptor::Tensor* tensor, mlir::Value value)
{ {
NGRAPH_CHECK(m_tensorToValueMap.find(tensor) == m_tensorToValueMap.end(), NGRAPH_CHECK(m_tensorToValueMap.find(tensor) == m_tensorToValueMap.end(),
"tensor value already defined"); "tensor value already defined");
...@@ -307,7 +307,7 @@ void NgDialectConversionPass::buildNgDialect(mlir::FuncOp function) ...@@ -307,7 +307,7 @@ void NgDialectConversionPass::buildNgDialect(mlir::FuncOp function)
{ {
for (auto i = 0; i < op->getNumResults(); i++) for (auto i = 0; i < op->getNumResults(); i++)
{ {
mlir::Value* result = op->getResult(i); auto result = op->getResult(i);
if (result) if (result)
{ {
updateTensorValue(np->get_output_tensor_ptr(i).get(), result); updateTensorValue(np->get_output_tensor_ptr(i).get(), result);
...@@ -600,7 +600,6 @@ template <> ...@@ -600,7 +600,6 @@ template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Softmax) mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Softmax)
{ {
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGSoftMaxOp>(ngNode, 1); mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGSoftMaxOp>(ngNode, 1);
auto softmaxNode = static_cast<const ngraph::op::Softmax*>(ngNode);
auto softmaxOp = llvm::cast<mlir::NGSoftMaxOp>(op); auto softmaxOp = llvm::cast<mlir::NGSoftMaxOp>(op);
auto originArg = NgDialectObj.getOriginArg(ngNode->input_value(1).get_node()); auto originArg = NgDialectObj.getOriginArg(ngNode->input_value(1).get_node());
...@@ -614,7 +613,7 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Softmax) ...@@ -614,7 +613,7 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Softmax)
template <typename Op> template <typename Op>
mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ngNode, int inNum) mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ngNode, int inNum)
{ {
std::vector<mlir::Value*> argValues; std::vector<mlir::Value> argValues;
std::vector<mlir::Type> resTypes; std::vector<mlir::Type> resTypes;
auto inputMap = m_compiledKernel->get_input_map(); auto inputMap = m_compiledKernel->get_input_map();
std::shared_ptr<descriptor::Tensor> argTensor; std::shared_ptr<descriptor::Tensor> argTensor;
...@@ -650,7 +649,7 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng ...@@ -650,7 +649,7 @@ mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ng
return (m_builder.create<Op, return (m_builder.create<Op,
ArrayRef<mlir::Type>, ArrayRef<mlir::Type>,
ArrayRef<mlir::Value*>, ArrayRef<mlir::Value>,
ArrayRef<mlir::NamedAttribute>>( ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(m_context), resTypes, argValues, {/* no attrs */})) mlir::UnknownLoc::get(m_context), resTypes, argValues, {/* no attrs */}))
.getOperation(); .getOperation();
...@@ -663,7 +662,7 @@ const NgDialectConversionPass::MLIRCompOpMap NgDialectConversionPass::opDispatch ...@@ -663,7 +662,7 @@ const NgDialectConversionPass::MLIRCompOpMap NgDialectConversionPass::opDispatch
void NgDialectConversionPass::createReturn() void NgDialectConversionPass::createReturn()
{ {
std::vector<mlir::Value*> valueList; std::vector<mlir::Value> valueList;
for (auto output : m_compiledKernel->get_kernel_outputs()) for (auto output : m_compiledKernel->get_kernel_outputs())
{ {
valueList.push_back(getTensorValue(output->get_output_tensor_ptr().get()).m_value); valueList.push_back(getTensorValue(output->get_output_tensor_ptr().get()).m_value);
......
...@@ -140,7 +140,6 @@ void MLIRCPURuntime::execute() ...@@ -140,7 +140,6 @@ void MLIRCPURuntime::execute()
void MLIRCPURuntime::cleanup() void MLIRCPURuntime::cleanup()
{ {
// Free void double pointer arguments without freeing external tensor data. // Free void double pointer arguments without freeing external tensor data.
int i = 0;
for (auto* arg : m_invokeArgs) for (auto* arg : m_invokeArgs)
{ {
auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(arg)); auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(arg));
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# Enable use of the lit tool that we build from MLIR repo. # Enable use of the lit tool that we build from MLIR repo.
set(LLVM_LIT ${LLVM_MAIN_SRC_DIR}/utils/lit/lit.py) set(LLVM_LIT ${LLVM_MAIN_SRC_DIR}/utils/lit/lit.py)
set(LLVM_DEFAULT_EXTERNAL_LIT ${MLIR_TOOLS_DIR}/llvm-lit) set(LLVM_DEFAULT_EXTERNAL_LIT ${MLIR_LLVM_TOOLS_DIR}/llvm-lit)
configure_lit_site_cfg( configure_lit_site_cfg(
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
......
...@@ -150,7 +150,8 @@ func @simple_dot(%arg0: !ng.tensor<16x8xf32>, %arg1: !ng.tensor<8x32xf32>) -> !n ...@@ -150,7 +150,8 @@ func @simple_dot(%arg0: !ng.tensor<16x8xf32>, %arg1: !ng.tensor<8x32xf32>) -> !n
// ----- // -----
// std.view // std.view
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1) -> (d0 * 2 + d1)
// CHECK: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK: %[[T1:[0-9]+]] = alloc() : memref<24xi8> // CHECK: %[[T1:[0-9]+]] = alloc() : memref<24xi8>
// CHECK-NEXT: %[[T2:[0-9]+]] = std.view %[[T1]][][] : memref<24xi8> to memref<3x2xf32, #[[MAP0]]> // CHECK-NEXT: %[[T2:[0-9]+]] = std.view %[[T1]][][] : memref<24xi8> to memref<3x2xf32, #[[MAP0]]>
// CHECK: affine.store %{{[0-9]+}}, %[[T2]][%{{.*}}, %{{.*}}] : memref<3x2xf32, #[[MAP0]]> // CHECK: affine.store %{{[0-9]+}}, %[[T2]][%{{.*}}, %{{.*}}] : memref<3x2xf32, #[[MAP0]]>
...@@ -198,12 +199,12 @@ func @convolution(%arg0: !ng.tensor<1x2x2x2xf32>, %arg1: !ng.tensor<2x2x1x1xf32> ...@@ -198,12 +199,12 @@ func @convolution(%arg0: !ng.tensor<1x2x2x2xf32>, %arg1: !ng.tensor<2x2x1x1xf32>
// ----- // -----
// //
// Group Convolution // Group Convolution
// CHECK-DAG: #[[M0:.*]] = (d0) -> (d0 * 2) // CHECK-DAG: #[[M0:.*]] = affine_map<(d0) -> (d0 * 2)>
// CHECK-DAG: #[[M1:.*]] = (d0) -> (d0 * 2 + 2) // CHECK-DAG: #[[M1:.*]] = affine_map<(d0) -> (d0 * 2 + 2)>
// CHECK-DAG: #[[M2:.*]] = (d0) -> (d0) // CHECK-DAG: #[[M2:.*]] = affine_map<(d0) -> (d0)>
// CHECK-DAG: #[[M3:.*]] = (d0) -> (d0 + 1) // CHECK-DAG: #[[M3:.*]] = affine_map<(d0) -> (d0 + 1)>
// CHECK-DAG: #[[M8:.*]] = (d0, d1) -> (d0 + d1) // CHECK-DAG: #[[M8:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-DAG: #[[M9:.*]] = (d0, d1) -> (d0 - d1 * 2) // CHECK-DAG: #[[M9:.*]] = affine_map<(d0, d1) -> (d0 - d1 * 2)>
// CHECK-LABEL: func @groupConv // CHECK-LABEL: func @groupConv
// //
// Outer groups loops // Outer groups loops
......
// RUN: ngraph-opt %s --split-input-file --ngraph-memory-opt --ngraph-memory-opt-concat --ngraph-memory-opt-eltwise -convert-ngraph-to-affine | FileCheck %s // RUN: ngraph-opt %s --split-input-file --ngraph-memory-opt --ngraph-memory-opt-concat --ngraph-memory-opt-eltwise -convert-ngraph-to-affine | FileCheck %s
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1) -> (d0 * 2 + d1) // CHECK: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK-LABEL: test0 // CHECK-LABEL: test0
// CHECK: %[[B:.*]] = alloc() : memref<16xi8> // CHECK: %[[B:.*]] = alloc() : memref<16xi8>
// CHECK: std.view %[[B]][][] : memref<16xi8> to memref<2x2xf32, #[[MAP0]]> // CHECK: std.view %[[B]][][] : memref<16xi8> to memref<2x2xf32, #[[MAP0]]>
...@@ -17,8 +17,8 @@ func @test0(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tenso ...@@ -17,8 +17,8 @@ func @test0(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tenso
// ----- // -----
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1) -> (d0 * 2 + d1) // CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = (d0, d1) -> (d0 * 2 + d1 + 4) // CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 + 4)>
// CHECK-LABEL: test1 // CHECK-LABEL: test1
// CHECK: %[[B:.*]] = alloc() : memref<32xi8> // CHECK: %[[B:.*]] = alloc() : memref<32xi8>
// CHECK: std.view %[[B]][][] : memref<32xi8> to memref<2x2xf32, #[[MAP0]]> // CHECK: std.view %[[B]][][] : memref<32xi8> to memref<2x2xf32, #[[MAP0]]>
...@@ -35,10 +35,10 @@ func @test1(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tenso ...@@ -35,10 +35,10 @@ func @test1(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tenso
// ----- // -----
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2) // CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2)>
// CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 4) // CHECK-DAG: #[[MAP1:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 4)>
// CHECK-DAG: #[[MAP2:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2) // CHECK-DAG: #[[MAP2:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2)>
// CHECK-DAG: #[[MAP3:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2) // CHECK-DAG: #[[MAP3:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2)>
// CHECK-LABEL: test2 // CHECK-LABEL: test2
// CHECK: %[[B1:.*]] = alloc() : memref<32xi8> // CHECK: %[[B1:.*]] = alloc() : memref<32xi8>
// CHECK: std.view %[[B1]][][] : memref<32xi8> to memref<1x2x2xf32, #[[MAP0]]> // CHECK: std.view %[[B1]][][] : memref<32xi8> to memref<1x2x2xf32, #[[MAP0]]>
...@@ -66,13 +66,13 @@ func @test2(%arg0: !ng.tensor<1x2x2xf32>, %arg1: !ng.tensor<1x2x2xf32>) -> (!ng. ...@@ -66,13 +66,13 @@ func @test2(%arg0: !ng.tensor<1x2x2xf32>, %arg1: !ng.tensor<1x2x2xf32>) -> (!ng.
// ----- // -----
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2) // CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2)>
// CHECK-DAG: #[[MAP8:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 8) // CHECK-DAG: #[[MAP8:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 8)>
// CHECK-DAG: #[[MAP9:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 16) // CHECK-DAG: #[[MAP9:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 16)>
// CHECK-DAG: #[[MAP10:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 24) // CHECK-DAG: #[[MAP10:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 8 + d1 * 2 + d2 + 24)>
// CHECK-DAG: #[[MAP11:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2) // CHECK-DAG: #[[MAP11:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2)>
// CHECK-DAG: #[[MAP12:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2 + 16) // CHECK-DAG: #[[MAP12:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 2 + d2 + 16)>
// CHECK-DAG: #[[MAP13:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2) // CHECK-DAG: #[[MAP13:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2)>
// CHECK-LABEL: test3 // CHECK-LABEL: test3
// CHECK: %[[B:.*]] = alloc() : memref<128xi8> // CHECK: %[[B:.*]] = alloc() : memref<128xi8>
// CHECK: std.view %[[B]][][] : memref<128xi8> to memref<1x4x2xf32, #[[MAP0]]> // CHECK: std.view %[[B]][][] : memref<128xi8> to memref<1x4x2xf32, #[[MAP0]]>
...@@ -97,10 +97,10 @@ func @test3(%arg0: !ng.tensor<1x2x2xf32>, %arg1: !ng.tensor<1x2x2xf32>) -> !ng.t ...@@ -97,10 +97,10 @@ func @test3(%arg0: !ng.tensor<1x2x2xf32>, %arg1: !ng.tensor<1x2x2xf32>) -> !ng.t
// ----- // -----
//CHECK-DAG: #[[MAP4:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 4) //CHECK-DAG: #[[MAP4:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 4)>
//CHECK-DAG: #[[MAP5:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2) //CHECK-DAG: #[[MAP5:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2)>
//CHECK-DAG: #[[MAP6:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 8) //CHECK-DAG: #[[MAP6:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 4 + d1 * 2 + d2 + 8)>
//CHECK-DAG: #[[MAP12:[a-zA-Z0-9]+]] = (d0, d1, d2) -> (d0 * 12 + d1 * 2 + d2) //CHECK-DAG: #[[MAP12:[a-zA-Z0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 12 + d1 * 2 + d2)>
// CHECK-LABEL: test4 // CHECK-LABEL: test4
//CHECK: %[[B1:.*]] = alloc() : memref<1x2x2xf32> //CHECK: %[[B1:.*]] = alloc() : memref<1x2x2xf32>
//CHECK: %[[B2:.*]] = alloc() : memref<48xi8> //CHECK: %[[B2:.*]] = alloc() : memref<48xi8>
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
import lit.llvm import lit.llvm
config.llvm_tools_dir = "@MLIR_TOOLS_DIR@" config.llvm_tools_dir = "@MLIR_LLVM_TOOLS_DIR@"
config.mlir_obj_root = "@MLIR_BUILD_DIR@" config.mlir_obj_root = "@MLIR_LLVM_BUILD_DIR@"
config.mlir_tools_dir = "@MLIR_TOOLS_DIR@" config.mlir_tools_dir = "@MLIR_LLVM_TOOLS_DIR@"
config.suffixes = ['.mlir'] config.suffixes = ['.mlir']
config.ngraph_mlir_tools_dir = "@NGRAPH_BUILD_BIN@" config.ngraph_mlir_tools_dir = "@NGRAPH_BUILD_BIN@"
......
...@@ -31,7 +31,6 @@ using namespace mlir; ...@@ -31,7 +31,6 @@ using namespace mlir;
OpBuilder createBuilder(MLIRContext* context) OpBuilder createBuilder(MLIRContext* context)
{ {
auto module = ModuleOp::create(UnknownLoc::get(context));
auto funcType = FunctionType::get({}, {}, context); auto funcType = FunctionType::get({}, {}, context);
auto function = FuncOp::create(UnknownLoc::get(context), "main", funcType); auto function = FuncOp::create(UnknownLoc::get(context), "main", funcType);
function.addEntryBlock(); function.addEntryBlock();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment