Commit 28551264 authored by nmostafa's avatar nmostafa

Fix JIT invocation ABI

parent 502a880d
...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git) ...@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git) set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions # Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID ebaa3eb1) set(MLIR_LLVM_COMMIT_ID 0845ac7331e)
set(MLIR_COMMIT_ID 9a856bce) set(MLIR_COMMIT_ID 1f7893e0)
# MLIR environment variables. Some of them are used by LIT tool. # MLIR environment variables. Some of them are used by LIT tool.
set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project) set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project)
......
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetSelect.h> #include <llvm/Support/TargetSelect.h>
#include <llvm/Target/TargetMachine.h> #include <llvm/Target/TargetMachine.h>
#include <mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h> #include <mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h> #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h> #include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h> #include <mlir/Dialect/LLVMIR/LLVMDialect.h>
...@@ -736,11 +736,10 @@ void MLIRCompiler::bindArguments(std::vector<void*>& externalTensors) ...@@ -736,11 +736,10 @@ void MLIRCompiler::bindArguments(std::vector<void*>& externalTensors)
m_externalTensors = &externalTensors; m_externalTensors = &externalTensors;
// Create list with a type-erased double pointer for each invocation arguments. // Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemrefArgs', which creates a // We currently use 'allocateMemrefArgs', which creates the arguments list per call ABI (see
// SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the // comment below).
// actual pointer to the data. // StaticFloatMemref is just a struct with the actual pointer to the data.
// create MemRef args
auto expectedArguments = allocateMemrefArgs(); auto expectedArguments = allocateMemrefArgs();
NGRAPH_CHECK(expectedArguments.size(), "Arguments can't be created"); NGRAPH_CHECK(expectedArguments.size(), "Arguments can't be created");
m_invokeArgs = std::move(expectedArguments); m_invokeArgs = std::move(expectedArguments);
...@@ -751,7 +750,8 @@ void MLIRCompiler::bindArguments(std::vector<void*>& externalTensors) ...@@ -751,7 +750,8 @@ void MLIRCompiler::bindArguments(std::vector<void*>& externalTensors)
// Assign external tensor pointers to invocation arguments. // Assign external tensor pointers to invocation arguments.
for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i) for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i)
{ {
((mlir::StaticFloatMemRef*)m_invokeArgs[i])->data = (float*)(*m_externalTensors)[i]; auto* memRefArg = *((mlir::StaticFloatMemRef**)m_invokeArgs[i]);
memRefArg->data = (float*)(*m_externalTensors)[i];
} }
} }
...@@ -787,13 +787,23 @@ void MLIRCompiler::cleanup() ...@@ -787,13 +787,23 @@ void MLIRCompiler::cleanup()
} }
} }
// The current call ABI takes a single arg pointer (argPtr) pointing to a list of args.
// Each arg is a pointer to a StaticFloatMemRef which contains a data pointer
//
// The args are laid out as follows
// argPtr-> arg[0]-> StaticFloatMemRef -> <data>
// arg[1]-> StaticFloatMemRef -> <data>
// ...
SmallVector<void*, 8> MLIRCompiler::allocateMemrefArgs() SmallVector<void*, 8> MLIRCompiler::allocateMemrefArgs()
{ {
SmallVector<void*, 8> args; SmallVector<void*, 8> args;
for (auto i = 0; i < m_externalTensors->size(); i++) for (auto i = 0; i < m_externalTensors->size(); i++)
{ {
auto descriptor = allocateMemrefDescriptor(); auto descriptor = allocateMemrefDescriptor();
args.push_back(descriptor); mlir::StaticFloatMemRef** arg =
reinterpret_cast<mlir::StaticFloatMemRef**>(malloc(sizeof(mlir::StaticFloatMemRef*)));
*arg = descriptor;
args.push_back(arg);
} }
return args; return args;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment