"""
 Copyright (c) 2018 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import numpy as np

from mo.front.caffe.extractors.utils import get_spatial_attr, get_list_from_container, weights_biases
from mo.front.common.extractors.utils import layout_attrs
from mo.front.extractor import FrontExtractorOp
from mo.ops.convolution import Convolution
from mo.utils.error import Error


class ConvFrontExtractor(FrontExtractorOp):
    op = 'convolution'
    enabled = True

    @staticmethod
    def extract(node):
        proto_layer, model_layer = node.pb, node.model_pb

        if not proto_layer:
            raise Error('Protobuf layer can not be empty')

        conv_param = proto_layer.convolution_param
        conv_type = 'ConvND' if len(proto_layer.bottom) > 1 else 'Conv2D'

        params = conv_set_params(conv_param, conv_type)
        attrs = conv_create_attrs(params)
        attrs.update({'op': conv_type,
                      'get_group': lambda node: node.group,
                      'get_output_feature_dim': lambda node: node.output
                      })

        # Embed weights and biases as attributes
        # It will be moved to a separate nodes in special pass
        attrs.update(
            weights_biases(conv_param.bias_term, model_layer, start_index=len(proto_layer.bottom), proto=conv_param))
        attrs.update(layout_attrs())

        # update the attributes of the node
        Convolution.update_node_stat(node, attrs)
        return __class__.enabled


class DeconvFrontExtractor(FrontExtractorOp):
    op = 'deconvolution'
    enabled = True

    @staticmethod
    def extract(node):
        proto_layer, model_layer = node.pb, node.model_pb

        if not proto_layer:
            raise Error('Protobuf layer can not be empty')

        deconv_param = proto_layer.convolution_param

        params = conv_set_params(deconv_param, 'Deconv2D')
        attrs = conv_create_attrs(params)
        attrs.update({'type': 'Deconvolution',
                      'op': 'Deconv2D',
                      'get_group': lambda node: node.group,
                      'get_output_feature_dim': lambda node: node.output,
                      'input_feature_channel': 0,
                      'output_feature_channel': 1,
                      })

        # Embed weights and biases as attributes
        # It will be moved to a separate nodes in special pass
        attrs.update(weights_biases(deconv_param.bias_term, model_layer))
        attrs.update(layout_attrs())

        # update the attributes of the node
        Convolution.update_node_stat(node, attrs)
        return __class__.enabled


def conv_create_attrs(params):
    """
    Creates object of attrs for convolution
    Args:
        params: {
            type_str: type_str
            padding: padding
            dilate: dilate
            stride: stride
            kernel: kernel
            group: group
            output: output
            bias_term: bias_term
        }
    Returns:
        object with all necessary convolution attributes

    """
    return {
        'bias_addable': True,
        'bias_term': params['bias_term'],
        'pad': np.array([[0, 0], [0, 0], [params['padding'][1], params['padding'][1]],
                         [params['padding'][0], params['padding'][0]]], dtype=np.int64),
        'pad_spatial_shape': np.array([[params['padding'][1], params['padding'][1]],
                                       [params['padding'][0], params['padding'][0]]], dtype=np.int64),
        'dilation': np.array([1, 1, params['dilate'][1], params['dilate'][0]], dtype=np.int64),
        'output_spatial_shape': None,
        'output_shape': None,
        'stride': np.array([1, 1, params['stride'][1], params['stride'][0]], dtype=np.int64),
        'group': params['group'],
        'output': params['output'],
        'kernel_spatial': np.array([params['kernel'][1], params['kernel'][0]], dtype=np.int64),
        'kernel_spatial_idx': np.array([2, 3], dtype=np.int64),
        'reshape_kernel': True,

        'input_feature_channel': 1,
        'output_feature_channel': 0,
    }


def conv_set_params(conv_param, conv_type):
    # Defaults
    padding = [0, 0]
    stride = [1, 1]
    kernel = [0, 0]
    dilate = [1, 1]
    group = 1

    kernel = get_spatial_attr(kernel, 'kernel_size', 'kernel', conv_param)
    padding = get_spatial_attr(padding, 'pad', 'pad', conv_param)
    stride = get_spatial_attr(stride, 'stride', 'stride', conv_param)
    dilates = get_list_from_container(conv_param, 'dilation', int)
    if len(dilates) > 0:
        dilate[0] = dilate[1] = dilates[0]

    groups = get_list_from_container(conv_param, 'group', int)
    group = groups[0] if len(groups) > 0 and groups[0] != 1 else group

    return {
        'type_str': conv_type,
        'padding': padding,
        'dilate': dilate,
        'stride': stride,
        'kernel': kernel,
        'group': group,
        'output': conv_param.num_output,
        'bias_term': conv_param.bias_term
    }