compiler.hpp 5.24 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once

Nagy Mostafa's avatar
Nagy Mostafa committed
18 19
#include "lowerer.hpp"
#include "memory_manager.hpp"
20 21 22
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp"

23 24 25 26 27 28 29 30 31 32 33
// TODO(dcab): Revisit and do fw decl when possible.
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/Types.h>
#include <mlir/StandardOps/Ops.h>

Nagy Mostafa's avatar
Nagy Mostafa committed
34 35 36 37 38
namespace mlir
{
    struct StaticFloatMemRef;
}

39 40
namespace ngraph
{
41 42 43 44
    namespace op
    {
        class CompiledKernel;
    }
45 46
    namespace runtime
    {
47
        namespace ngmlir
48 49 50
        {
            class MLIRCompiler
            {
Nagy Mostafa's avatar
Nagy Mostafa committed
51 52 53
            public:
                static void init_mlir();

54 55 56 57
            public:
                using TensorList = std::vector<descriptor::Tensor*>;
                using TypeList = llvm::SmallVector<mlir::Type, 4>;

58 59
                MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
                             const std::vector<void*>& external_tensors);
60

Nagy Mostafa's avatar
Nagy Mostafa committed
61
                /// Compiles and runs a subgraph in MLIR
62
                void compile_and_run();
63

Nagy Mostafa's avatar
Nagy Mostafa committed
64 65 66 67 68 69 70 71
                /// Returns the memory manager used by this sub-graph compiler
                MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; }
                /// Returns memory manager pointer argument ID in call interface
                unsigned get_mem_mgr_arg_id(mlir::Function* func)
                {
                    return func->getNumArguments() - 1;
                }

72 73 74 75 76 77 78 79 80 81
            private:
                struct TensorInfo
                {
                    mlir::Value* m_value; /* mlir value this tensor maps to */
                    // More info here ?
                };

            private:
                void build_module();
                void lower_dialect();
82
                void optimize();
Nagy Mostafa's avatar
Nagy Mostafa committed
83
                void bind_arguments();
84 85 86
                void execute();
                void cleanup();

87 88 89 90
                mlir::Type get_mlir_type(const descriptor::Tensor* tensor);
                mlir::Type get_mlir_type(const element::Type& type);
                TensorInfo get_tensor_value(descriptor::Tensor* tensor);
                void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value);
Nagy Mostafa's avatar
Nagy Mostafa committed
91

92 93 94 95 96 97 98 99 100 101 102
                void build_ng_dialect();

                template <typename OP>
                static mlir::Value* create_op(MLIRCompiler& compiler, const ngraph::Node* ng_node)
                {
                    throw std::runtime_error("Unimplemented op '" + ng_node->description() +
                                             "' in MLIR Compiler");
                }

                template <typename BinOp>
                mlir::Value* create_binary_op(const ngraph::Node* ng_node);
103

104 105
                void create_return();

Nagy Mostafa's avatar
Nagy Mostafa committed
106 107 108 109 110 111
                /// Helper to create memref arguments for MLIR function signature
                llvm::SmallVector<void*, 8> allocate_memref_args(mlir::Function* func);

                /// Helper to allocate a mem ref object. Handles static shapes only for now.
                mlir::StaticFloatMemRef* allocate_memref_descriptor(mlir::Type type);

112 113 114 115
            private:
                mlir::MLIRContext m_context;
                std::unique_ptr<mlir::Module> m_module;
                std::unique_ptr<mlir::FuncBuilder> m_builder;
116
                std::unique_ptr<mlir::ExecutionEngine> m_engine;
117 118 119 120 121 122 123

                using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>;
                using TensorToInfoMap = std::unordered_map<descriptor::Tensor*, TensorInfo>;
                using MLIRCompOpFunction =
                    std::function<mlir::Value*(MLIRCompiler& compiler, const ngraph::Node*)>;
                using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>;

124 125 126 127
                // Sub-graph to be compiled and executed with MLIR.
                const ngraph::op::CompiledKernel* m_compiled_kernel;

                // Pointers to externally allocated memory for sub-graph's input and output tensors.
128 129 130
                const std::vector<void*>& m_external_tensors;
                llvm::SmallVector<void*, 8> m_invoke_args;

131 132 133 134
                // Maps tensor to the value it represents in the IR
                // use for MLIR dialect gen
                TensorToInfoMap m_tensor_to_value_map;
                static const MLIRCompOpMap op_dispatcher;
Nagy Mostafa's avatar
Nagy Mostafa committed
135 136 137

                // Memory manager for temp allocations inside JIT'ed code
                MLIRMemMgr m_mem_mgr;
138 139 140 141
            };
        }
    }
}