compiler.cpp 15.7 KB
Newer Older
1
//*****************************************************************************
nmostafa's avatar
nmostafa committed
2
// Copyright 2017-2019 Intel Corporation
3 4 5 6 7 8 9 10 11 12 13 14 15
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
16

17
#include "compiler.hpp"
18

19 20 21 22 23 24
#include "dialect/dialect.hpp"
#include "dialect/ops.hpp"
#include "dialect/type.hpp"
#include "lowerer.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/graph_util.hpp"
25
#include "ngraph/node.hpp"
26 27 28 29 30
#include "ngraph/op/add.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/type/element_type.hpp"

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
#include <llvm/ADT/STLExtras.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/ErrorOr.h>
#include <llvm/Support/MemoryBuffer.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/LLVMIR/LLVMDialect.h>
#include <mlir/LLVMIR/Transforms.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Target/LLVMIR.h>
#include <mlir/Transforms/DialectConversion.h>
#include <mlir/Transforms/Passes.h>
46 47

#include <memory>
48
#include <mutex>
49

50
using llvm::SmallVector;
51
using llvm::StringRef;
52
using llvm::make_unique;
53
using namespace ngraph::runtime::ngmlir;
54 55 56 57

#define COMPILE_OP_DECL(op_name)                                                                   \
    create_op<op_name>(MLIRCompiler & compiler, const ngraph::Node* ng_node)

58 59 60 61
MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
                           const std::vector<void*>& external_tensors)
    : m_compiled_kernel(compiled_kernel)
    , m_external_tensors(external_tensors)
62
{
63 64 65
    NGRAPH_CHECK((m_compiled_kernel->get_arguments().size() +
                  m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size(),
                 "Number of arguments and outputs doesn't match number of tensors");
66 67 68 69
}

void MLIRCompiler::init_mlir()
{
70 71 72 73 74 75 76 77 78 79 80 81 82
    // Mutex to safely initialize MLIR.
    static std::mutex mlir_init_mutex;
    static bool initialized = false;

    std::unique_lock<std::mutex> lock(mlir_init_mutex);

    if (!initialized)
    {
        mlir::registerDialect<mlir::NGDialect>();
        // Register any LLVM command line options
        llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
        initialized = true;
    }
83 84 85 86
}

void MLIRCompiler::compile_and_run()
{
87 88
    build_ng_dialect_module();
    lower_ng_dialect();
89 90 91 92 93 94
    optimize();
    bind_arguments();
    execute();
    cleanup();
}

95 96
// Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel.
void MLIRCompiler::build_ng_dialect_module()
97 98 99 100 101 102 103 104 105
{
    // initialize an empty module
    m_module = make_unique<mlir::Module>(&m_context);

    TypeList args_type_list, result_type_list;

    // Retrieve input and output tensors.
    const auto& kernel_inputs = m_compiled_kernel->get_arguments();
    const auto& kernel_outputs = m_compiled_kernel->get_kernel_outputs();
106 107
    NGRAPH_CHECK(kernel_inputs.size() != 0, "Cannot have empty inputs list");
    NGRAPH_CHECK(kernel_outputs.size() != 0, "Cannot have empty outputs list");
108 109

    for (auto input : kernel_inputs)
110
    {
111
        args_type_list.push_back(get_mlir_type(input->get_output_tensor_ptr().get()));
112 113
    }

114
    for (auto output : kernel_outputs)
115
    {
116
        result_type_list.push_back(get_mlir_type(output->get_output_tensor_ptr().get()));
117
    }
Nagy Mostafa's avatar
Nagy Mostafa committed
118

119 120 121 122 123 124 125 126
    auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context);
    auto function =
        make_unique<mlir::Function>(mlir::UnknownLoc::get(&m_context), "main", func_type);
    function->addEntryBlock();

    // populate Tensor->Value maps
    int i = 0;
    for (auto input : kernel_inputs)
127
    {
128 129 130 131 132
        mlir::Value* arg = function->getArgument(i);
        TensorInfo tensor_info{arg};
        m_tensor_to_value_map.insert(
            TensorToInfo(input->get_output_tensor_ptr().get(), tensor_info));
        i++;
133 134
    }

135 136 137 138 139
    // create builder
    m_builder = llvm::make_unique<mlir::FuncBuilder>(function.get());
    build_ng_dialect();
    m_module->getFunctions().push_back(function.release());
    if (failed(m_module->verify()))
140
    {
141
        NGRAPH_CHECK(false, "Invalid module after lowering to NG dialect");
142
    }
143 144

    dump_mlir_module("nGraph Dialect Dump:");
145
}
146

147 148
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
149 150 151 152 153 154
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
    SmallVector<int64_t, 4> shape;
    for (auto d : tensor->get_shape())
    {
        shape.push_back(d);
155 156
    }

157 158 159
    return mlir::NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape);
}

160
// Converts an nGraph element type into an MLIR type.
161 162
mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
{
nmostafa's avatar
nmostafa committed
163 164 165 166 167 168
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif

169
    switch (type.get_type_enum())
170
    {
171 172
    case ngraph::element::Type_t::undefined:
    case ngraph::element::Type_t::dynamic:
173
    default: NGRAPH_CHECK(false, "MLIR: Unsupported NGraph types"); break;
174
    case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(&m_context);
nmostafa's avatar
nmostafa committed
175
    case ngraph::element::Type_t::f16: return mlir::NGFloatType::getF16(&m_context);
176 177 178 179 180 181 182 183 184 185 186
    case ngraph::element::Type_t::f32: return mlir::NGFloatType::getF32(&m_context);
    case ngraph::element::Type_t::f64: return mlir::NGFloatType::getF64(&m_context);
    case ngraph::element::Type_t::i8: return mlir::NGIntegerType::getInt8(&m_context);
    case ngraph::element::Type_t::u8:
    case ngraph::element::Type_t::boolean: return mlir::NGIntegerType::getUInt8(&m_context);
    case ngraph::element::Type_t::i16: return mlir::NGIntegerType::getInt16(&m_context);
    case ngraph::element::Type_t::u16: return mlir::NGIntegerType::getInt16(&m_context);
    case ngraph::element::Type_t::i32: return mlir::NGIntegerType::getInt32(&m_context);
    case ngraph::element::Type_t::u32: return mlir::NGIntegerType::getUInt32(&m_context);
    case ngraph::element::Type_t::i64: return mlir::NGIntegerType::getInt64(&m_context);
    case ngraph::element::Type_t::u64: return mlir::NGIntegerType::getUInt64(&m_context);
187
    }
188
    NGRAPH_CHECK(false, "Unreachable");
189
    return mlir::Type();
nmostafa's avatar
nmostafa committed
190 191 192 193

#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
194
}
195

196 197
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value)
{
198 199
    NGRAPH_CHECK(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end(),
                 "tensor value already defined");
200 201 202
    TensorInfo tensor_info{value};
    m_tensor_to_value_map.insert(TensorToInfo(tensor, tensor_info));
}
203

204 205 206
MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tensor)
{
    auto it = m_tensor_to_value_map.find(tensor);
207

208
    NGRAPH_CHECK(it != m_tensor_to_value_map.end(), "Undefined tensor");
209

210 211
    return it->second;
}
212

213 214
// Lowers nGraph dialect to affine dialect.
void MLIRCompiler::lower_ng_dialect()
215 216 217 218 219 220 221 222
{
    mlir::PassManager pm;
    pm.addPass(mlir::createDialectLoweringPass(this));
    pm.addPass(mlir::createCanonicalizerPass());

    pm.run(m_module.get());

    if (failed(m_module->verify()))
223
    {
224
        NGRAPH_CHECK(false, "Incorrect module after dialect lowering");
225
    }
226 227

    dump_mlir_module("Affine Dialect Dump:");
228 229
}

230 231 232
// Receives affine dialect as input and applies affine and standard dialect based optimizations.
// Lowering from affine dialect to standard dialect happens along the way. Output consists of
// standard dialect only ops.
233 234 235 236 237 238
void MLIRCompiler::optimize()
{
    mlir::PassManager pm;
    // Lower affine ops
    pm.addPass(mlir::createLowerAffinePass());
    auto rr = pm.run(m_module.get());
239
    NGRAPH_CHECK(succeeded(rr), "Affine loop lowering failed");
240 241

    dump_mlir_module("Standard Dialect Dump:");
242
}
243 244 245 246

// MLIR builders
#define TI(x) std::type_index(typeid(x))

247 248 249 250
void MLIRCompiler::build_ng_dialect()
{
    const NodeVector& sub_graph = m_compiled_kernel->get_node_list();

251
    for (auto np : sub_graph)
252
    {
253 254 255 256 257 258 259 260 261 262 263 264
        auto it = op_dispatcher.find(TI(*np));
        if (it == op_dispatcher.end())
        {
            throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
                                 np->description() + "' operation"};
        }
        mlir::Value* mlir_value = it->second(*this, np.get());
        // builders that have multiple result values will update the value map, and set their ret values to null
        if (mlir_value)
        {
            update_tensor_value(np->get_output_tensor_ptr().get(), mlir_value);
        }
265
    }
266 267
    create_return();
}
268

269
namespace ngraph
270
{
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
    namespace runtime
    {
        namespace ngmlir
        {
            template <>
            mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
            {
                return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
            }

            template <>
            mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
            {
                return compiler.create_binary_op<mlir::NGDotOp>(ng_node);
            }
        }
    }
288
}
289

290
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
291 292 293
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>},
#include "ops_supported.inc"
};
294

295 296 297 298 299 300 301
template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
{
    auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
    auto rhs = ng_node->get_argument(1)->get_output_tensor_ptr();
    auto lhs_v = get_tensor_value(lhs.get()).m_value;
    auto rhs_v = get_tensor_value(rhs.get()).m_value;
302 303 304
    auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
    return m_builder->create<BinOp>(mlir::UnknownLoc::get(&m_context), res_type, lhs_v, rhs_v)
        .getResult();
305
}
306

307 308 309 310
void MLIRCompiler::create_return()
{
    std::vector<mlir::Value*> value_list;
    for (auto output : m_compiled_kernel->get_kernel_outputs())
311
    {
312
        value_list.push_back(get_tensor_value(output->get_output_tensor_ptr().get()).m_value);
313
    }
314 315
    m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
316

317 318
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
319 320
void MLIRCompiler::bind_arguments()
{
321
    NGRAPH_CHECK(m_module, "MLIR module is not ready.");
322 323

    mlir::Function* func = m_module->getNamedFunction("main");
324
    NGRAPH_CHECK(func && !func->getBlocks().empty(), "Function not found");
325 326 327 328 329 330 331 332

    // Create list with a type-erased double pointer for each invocation arguments.
    // We currently use 'allocateMemRefArguments', which creates a
    // SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the
    // actual pointer to the data.

    // create MemRef args
    auto expected_arguments = allocate_memref_args(func);
333
    NGRAPH_CHECK(expected_arguments.size(), "Arguments can't be created");
334 335
    m_invoke_args = std::move(expected_arguments);

336 337
    NGRAPH_CHECK(m_invoke_args.size() == m_external_tensors.size(),
                 "Number of external tensors doesn't match number of function arguments");
338 339 340

    // Assign external tensor pointers to invocation arguments.
    for (size_t i = 0, num_args = m_invoke_args.size(); i < num_args; ++i)
341
    {
342
        ((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)m_external_tensors[i];
Nagy Mostafa's avatar
Nagy Mostafa committed
343 344
    }

345 346 347 348
    // Add pointer to memory manager
    // malloc here since that's what allocateMemRefArguments use
    // TODO (nmostafa): Better way of doing this ? Use builder allocator ?
    MLIRMemMgr** mem_mgr_arg = reinterpret_cast<MLIRMemMgr**>(malloc(sizeof(void*)));
nmostafa's avatar
nmostafa committed
349
    NGRAPH_CHECK(mem_mgr_arg != nullptr);
350 351
    *mem_mgr_arg = &get_mem_mgr();
    // inserting memory manager ptr in right location ?
352
    NGRAPH_CHECK(m_invoke_args.size() == get_mem_mgr_arg_id(func));
353 354 355
    m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
}

356
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
357 358
void MLIRCompiler::execute()
{
359
    NGRAPH_CHECK(m_module, "MLIR module is not ready.");
360 361 362 363 364

    // Lower Standard dialect to LLVM dialect.
    auto converter = mlir::createStdToLLVMConverter();
    auto r = converter->convert(m_module.get());
    (void)r;
365
    NGRAPH_CHECK(succeeded(r), "second conversion failed");
366

367 368
    dump_mlir_module("LLVM-IR Dialect Dump:");

369 370 371 372 373 374 375 376
    // Initialize LLVM targets.
    llvm::InitializeNativeTarget();
    llvm::InitializeNativeTargetAsmPrinter();

    // Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
    // don't run MLIR passes that were already run. We also pass a default transformer to run
    // LLVM optimizations at level 3.
    auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/);
377
    auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer);
378
    NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
379 380 381 382 383 384 385
    m_engine = std::move(maybeEngine.get());

    // Invoke the JIT-compiled function with the arguments. Note that, for API
    // uniformity reasons, it takes a list of type-erased pointers to arguments.
    // Please, note that 'invoke' method is overloaded with a parameter pack version.
    // Make sure the MutableArrayRef version is invoked.
    auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invoke_args));
386
    NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n");
387 388 389 390 391 392
}

void MLIRCompiler::cleanup()
{
    // Free void double pointer arguments without freeing external tensor data.
    for (auto* arg : m_invoke_args)
Nagy Mostafa's avatar
Nagy Mostafa committed
393
    {
394
        free(arg);
Nagy Mostafa's avatar
Nagy Mostafa committed
395 396
    }

397 398
    // Free MLIR function builder.
    if (m_builder)
nmostafa's avatar
nmostafa committed
399
    {
400
        m_builder.reset(nullptr);
nmostafa's avatar
nmostafa committed
401
    }
402 403 404 405 406 407 408 409 410 411

    // Free allocated memory for JIT'ed code temps
    m_mem_mgr.freeAll();
}

SmallVector<void*, 8> MLIRCompiler::allocate_memref_args(mlir::Function* func)
{
    SmallVector<void*, 8> args;
    args.reserve(func->getNumArguments());
    for (const auto& arg : func->getArguments())
Nagy Mostafa's avatar
Nagy Mostafa committed
412
    {
413 414 415
        auto descriptor = allocate_memref_descriptor(arg->getType());

        if (!descriptor)
nmostafa's avatar
nmostafa committed
416
        {
417
            continue;
nmostafa's avatar
nmostafa committed
418
        }
419
        args.push_back(descriptor);
420
    }
421 422 423 424 425 426 427
    return args;
}

mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type type)
{
    auto memRefType = type.dyn_cast<mlir::MemRefType>();
    if (!memRefType)
nmostafa's avatar
nmostafa committed
428
    {
429
        return nullptr;
nmostafa's avatar
nmostafa committed
430
    }
431
    NGRAPH_CHECK(memRefType.getNumDynamicDims() == 0, "No support for dynamic shapes");
432 433 434 435 436

    // We only use StaticFloatMemRef because that's what MLIR currently offers.
    // We should expand this with different types and dynamic MemRefs
    auto* descriptor =
        reinterpret_cast<mlir::StaticFloatMemRef*>(malloc(sizeof(mlir::StaticFloatMemRef)));
nmostafa's avatar
nmostafa committed
437
    NGRAPH_CHECK(descriptor != nullptr, "NULL MemRef descriptor");
438 439
    descriptor->data = nullptr;
    return descriptor;
440
}
441 442 443 444 445 446 447 448 449 450

void MLIRCompiler::dump_mlir_module(const std::string msg)
{
    if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
    {
        llvm::dbgs() << "*** " << msg << " ***\n";
        m_module->dump();
        llvm::dbgs() << "\n\n";
    }
}