// Copyright (C) 2018 Intel Corporation
//
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "perf_count.h"
#include <vector>
#include <utility>
#include <mkldnn_types.h>
#include <ie_common.h>
#include <mkldnn.hpp>

namespace MKLDNNPlugin {

class MKLDNNDims {
public:
    MKLDNNDims() = default;

    explicit MKLDNNDims(const InferenceEngine::SizeVector& size) {
        dims = std::vector<int>(size.begin(), size.end());
    }

    explicit MKLDNNDims(const std::vector<int>& dim) {
        dims = dim;
    }

    MKLDNNDims(const mkldnn_dims_t dnn_dims, int dnn_ndims) {
        dims = std::vector<int>(dnn_dims, dnn_dims + dnn_ndims);
    }

    explicit MKLDNNDims(std::initializer_list<int> ilist) : dims(ilist) {}
    explicit MKLDNNDims(std::initializer_list<size_t > ilist) : dims(ilist.begin(), ilist.end()) {}

    InferenceEngine::SizeVector ToSizeVector() const {
        InferenceEngine::SizeVector size;
        for (auto i : dims) {
            size.push_back(i);
        }

        return size;
    }

    int ndims() const {
        return dims.size();
    }

    int size() const {
        return size(0);
    }

    int size(int start) const {
        int size = 1;

        for (int i = start; i < dims.size(); i++) {
            size *= dims[i];
        }

        return size;
    }

    void push_back(int val) {
        dims.push_back(val);
    }

    operator mkldnn::memory::dims() const {
        return dims;
    }

    bool operator == (const MKLDNNDims& rhs) {
        if (dims.size() != rhs.dims.size()) {
            return false;
        }

        return std::equal(rhs.dims.begin(), rhs.dims.end(), dims.begin());
    }

    bool operator != (const MKLDNNDims& rhs) {
        return !(*this == rhs);
    }

    int& operator[](int idx) {
        return dims[idx];
    }

    int operator[](int idx) const {
        return dims[idx];
    }

private:
    std::vector<int> dims;
};

}  // namespace MKLDNNPlugin