"""
 Copyright (c) 2018 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""
import argparse
import logging as log

import numpy as np

from extensions.front.freeze_placeholder_value import FreezePlaceholderValue
from mo.front.caffe import custom_layers_mapping
from mo.front.caffe import loader
from mo.front.caffe.extractor import caffe_extractor, common_caffe_fields, caffe_type_extractors
from mo.front.common.register_custom_ops import check_for_duplicates
from mo.front.common.register_custom_ops import update_extractors_with_extensions
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.extractor import extract_node_attrs, add_output_ops, create_tensor_nodes, remove_output_ops, \
    add_input_ops, user_data_repack
from mo.graph.graph import print_graph_stat
from mo.middle.passes.conv import convert_muladd_to_scaleshift_or_power, \
    convert_matmul_to_fully_connected, batch_norm_fuse, convert_bias, convert_add_to_scaleshift, \
    convert_mul_to_scaleshift, \
    convert_multi_input_conv
from mo.middle.passes.eliminate import graph_clean_up, remove_op_nodes
from mo.middle.passes.fusing.decomposition import convert_bn_to_mul_add, convert_scale_shift_to_mul_add
from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
from mo.middle.passes.fusing.resnet_optimization import stride_optimization
from mo.middle.passes.infer import add_mean_scale_values, scale_input, override_placeholder_shapes, mark_outputs, \
    partial_infer, convert_mul_add_to_power, override_batch
from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
from mo.middle.passes.pool import mean_to_avgpool
from mo.middle.passes.shape import reverse_input_channels, fuse_sequence_of_reshapes
from mo.middle.passes.shared_weights_duplication import duplicate_shared_weights
from mo.pipeline.common import prepare_emit_ir
from mo.utils import class_registration
from mo.utils.error import Error
from mo.utils.find_inputs import find_inputs
from mo.utils.utils import refer_to_faq_msg


def driver(argv: argparse.Namespace, proto_file_name: str, model_file_name: str, output_model_name: str, outputs: list,
           output_dir: str,
           scale: float,
           user_shapes: [None, list, np.array] = None, mean_scale_values: [dict, list] = (), mean_file: str = "",
           mean_file_offsets: tuple = None,
           custom_layers_mapping_path: str = None):
    try:
        proto, model = loader.load_caffe_proto_model(proto_file_name, model_file_name)
    except Error as e:
        raise
    except Exception as e:
        raise Error('Model Optimizer is not able to read {}. Possible reasons: '.format(proto_file_name) +
                    '1. your caffemodel contains custom layers that are not supported in Model Optimizer by default. ' +
                    '2. your prototxt does not have a valid structure, e.g you downloaded it as html. ' +
                    'In particular the first unknown field is {} '.format(str(e).split(' ')[-1]) +
                    'After you made sure that prototxt has a valid structure and still see this issue, then ' +
                    'you need to generate a python parser for caffe.proto that was used when the model ' +
                    'was created. ' +
                    'Run "python3 generate_caffe_pb2.py --input_proto ${PATH_TO_CAFFE}/src/caffe/proto/caffe.proto". ' +
                    refer_to_faq_msg(1)) from e

    update_extractors_with_extensions(
        caffe_type_extractors,
        argv.disable_omitting_optional if hasattr(argv, 'disable_omitting_optional') else False,
        argv.disable_flattening_optional_params if hasattr(argv, 'disable_flattening_optional_params') else False
    )

    try:
        graph, original_shapes = loader.caffe_pb_to_nx(proto, model)
    except ValueError as e:
        raise Error('Invalid prototxt file: value error {}. ' +
                    refer_to_faq_msg(11), str(e)) from e

    log.debug("After caffe_pb_to_nx")
    print_graph_stat(graph)

    graph.__setattr__('proto_path', proto_file_name)
    graph.__setattr__('caffemodel_path', model_file_name)
    graph.__setattr__('name', getattr(proto, 'name', None) or output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'caffe'

    extract_node_attrs(graph, lambda node: (True, common_caffe_fields(node)))

    log.debug("After adding specific nodes for outputs")
    print_graph_stat(graph)

    custom_layers_map = custom_layers_mapping.load_layers_xml(custom_layers_mapping_path)
    custom_layers_mapping.update_extractors(
        caffe_type_extractors,
        custom_layers_map,
        argv.disable_omitting_optional if hasattr(argv, 'disable_omitting_optional') else False,
        argv.enable_flattening_nested_params if hasattr(argv, 'enable_flattening_nested_params') else False
    )

    extract_node_attrs(graph, lambda node: caffe_extractor(node, check_for_duplicates(caffe_type_extractors)))

    log.debug("After extract_node_attr")
    print_graph_stat(graph)

    user_shapes, outputs, _ = user_data_repack(graph, user_shapes, outputs, None)
    if argv.freeze_placeholder_with_value is not None:
        FreezePlaceholderValue.enabled = True
        FreezePlaceholderValue.replacement_dict = argv.freeze_placeholder_with_value
        class_registration.update_registration([FrontReplacementSubgraph])
    graph, output_op_nodes = add_output_ops(graph, outputs)
    graph, input_op_nodes = add_input_ops(graph, user_shapes, True)
    override_placeholder_shapes(graph, user_shapes)
    override_batch(graph, argv.batch)
    graph_clean_up(graph)

    class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)

    graph = create_tensor_nodes(graph)

    log.debug("After create_tensor_nodes")
    print_graph_stat(graph)

    remove_op_nodes(graph, {'op': 'Identity'})
    remove_output_ops(graph)
    graph_clean_up(graph)

    log.debug("After removing specific nodes for output")
    print_graph_stat(graph)

    # you need to pass required network outputs here
    # but we don't have a way yet, so just passing all discovered sinks
    mark_outputs(graph)
    graph_clean_up(graph)
    log.debug("After graph_cleanup")
    print_graph_stat(graph)

    graph = partial_infer(graph)
    log.debug("After partial_infer")
    print_graph_stat(graph)

    duplicate_shared_weights(graph)

    graph, input_op_nodes = add_input_ops(graph, user_shapes, False)
    graph_clean_up(graph)

    scale_input(graph, scale)

    add_mean_scale_values(graph, mean_scale_values)

    log.debug("Split multi input convolutions")
    convert_multi_input_conv(graph)

    graph_clean_up(graph)
    log.debug("After graph_cleanup")
    print_graph_stat(graph)

    remove_op_nodes(graph, {'op': 'Dropout'})
    remove_op_nodes(graph, {'phase': 0})
    graph_clean_up(graph)

    class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)

    mean_to_avgpool(graph)

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    if not argv.disable_fusing:
        convert_bn_to_mul_add(graph)
        graph_clean_up(graph)

        convert_scale_shift_to_mul_add(graph)
        graph_clean_up(graph)

        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_resnet_optimization:
        stride_optimization(graph)

    convert_muladd_to_scaleshift_or_power(graph)
    convert_matmul_to_fully_connected(graph)
    batch_norm_fuse(graph)
    convert_mul_add_to_power(graph)
    convert_bias(graph)
    convert_add_to_scaleshift(graph)  # scale = 1
    convert_mul_to_scaleshift(graph)  # biases = 0

    graph_clean_up(graph)
    log.debug("After graph_cleanup")
    print_graph_stat(graph)

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

    fuse_sequence_of_reshapes(graph)

    input_names = find_inputs(graph)
    mf = []
    try:
        if mean_file and len(original_shapes) == 1:
            mf = loader.parse_mean(mean_file, original_shapes[input_names[0]], mean_file_offsets)
        elif mean_file:
            raise Error('Mean file for topologies with multiple inputs is not supported. ' +
                        refer_to_faq_msg(9))
    except ValueError as e:
        raise Error('Cannot load or process mean file: value error {}. ' +
                    refer_to_faq_msg(10), str(e)) from e

    class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)

    prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
                    mean_data=mf,
                    input_names=input_names)
    return 0