// Copyright (C) 2018 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>
#include <vector>

#include "inference_engine.hpp"
#include "mkldnn_dims.h"
#include <mkldnn.hpp>
#include <string>
#include <mkldnn_types.h>
#include <functional>

namespace MKLDNNPlugin {

class MKLDNNMemoryDesc {
public:
    MKLDNNMemoryDesc(): desc({}, mkldnn::memory::data_type::f32, mkldnn::memory::format::format_undef) {}
    explicit MKLDNNMemoryDesc(const InferenceEngine::TensorDesc& tDesc);
    explicit MKLDNNMemoryDesc(const mkldnn::memory::desc& desc): desc(desc) {}
    MKLDNNMemoryDesc(mkldnn::memory::dims dims, mkldnn::memory::data_type dataType, mkldnn::memory::format format);

    mkldnn::memory::format getFormat() const {
        return static_cast<mkldnn::memory::format>(desc.data.format);
    }

    mkldnn::memory::data_type getDataType() const {
        return static_cast<mkldnn::memory::data_type>(desc.data.data_type);
    }

    MKLDNNDims getDims() const {
        return MKLDNNDims(desc.data.dims, desc.data.ndims);
    }

    bool blocksExtended() const;
    operator bool() const {
        return getFormat() != mkldnn::memory::format::any && getFormat() != mkldnn::memory::format::format_undef;
    }

    bool operator == (const MKLDNNMemoryDesc& rhs) const;
    bool operator != (const MKLDNNMemoryDesc& rhs) const;

    operator mkldnn::memory::desc() const;
    operator InferenceEngine::TensorDesc() const;

private:
    mkldnn::memory::desc desc;
};


class MKLDNNMemory;

using MKLDNNMemoryPtr = std::shared_ptr<MKLDNNMemory>;

class MKLDNNMemory {
public:
    explicit MKLDNNMemory(const mkldnn::engine& eng);

    const mkldnn::memory& GetPrimitive() const {
        return *prim;
    }

    const std::shared_ptr<mkldnn::memory>& GetPrimitivePtr() const {
        return prim;
    }

    mkldnn::memory::desc GetDescriptor() const {
        return prim->get_primitive_desc().desc();
    }

    mkldnn::memory::primitive_desc GetPrimitiveDescriptor() const {
        return prim->get_primitive_desc();
    }

    void* GetData() const {
        return prim->get_data_handle();
    }

    mkldnn::memory::data_type GetDataType() const {
        return static_cast<mkldnn::memory::data_type>(GetDescriptor().data.data_type);
    }

    size_t GetSize() const;

    mkldnn::memory::format GetFormat() const {
        return static_cast<mkldnn::memory::format>(prim->get_primitive_desc().desc().data.format);
    }

    mkldnn::memory::dims GetDims() const {
        auto data = GetDescriptor().data;

        return std::vector<int>(data.dims, data.dims + data.ndims);
    }

    void Create(mkldnn::memory::dims dims, mkldnn::memory::data_type data_type, mkldnn::memory::format format,
                const void* data = nullptr);

    void Create(const mkldnn::memory::desc& desc, const void* data = nullptr);

    void SetData(mkldnn::memory::data_type dataType, mkldnn::memory::format format, const void* data, size_t size, bool ftz = true) const;
    void SetData(const MKLDNNMemory& memory, bool ftz = true) const;

    void FillZero();

    static bool IsPlainFormat(mkldnn::memory::format format);
    static mkldnn::memory::format GetPlainFormat(mkldnn::memory::dims dims);
    static bool isConsistant(mkldnn::memory::dims dims, mkldnn::memory::format format);
    static mkldnn::memory::format Convert(const InferenceEngine::Layout layout);

    static std::string formatToString(mkldnn::memory::format fmt);

    static void CreateBlockingDesc(mkldnn::memory::desc& desc);

private:
    std::shared_ptr<mkldnn::memory> prim;
    mkldnn::engine eng;
};


}  // namespace MKLDNNPlugin