// Copyright (C) 2018-2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

/**
 * @brief a header file with common functions for graph transformation
 * @file graph_transformer.h
 */

#pragma once

#include <map>
#include <vector>
#include <string>
#include <ie_icnn_network.hpp>
#include <details/caseless.hpp>
#include "cnn_network_impl.hpp"

namespace InferenceEngine {

/**
 * @brief TBD
 */
class INFERENCE_ENGINE_API_CLASS(ConstTransformer) {
public:
    explicit ConstTransformer(details::CNNNetworkImpl* _network) {
        if (!_network) THROW_IE_EXCEPTION << "[ERROR]: Failed to init ConstTransformer with null pointer of network";
        network = _network;
        cnnNetwork = CNNNetwork(network);
    }

    /**
     * @brief calculates const layers, combines const subgraph into a single const layers
     */
    void foldConstSubgraphs();

    /**
      * @brief folds Const Subgraphs and removes second input of Reshape-like layers (Interp, Gather, Resample, ...)
      */
    void fullTrim();

protected:
    /**
     * @brief collect all const layers with marking if it defines shape (1 - for shape, 0 - otherwise)
     */
    virtual const std::map<std::string, bool> getConstLayers(const std::vector<CNNLayerPtr>& sortedLayers);

    /**
     * @brief TBD
     */
    virtual const BlobMap
        getConstData(const std::map<std::string, bool>& constLayers, const std::vector<CNNLayerPtr>& sortedLayers);

    /**
     * @brief TBD
     */
    virtual std::vector<std::string>
    foldConstSubgraphsInternal(const std::map<std::string, bool>& constLayers, const BlobMap& constData,
                               const std::vector<CNNLayerPtr>& sortedLayers);

    /**
     * @brief TBD
     */
    virtual void trimShapeInputs(const std::vector<std::string>& constLayers);

private:
    const details::caseless_set<std::string> shapeTaking = {"Reshape", "Resample", "Interp"};
    details::CNNNetworkImpl* network;
    CNNNetwork cnnNetwork;
};

}  // namespace InferenceEngine