"""
 Copyright (c) 2018-2019 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import numpy as np

from mo.front.common.partial_infer.split import tf_split_infer
from mo.front.tf.extractors.concat import tf_concat_ext
from mo.front.tf.extractors.const import tf_const_ext
from mo.front.tf.extractors.eltwise import make_tf_eltwise
from mo.front.tf.extractors.fused_bn import tf_fused_bn_extractor
from mo.front.tf.extractors.lrn import tf_lrn_ext
from mo.front.tf.extractors.matmul import tf_matmul_ext, tf_batchmatmul_ext
from mo.front.tf.extractors.native_tf import native_tf_node_extractor
from mo.front.tf.extractors.pack import tf_pack_ext
from mo.front.tf.extractors.random_uniform import tf_random_uniform_ext
from mo.front.tf.extractors.space_to_batch import tf_space_to_batch_ext, tf_batch_to_space_ext
from mo.front.tf.extractors.split import tf_split_ext
from mo.front.tf.extractors.unpack import tf_unpack_ext
from mo.front.tf.extractors.utils import get_tf_node_port
from mo.graph.graph import Node


def get_tf_edges(node: Node):
    """
    By TF/NX node find all inputs and return list of all edges.
    Edge direction represents data flow (from source op to this node).
    So the resulting list contains all input edges for a given node.
    Edge attributes: 'in' is index of input port for a given node, 'out' is an index
    of output port of some other node that produces input data for this node.
    """
    edge_list = []
    for in_port, src_node_id in enumerate(node.pb.input):
        src_node, src_port = get_tf_node_port(src_node_id)
        cf_flag = False
        if src_node[0] == '^':
            src_node = src_node[1:]
            cf_flag = True
        edge = (src_node, node.id, {
            'in': in_port,
            'out': src_port,
            'fw_tensor_debug_info': [(src_node_id, src_port)],  # debug anchor for a framework tensor name and port
            'in_attrs': ['in', 'control_flow_edge', 'permutation'],
            'out_attrs': ['out', 'permutation'],
            'data_attrs': ['fw_tensor_debug_info'],
            'control_flow_edge': cf_flag
        })
        edge_list.append(edge)
    return edge_list


def node_pb_arg(pb_extractor: callable):
    return lambda node: pb_extractor(node.pb)


tf_op_extractors = {
    'TFCustomSubgraphCall': node_pb_arg(lambda pb: None),
    'LRN': node_pb_arg(tf_lrn_ext),
    'Split': node_pb_arg(lambda pb: tf_split_ext(pb, tf_split_infer)),
    'FusedBatchNorm': node_pb_arg(tf_fused_bn_extractor),
    'ConcatV2': node_pb_arg(tf_concat_ext),
    'MatMul': node_pb_arg(tf_matmul_ext),
    'BatchMatMul': node_pb_arg(tf_batchmatmul_ext),
    'BatchMatMulV2': node_pb_arg(tf_batchmatmul_ext),
    'Pack': node_pb_arg(tf_pack_ext),
    'Unpack': node_pb_arg(tf_unpack_ext),
    'Const': node_pb_arg(tf_const_ext),
    'Identity': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
    'RandomUniform': node_pb_arg(tf_random_uniform_ext),
    'SpaceToBatchND': node_pb_arg(tf_space_to_batch_ext),
    'BatchToSpaceND': node_pb_arg(tf_batch_to_space_ext),
    'ReadVariableOp': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
    'PlaceholderWithDefault': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True}))
}


def common_tf_fields(node: Node):
    return {
        'kind': 'op',
        'name': node.pb.name,
        'op': node.pb.op,
        'precision': 'FP32'  # TODO use real precision derived from the model
    }


def tf_op_extractor(node: Node, lowered_keys_map: dict):
    # all required attributes for the 'TFCustomSubgraphCall' are set during their initialization
    if (node.has('op') and node.op == 'TFCustomSubgraphCall') or (not node.has_valid('pb')):
        return True, node.graph.node[node.id]

    result = common_tf_fields(node)
    node.graph.node[node.id].update(result)
    supported = False
    op = result['op'].lower()
    if op in lowered_keys_map:
        op = lowered_keys_map[op]
        assert op in tf_op_extractors
        attrs = tf_op_extractors[op](node)
        if attrs:
            result.update(attrs)
            supported = True
    new_attrs = native_tf_node_extractor(node.pb)
    new_attrs.update(result)
    result = new_attrs
    return supported, result