cpu_runtime.hpp 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

// NOTE: This file follows nGraph format style.
// Follows nGraph naming convention for public APIs only, else MLIR naming convention.

#pragma once

#include <memory>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/Types.h>
#include "contrib/mlir/backend/backend.hpp"
#include "contrib/mlir/runtime/runtime.hpp"

namespace ngraph
{
    namespace runtime
    {
        namespace ngmlir
        {
36 37 38 39
            struct StaticMemRef
            {
                void* data;
            };
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
            /// A CPU Runtime is an MLIR runtime that owns an MLIR context and a module
            /// The module should be in LLVM dialect and ready to be lowered via an MLIR
            /// ExecutionEngine. The runtime owns the context and must out-live any MLIR
            /// code Compilation and execution.
            class MLIRCPURuntime : public MLIRRuntime
            {
            public:
                /// Executes a pre-compiled subgraph
                void run(void* args) override;

            private:
                void run_internal(std::vector<void*>& externalTensors);
                // Bind external tensors to MLIR module entry point
                void bindArguments(std::vector<void*>& externalTensors);
                // Invokes an MLIR module entry point with bound arguments
                void execute();
                // Cleans up allocated args
                void cleanup();

                /// Helper to create memref arguments for MLIR function signature
                llvm::SmallVector<void*, 8> allocateMemrefArgs();

                /// Helper to allocate a mem ref object. Handles static shapes only for now.
63
                StaticMemRef* allocateMemrefDescriptor();
64 65 66 67 68 69 70 71 72 73 74

            private:
                // Pointers to externally allocated memory for sub-graph's input and output tensors.
                std::vector<void*>* m_externalTensors;
                // Arguments for the MLIR function generated for the nGraph sub-graph.
                llvm::SmallVector<void*, 8> m_invokeArgs;
                std::unique_ptr<mlir::ExecutionEngine> m_engine;
            };
        }
    }
}