Unverified Commit e5258c0a authored by Diego Caballero's avatar Diego Caballero Committed by GitHub

[MLIR] Update MLIR repo (#4335)

* [MLIR] Update MLIR repo

* Nagy's fix

* Changes related to mlir-opt

* Update MLIR commit

* Update MLIR commit and callbacks.

* Disable 'noalias' attribute.

It will be re-introduced in a follow-up commit.

* Remove '__mlir' prefix in callback test

* Address feedback

* Fix EDSC includes

* Move MLIR repo forward

* Update type converter code

* Address feedback
Co-authored-by: 's avatarAmy Zhuang <amyzhuang97@gmail.com>
parent fdd8db66
...@@ -19,7 +19,7 @@ include(ExternalProject) ...@@ -19,7 +19,7 @@ include(ExternalProject)
set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git) set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
# Change these commit IDs to move to latest stable versions # Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID 96400ae) set(MLIR_LLVM_COMMIT_ID 376c6853)
# MLIR environment variables. Some of them are used by LIT tool. # MLIR environment variables. Some of them are used by LIT tool.
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <llvm/ADT/DenseSet.h> #include <llvm/ADT/DenseSet.h>
#include <map> #include <map>
#include <mlir/EDSC/Builders.h> #include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h> #include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h> #include <mlir/IR/AffineExpr.h>
#include <mlir/IR/IntegerSet.h> #include <mlir/IR/IntegerSet.h>
......
...@@ -194,7 +194,8 @@ void MLIRCPUBackend::lowerNgDialect() ...@@ -194,7 +194,8 @@ void MLIRCPUBackend::lowerNgDialect()
void MLIRCPUBackend::lowerStandardDialect() void MLIRCPUBackend::lowerStandardDialect()
{ {
mlir::PassManager pm(&m_context); mlir::PassManager pm(&m_context);
pm.addPass(mlir::createLowerToLLVMPass()); pm.addPass(mlir::createLowerToLLVMPass(
/*useAlloca=*/false, /*useBarePtrCallConv=*/false, /*emitCWrappers=*/true));
// Apply any generic pass manager command line options. // Apply any generic pass manager command line options.
mlir::applyPassManagerCLOptions(pm); mlir::applyPassManagerCLOptions(pm);
......
...@@ -23,9 +23,7 @@ ...@@ -23,9 +23,7 @@
#include "contrib/mlir/core/ngraph_dialect/type.hpp" #include "contrib/mlir/core/ngraph_dialect/type.hpp"
#include <llvm/IR/Module.h> #include <llvm/IR/Module.h>
#include <mlir/EDSC/Builders.h> #include <mlir/Dialect/AffineOps/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/IntegerSet.h> #include <mlir/IR/IntegerSet.h>
#include <mlir/IR/MLIRContext.h> #include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h> #include <mlir/IR/StandardTypes.h>
......
...@@ -719,7 +719,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA, ...@@ -719,7 +719,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
} }
} }
extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index, OpType type) extern "C" void _mlir_ciface_callback_1_input(void* input, void* output, size_t index, OpType type)
{ {
auto unrankedMemRefInput = reinterpret_cast<UnrankedMemRef*>(input); auto unrankedMemRefInput = reinterpret_cast<UnrankedMemRef*>(input);
auto unrankedMemRefOutput = reinterpret_cast<UnrankedMemRef*>(output); auto unrankedMemRefOutput = reinterpret_cast<UnrankedMemRef*>(output);
...@@ -752,8 +752,8 @@ extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index, ...@@ -752,8 +752,8 @@ extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index,
} }
} }
extern "C" void extern "C" void _mlir_ciface_callback_2_inputs(
__mlir_callback_2_inputs(void* input0, void* input1, void* output, size_t index, OpType type) void* input0, void* input1, void* output, size_t index, OpType type)
{ {
auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0); auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0);
auto unrankedMemRefInput1 = reinterpret_cast<UnrankedMemRef*>(input1); auto unrankedMemRefInput1 = reinterpret_cast<UnrankedMemRef*>(input1);
...@@ -780,7 +780,7 @@ extern "C" void ...@@ -780,7 +780,7 @@ extern "C" void
} }
} }
extern "C" void __mlir_callback_3_inputs( extern "C" void _mlir_ciface_callback_3_inputs(
void* input0, void* input1, void* input2, void* output, size_t index, OpType type) void* input0, void* input1, void* input2, void* output, size_t index, OpType type)
{ {
auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0); auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0);
......
...@@ -83,7 +83,7 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args) ...@@ -83,7 +83,7 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
{ {
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>("main"); auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>("_mlir_ciface_main");
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found"); NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments // Set external arguments
...@@ -127,14 +127,15 @@ void MLIRCPURuntime::execute() ...@@ -127,14 +127,15 @@ void MLIRCPURuntime::execute()
// uniformity reasons, it takes a list of type-erased pointers to arguments. // uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version. // Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked. // Make sure the MutableArrayRef version is invoked.
auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invokeArgs)); auto invocationResult =
m_engine->invoke("_mlir_ciface_main", llvm::MutableArrayRef<void*>(m_invokeArgs));
if (clDumpObjectFile) if (clDumpObjectFile)
{ {
m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o" m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o"
: clObjectFilename.getValue()); : clObjectFilename.getValue());
} }
NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n"); NGRAPH_CHECK(!invocationResult, "JIT invocation of '_mlir_ciface_main' failed\n");
} }
void MLIRCPURuntime::cleanup() void MLIRCPURuntime::cleanup()
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
set(LIBS set(LIBS
mlir_backend mlir_backend
MLIROptMain MLIROptLib
MLIRPass MLIRPass
MLIRParser MLIRParser
LLVMSupport LLVMSupport
......
...@@ -21,10 +21,21 @@ ...@@ -21,10 +21,21 @@
#include "contrib/mlir/core/ngraph_dialect/dialect.hpp" #include "contrib/mlir/core/ngraph_dialect/dialect.hpp"
#include <llvm/Support/CommandLine.h> #include <mlir/Dialect/AffineOps/AffineOps.h>
#include <llvm/Support/Debug.h> #include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/LoopOps/LoopOps.h>
#include <mlir/Dialect/StandardOps/Ops.h>
#include <mlir/Dialect/VectorOps/VectorOps.h>
#include <mlir/IR/Dialect.h> #include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h> #include <mlir/IR/MLIRContext.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/LocationSnapshot.h>
#include <mlir/Transforms/Passes.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/Debug.h>
using namespace mlir;
static llvm::cl::opt<bool> clPrintIRAfterAll( static llvm::cl::opt<bool> clPrintIRAfterAll(
"ngraph-print-ir-after-all", "ngraph-print-ir-after-all",
...@@ -35,15 +46,47 @@ static llvm::cl::opt<bool> clPrintIRAfterAll( ...@@ -35,15 +46,47 @@ static llvm::cl::opt<bool> clPrintIRAfterAll(
void ngraph::runtime::ngmlir::initializeNGraphMLIR() void ngraph::runtime::ngmlir::initializeNGraphMLIR()
{ {
// Initialize a dialect only once. // Initialize MLIR dialects and passes only once.
// We currently have no way to query if a dialect is previously static bool init_once = []() {
// registered. So using a global flag instead. // In-tree Dialects.
static bool init = false; registerDialect<AffineOpsDialect>();
if (!init) registerDialect<LLVM::LLVMDialect>();
{ registerDialect<loop::LoopOpsDialect>();
mlir::registerDialect<mlir::NGraphOpsDialect>(); registerDialect<StandardOpsDialect>();
init = true; registerDialect<vector::VectorOpsDialect>();
}
// nGraph dialects.
registerDialect<mlir::NGraphOpsDialect>();
// In-tree passes.
// No-op to avoid DCE on the following pass initializations.
if (std::getenv("bar") != (char*)-1)
return false;
createCanonicalizerPass();
createCSEPass();
createVectorizePass({});
createLoopUnrollPass();
createLoopUnrollAndJamPass();
createSimplifyAffineStructuresPass();
createLoopFusionPass();
createLoopInvariantCodeMotionPass();
createAffineLoopInvariantCodeMotionPass();
createPipelineDataTransferPass();
createLowerAffinePass();
createLoopTilingPass(0);
createLoopCoalescingPass();
createAffineDataCopyGenerationPass(0, 0);
createMemRefDataFlowOptPass();
createStripDebugInfoPass();
createPrintOpStatsPass();
createInlinerPass();
createSymbolDCEPass();
createLocationSnapshotPass({});
return true;
}();
(void)init_once;
} }
void ngraph::runtime::ngmlir::dumpMlirModule(const std::string msg, mlir::ModuleOp module) void ngraph::runtime::ngmlir::dumpMlirModule(const std::string msg, mlir::ModuleOp module)
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32> // CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> { func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> {
%0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32> %0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32>
"ng.return"(%0) : (!ng.tensor<2x3xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x3xf32>) -> ()
...@@ -26,7 +26,7 @@ func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) - ...@@ -26,7 +26,7 @@ func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -
// CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32> // CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32>
// CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32> // CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> { func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> {
%0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> %0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32>
"ng.return"(%0) : (!ng.tensor<3x4xf32>) -> () "ng.return"(%0) : (!ng.tensor<3x4xf32>) -> ()
...@@ -41,7 +41,7 @@ func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: ...@@ -41,7 +41,7 @@ func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2:
// CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32> // CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32> // CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> { func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> {
%0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> %0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32>
"ng.return"(%0) : (!ng.tensor<2x2xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x2xf32>) -> ()
...@@ -55,7 +55,7 @@ func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> ! ...@@ -55,7 +55,7 @@ func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !
// CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64 // CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> { func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> {
%0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> %0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> ()
...@@ -69,7 +69,7 @@ func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> ...@@ -69,7 +69,7 @@ func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64 // CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> { func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> {
%0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> %0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> ()
...@@ -83,7 +83,7 @@ func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3 ...@@ -83,7 +83,7 @@ func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3
// CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64 // CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> { func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> {
%0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> %0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32>
"ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> () "ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> ()
...@@ -98,7 +98,7 @@ func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x ...@@ -98,7 +98,7 @@ func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x
// CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32> // CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64 // CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> { func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> {
%0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> %0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32>
"ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> ()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment