import logging as log

import numpy as np

from extensions.back.InsertLayoutPropagationTransposes import is_input_data_in_correct_layout, \
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import PermuteAttrs
from mo.utils.error import Error

def apply_nhwc_to_nchw_permutation(graph: Graph):
    # Add NHWC to NCHW permutation for all data nodes (only for nodes without permutation)
    if graph.graph['layout'] == 'NCHW':

    for node in graph.get_data_nodes():
        if node.has_and_set('nchw_layout'):

        # Get NHWC to NCHW permutation for N dims, where N = len(node.shape)
        permutation = PermuteAttrs().get_nhwc_to_nchw_permutation(len(node.shape))

        # Check that data node already has permutation
        skip_permutation = False
        for in_node in node.in_nodes():
            edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
            if 'permutation' in edge_attrs:
                skip_permutation = True
        for out_node in node.out_nodes():
            edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
            if 'permutation' in edge_attrs:
                skip_permutation = True

        if skip_permutation:

        # Set permutation to all in/out edges
        for in_node in node.in_nodes():
            PermuteAttrs.set_permutation(in_node, node, permutation)

        for out_node in node.out_nodes():
            PermuteAttrs.set_permutation(node, out_node, permutation)

def merge_nodes_permutations(graph: Graph):
    # Iterate over all data nodes and check all permutations for similarity
    # In case of equal permutations, this permutation will be set as attribute for data node
    # otherwise exception will be raised
    for node in graph.nodes():
        node = Node(graph, node)
        if node.kind != 'data':

        permutations = []

        # Get all permutations from in edges
        for in_node in node.in_nodes():
            edge_attrs = node.graph.get_edge_data(in_node.id, node.id)[0]
            if 'permutation' in edge_attrs:

        # Get all permutations from out edges
        for out_node in node.out_nodes():
            edge_attrs = node.graph.get_edge_data(node.id, out_node.id)[0]
            if 'permutation' in edge_attrs:

        # Check that all permutations are equal
        final_permutations = []
        for p in permutations:
            if p is not None:

        if len(final_permutations) == 0:

        if not all([np.array_equal(final_permutations[0], perm) for perm in final_permutations]):
            raise Error('Permutations requested for {} data node are not equal! List of permutations: {}'
                        ''.format(node.name, [p.perm for p in permutations]))

        assert not node.has_valid('permutation') or np.array_equal(node.permutation, permutations[0])
        node['permutation'] = permutations[0]

def permute_data_nodes_attrs(graph: Graph):
    # Iterate over all data nodes and apply permutation if exists
    for node in graph.get_data_nodes():
        if not node.has_valid('permutation'):

        if len(node.in_nodes()) != 0:  # there are data nodes without input operation node inside the tensor iterator
            edge_attrs = graph.get_edge_data(node.in_node(0).id, node.id)[0]
            if is_output_data_in_correct_layout(node.in_node(0), edge_attrs['out']):
                log.debug('Do not permute data node attrs for node "{}" output port "{}"'.format(node.in_node(0).id,

        # Apply permutation for shape and value if exists
        if len(node.permutation.perm) == 0:
        node.shape = np.array(node.shape)[node.permutation.perm]
        if node.has_valid('value'):
            assert len(node.value.shape) == len(node.permutation.perm), \
                'Node {} has shape {} and permutation {} that does not match. Their lengths should be equal' \
                ''.format(node.name, node.value.shape, node.permutation.perm)
            node.value = np.array(node.value.transpose(node.permutation.perm))

def permute_op_nodes_attrs(graph: Graph):
    for node in graph.get_op_nodes():
        if node.has_valid('permute_attrs') and not node.has_and_set('nchw_layout'):
            except Exception as e:
                raise Error('Can\'t permute attrs for node {}. Error message: {}'.format(node.id, e))

def permute_input_data(graph: Graph):
    if graph.graph['layout'] != 'NHWC':
    for node in graph.get_op_nodes():
        input_permutations = [(in_port, edge_attrs['input_permutation']) for in_port, edge_attrs in
                              node.in_edges().items() if 'input_permutation' in edge_attrs]
        for in_port, input_perm in input_permutations:
            permutation, port_info = input_perm
            direction, port = port_info.split(':')
            port = int(port)
            port_to_check = node.in_port(port) if direction == 'input' else node.out_port(port)
            if not is_input_data_in_correct_layout(node, in_port) and len(port_to_check.data.get_shape()) >= 4:
                permutation(node, port_info, in_port)

def reverse_input_channels(graph: Graph):
    Searchers for all type=Input nodes with 4D output tensors,
    tracks tensors down through non-shape-changing ops to the first type=Convolution or other channel-dependent nodes
    and reverse input channels in convolution weights.
    candidates = set()
    for node in graph.nodes():
        node = Node(graph, node)
        if node.has_valid('type') and node.type == 'Parameter' and len(node.out_nodes()) == 1 and node.out_node(
                0).shape.size == 4:
    log.debug('reverse_input_channels found candidates: {}'.format([c.node for c in candidates]))
    # Track down to the first convolutions
    convolutions = set()
    flip_passthrough = set()
    while len(candidates) > 0:
        op_node = candidates.pop()
        assert (len(op_node.out_nodes()) == 1)
        tensor_node = op_node.out_node(0)
        for consumer in tensor_node.out_nodes():
            if (consumer.has_valid('type') and
                    consumer.type == 'Convolution' and
                    consumer.in_node(1).has_valid('input_channel_dim') and
                    consumer.in_node(1).has_valid('shape') and
                    consumer.in_node(1).shape[consumer.in_node(1).input_channel_dim] == 3 and
                # TODO Use more reliable way
                if len(consumer.out_nodes()) == 1 and np.all(consumer.out_node().shape == tensor_node.shape):
                    if consumer.has_valid('type') and (
                            consumer.type == 'ScaleShift' or consumer.type == 'BatchNormalization'):
                    log.debug('Stop searching of conv candidate for channel reversing at node {}'.format(consumer.id))

    if len(convolutions) == 0:
        log.error('Reverse input channels are not applied -- appropriate convolutions were not found')

    for node in flip_passthrough:
        log.debug("Applying flip for ScaleShift: {}".format(node.name))
        assert node.has_valid('type') and (node.type == 'ScaleShift' or node.type == 'BatchNormalization')
        blobs = [node.in_node(i) for i in range(1, len(node.in_nodes()))]
        for blob in blobs:
            assert blob.has_valid('value')
            non_one_dimensions = np.where(blob.shape != 1)[0]
            assert len(non_one_dimensions) == 1
            assert blob.shape[non_one_dimensions[0]] == 3
            blob.value = np.flip(blob.value, non_one_dimensions[0])

    for conv in convolutions:
        if conv.op == 'DepthwiseConv2dNative':
            log.debug('out nodes: {}'.format(conv.out_node()))
            bottoms = conv.out_node().out_nodes()
            if len(bottoms) == 1 and bottoms[0].op == 'FakeQuantize':
                bottoms = bottoms[0].out_node().out_nodes()
            log.debug('bottoms: {}'.format(bottoms))
            log.debug('assumed conv: name = {}, op = {}'.format(bottoms[0].name, bottoms[0].op))
            if len(bottoms) > 0 and bottoms[0].op == 'Conv2D':
                bottom_conv = bottoms[0]
                # Flipping input channel for DepthwiseConv2dNative along doesn't do complete thing
                # We also need to flip input channels for the next convolution in groups
                ngroups = conv.group
                log.debug('ngroups = {}'.format(ngroups))
                bottom_channel_dim = bottom_conv.channel_dims[0]
                log.debug('bottom_challen_dim = {}'.format(bottom_channel_dim))
                bottom_channels = bottom_conv.in_node(0).shape[bottom_channel_dim]
                log.debug('bottom_channels = {}'.format(bottom_channels))
                assert (bottom_channels % ngroups == 0)
                multiplier = int(bottom_channels / ngroups)
                log.debug('multiplier = {}'.format(multiplier))
                bottom_weights = bottom_conv.in_node(1)
                tmp_shape_for_reorder = list(bottom_weights.value.shape)
                src_shape = list(tmp_shape_for_reorder)
                log.debug('weights shape = {}'.format(tmp_shape_for_reorder))
                assert (tmp_shape_for_reorder[bottom_weights.input_channel_dim] == bottom_channels)
                tmp_shape_for_reorder[bottom_weights.input_channel_dim] = ngroups
                tmp_shape_for_reorder = tmp_shape_for_reorder + [multiplier]
                log.debug('tmp_shape_for_reorder = {}'.format(tmp_shape_for_reorder))
                # temporary change shape of weights to do reordering
                # bottom_weights.value.shape = tuple(tmp_shape_for_reorder)
                bottom_weights.value = np.flip(bottom_weights.value.reshape(tuple(tmp_shape_for_reorder)),
                # change shape of weights back
                log.debug('back to shape = {}'.format(tuple(src_shape)))
                bottom_weights.value = bottom_weights.value.reshape(tuple(src_shape))
                log.debug('final shape of weights = {}'.format(bottom_weights.value.shape))
                log.debug('shape as attr = {}'.format(bottom_weights.shape))
                    'Reverse input channels are not applied: there is no Conv2D after DepthwiseConv2dNative to ' +
                    'complete the flip')

        conv.in_node(1).value = np.flip(conv.in_node(1).value, conv.in_node(1).input_channel_dim)
        conv.in_node(1).shape = int64_array(conv.in_node(1).value.shape)
        log.debug('Applied reversing input channels for weights of convolution {}'.format(conv.id))
        log.debug('Shape was (shape){}, (value.shape){}'.format(conv.in_node(1).shape, conv.in_node(1).value.shape))
        log.debug('Flipped dim: {}'.format(conv.in_node(1).input_channel_dim))