Commit 28551264 authored by nmostafa's avatar nmostafa

Fix JIT invocation ABI

parent 502a880d
......@@ -20,8 +20,8 @@ set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
set(MLIR_REPO_URL https://github.com/tensorflow/mlir.git)
# Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID ebaa3eb1)
set(MLIR_COMMIT_ID 9a856bce)
set(MLIR_LLVM_COMMIT_ID 0845ac7331e)
set(MLIR_COMMIT_ID 1f7893e0)
# MLIR environment variables. Some of them are used by LIT tool.
set(MLIR_PROJECT_ROOT ${CMAKE_CURRENT_BINARY_DIR}/mlir_project)
......
......@@ -58,7 +58,7 @@
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Target/TargetMachine.h>
#include <mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h>
#include <mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h>
#include <mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
......@@ -736,11 +736,10 @@ void MLIRCompiler::bindArguments(std::vector<void*>& externalTensors)
m_externalTensors = &externalTensors;
// Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemrefArgs', which creates a
// SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the
// actual pointer to the data.
// We currently use 'allocateMemrefArgs', which creates the arguments list per call ABI (see
// comment below).
// StaticFloatMemref is just a struct with the actual pointer to the data.
// create MemRef args
auto expectedArguments = allocateMemrefArgs();
NGRAPH_CHECK(expectedArguments.size(), "Arguments can't be created");
m_invokeArgs = std::move(expectedArguments);
......@@ -751,7 +750,8 @@ void MLIRCompiler::bindArguments(std::vector<void*>& externalTensors)
// Assign external tensor pointers to invocation arguments.
for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i)
{
((mlir::StaticFloatMemRef*)m_invokeArgs[i])->data = (float*)(*m_externalTensors)[i];
auto* memRefArg = *((mlir::StaticFloatMemRef**)m_invokeArgs[i]);
memRefArg->data = (float*)(*m_externalTensors)[i];
}
}
......@@ -787,13 +787,23 @@ void MLIRCompiler::cleanup()
}
}
// The current call ABI takes a single arg pointer (argPtr) pointing to a list of args.
// Each arg is a pointer to a StaticFloatMemRef which contains a data pointer
//
// The args are laid out as follows
// argPtr-> arg[0]-> StaticFloatMemRef -> <data>
// arg[1]-> StaticFloatMemRef -> <data>
// ...
SmallVector<void*, 8> MLIRCompiler::allocateMemrefArgs()
{
SmallVector<void*, 8> args;
for (auto i = 0; i < m_externalTensors->size(); i++)
{
auto descriptor = allocateMemrefDescriptor();
args.push_back(descriptor);
mlir::StaticFloatMemRef** arg =
reinterpret_cast<mlir::StaticFloatMemRef**>(malloc(sizeof(mlir::StaticFloatMemRef*)));
*arg = descriptor;
args.push_back(arg);
}
return args;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment