//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "mlir_subgraph_extraction.hpp"

#include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"

using namespace ngraph::descriptor;
using namespace ngraph::op;
using namespace ngraph::pass;

#define TI(x) std::type_index(typeid(x))

bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
{
    // Create a CompiledKernel for all the ops in the function, except Parameters and Results.
    NodeVector ck_ops;
    for (auto op : func->get_ordered_ops())
    {
        // All ops must be supported by MLIR compiler
        if (!is_supported_mlir_op(op))
        {
            return false;
        }

        if (TI(Parameter) != TI(*op) && TI(Result) != TI(*op))
        {
            ck_ops.push_back(op);
        }
    }

    NodeVector ck_args;
    for (auto& param : func->get_parameters())
    {
        ck_args.push_back(param);
    }

    NodeVector ck_outputs = std::move(get_subgraph_outputs(ck_ops, {} /*exclusions*/));
    if (ck_outputs.size() != 1)
    {
        return false;
    }

    auto ck = std::make_shared<CompiledKernel>(ck_ops, ck_outputs, ck_args);

    // Connect CompiledKernel to output nodes by replacing the output descriptors of the output
    // nodes.
    for (size_t i = 0, end = ck_outputs.size(); i < end; ++i)
    {
        auto& output_descs = ck_outputs[i]->get_outputs();
        NGRAPH_CHECK(output_descs.size() == 1, "Unexpected multiple output descriptors");
        auto& out_desc = output_descs[0];

        // 'replace_output' invalidates iterator of the original container. Use a copy instead.
        const std::set<descriptor::Input*> input_descs = out_desc.get_inputs();

        for (descriptor::Input* in_desc : input_descs)
        {
            in_desc->replace_output(ck, i);
        }
    }

    return true;
}

#define TI(x) std::type_index(typeid(x))

bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node)
{
    if (TI(Parameter) == TI(*node) || TI(Result) == TI(*node))
    {
        return true;
    }

    // supported by backend ?
    if (m_supported_ops.find(TI(*node)) == m_supported_ops.end())
    {
        return false;
    }

    // check on invariants expected by MLIR backend

    // Dot is 2D only
    if (TI(ngraph::op::Dot) == TI(*node))
    {
        if (node->get_input_shape(0).size() != 2 || node->get_input_shape(1).size() != 2)
        {
            return false;
        }
    }

    if (TI(ngraph::op::ArgMin) == TI(*node) || TI(ngraph::op::ArgMax) == TI(*node))
    {
        // TODO: Remove this when MLIR has float point cmp support
        if (!node->input(0).get_element_type().is_integral())
            return false;
    }
    return true;
}

const std::set<std::type_index> MLIRSubgraphExtractionPass::m_supported_ops{
#define MLIR_OP(OP) TI(ngraph::op::OP),
#include "contrib/mlir/ops_supported.inc"
};