mx.py 6.58 KB
Newer Older
openvino-pushbot's avatar
openvino-pushbot committed
1
"""
2
 Copyright (c) 2018-2019 Intel Corporation
openvino-pushbot's avatar
openvino-pushbot committed
3 4 5 6 7 8 9 10 11 12 13 14 15

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""
16 17
from extensions.back.CreateConstNodes import CreateConstNodesReplacement
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
Alexey Suhov's avatar
Alexey Suhov committed
18
from mo.utils.error import Error, FrameworkError
19
from mo.utils.logger import log_step
openvino-pushbot's avatar
openvino-pushbot committed
20 21 22 23 24 25
from mo.utils.utils import refer_to_faq_msg

try:
    import mxnet
except ImportError:
    raise Error('Module mxnet was not found. Please install appropriate version of mxnet via install_prerequisites '
Alexey Suhov's avatar
Alexey Suhov committed
26
                'script.' + refer_to_faq_msg(52))
openvino-pushbot's avatar
openvino-pushbot committed
27 28 29

import argparse

30
from mo.front.extractor import extract_node_attrs, remove_output_ops
openvino-pushbot's avatar
openvino-pushbot committed
31 32 33
from mo.front.mxnet.extractor import mxnet_op_extractor
from mo.front.mxnet.loader import symbol2nx, load_symbol_def
from mo.middle.passes.fusing.decomposition import convert_batch_norm, convert_scale_shift_to_mul_add
34 35
from mo.middle.passes.conv import convert_muladd_to_scaleshift, \
    convert_add_or_mul_to_scaleshift, fuse_pad, convert_matmul_to_fully_connected
36
from mo.middle.passes.eliminate import graph_clean_up, remove_const_ops
openvino-pushbot's avatar
openvino-pushbot committed
37 38 39 40 41
from mo.middle.passes.fusing.fuse_linear_ops import fuse_linear_ops
from mo.middle.passes.fusing.fuse_linear_seq import fuse_mul_add_sequence
from mo.middle.passes.fusing.mark_unfused_nodes import mark_unfused_nodes
from mo.middle.passes.fusing.resnet_optimization import stride_optimization
from mo.middle.passes.mean_scale_values import move_scaleshift_to_preprocess
42 43
from mo.middle.passes.shape import reverse_input_channels, merge_nodes_permutations, permute_data_nodes_attrs, \
    permute_op_nodes_attrs
openvino-pushbot's avatar
openvino-pushbot committed
44 45 46 47 48
from mo.pipeline.common import prepare_emit_ir
from mo.front.mxnet.nd_to_params import save_params_file
from mo.front.common.register_custom_ops import update_extractors_with_extensions
from mo.front.mxnet.extractor import mxnet_op_extractors
from mo.utils import class_registration
Alexey Suhov's avatar
Alexey Suhov committed
49 50
from mo.utils.cli_parser import get_meta_info
from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
openvino-pushbot's avatar
openvino-pushbot committed
51 52


53
def driver(argv: argparse.Namespace, input_model: str, output_model_name: str, output_dir: str):
54
    log_step(argv.steps, 'LOAD')
Alexey Suhov's avatar
Alexey Suhov committed
55
    meta_info = get_meta_info(argv)
openvino-pushbot's avatar
openvino-pushbot committed
56 57

    try:
Alexey Suhov's avatar
Alexey Suhov committed
58 59
        model_nodes, model_params, model_name, iteration_number = load_symbol_def(input_model, argv.input_symbol,
                                                                                  argv.input,
openvino-pushbot's avatar
openvino-pushbot committed
60 61 62 63
                                                                                  argv.nd_prefix_name,
                                                                                  argv.pretrained_model_name,
                                                                                  argv.legacy_mxnet_model)
    except (ValueError, mxnet.base.MXNetError) as e:
Alexey Suhov's avatar
Alexey Suhov committed
64
        raise FrameworkError(
openvino-pushbot's avatar
openvino-pushbot committed
65 66 67 68 69 70 71 72 73 74 75
            'The following error happened while loading mxnet model {}: {}. ' +
            refer_to_faq_msg(53),
            input_model,
            str(e)
        ) from e

    if argv.nd_prefix_name and argv.pretrained_model_name and argv.save_params_from_nd:
        save_params_file(model_name, model_params._arg_params, model_params._aux_params, iteration_number)

    update_extractors_with_extensions(mxnet_op_extractors)
    graph = symbol2nx(model_nodes, model_params, argv.input)
76
    graph.check_empty_graph('symbol2nx. It may happen due to problems with loaded model')
openvino-pushbot's avatar
openvino-pushbot committed
77 78 79 80 81

    graph.__setattr__('name', output_model_name)
    graph.graph['layout'] = 'NCHW'
    graph.graph['cmd_params'] = argv
    graph.graph['fw'] = 'mxnet'
82
    graph.graph['feature_dim'] = 1 if graph.graph['layout'] == 'NCHW' else 3
83 84 85 86 87 88 89

    if graph.graph['cmd_params'].generate_experimental_IR_V10:
        version = 10
    else:
        version = 6
    graph.graph['ir_version'] = 2 if argv.generate_deprecated_IR_V2 else version

90
    extract_node_attrs(graph, mxnet_op_extractor)
Alexey Suhov's avatar
Alexey Suhov committed
91

92
    # --------------------------------- LOAD END ------------------------------------------------------
93
    log_step(argv.steps, 'FRONT')
Alexey Suhov's avatar
Alexey Suhov committed
94
    class_registration.apply_replacements(graph, class_registration.ClassType.FRONT_REPLACER)
95
    log_step(argv.steps, 'MIDDLE')
openvino-pushbot's avatar
openvino-pushbot committed
96
    class_registration.apply_replacements(graph, class_registration.ClassType.MIDDLE_REPLACER)
97

Alexey Suhov's avatar
Alexey Suhov committed
98
    fuse_pad(graph)
openvino-pushbot's avatar
openvino-pushbot committed
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122

    # Mark nodes with attr 'can_be_fused': False to disable fusing for specified nodes
    mark_unfused_nodes(graph, argv.finegrain_fusing)

    # Converting FusedBatchNorm layer to Mul->Add->Mul->Add sequence
    convert_batch_norm(graph)
    graph_clean_up(graph)

    if not argv.disable_fusing:
        # Converting ScaleShift layer to Mul->Add
        convert_scale_shift_to_mul_add(graph)
        graph_clean_up(graph)

        # Fusing the sequences of Mul/Add operations
        fuse_mul_add_sequence(graph)
        graph_clean_up(graph)

        # Fusing linear operation to Convolution
        fuse_linear_ops(graph)
        graph_clean_up(graph)

    if not argv.disable_resnet_optimization:
        stride_optimization(graph)

Alexey Suhov's avatar
Alexey Suhov committed
123 124
    fuse_pad(graph)

openvino-pushbot's avatar
openvino-pushbot committed
125
    # Converting Mul->Add to ScaleShift node
126
    convert_muladd_to_scaleshift(graph)
openvino-pushbot's avatar
openvino-pushbot committed
127 128
    graph_clean_up(graph)

129 130
    convert_add_or_mul_to_scaleshift(graph)  # scale = 1
    graph_clean_up(graph)
openvino-pushbot's avatar
openvino-pushbot committed
131 132 133 134 135 136 137 138

    if argv.reverse_input_channels:
        reverse_input_channels(graph)

    if argv.move_to_preprocess:
        move_scaleshift_to_preprocess(graph)
        graph_clean_up(graph)

Alexey Suhov's avatar
Alexey Suhov committed
139 140 141
    pattern = EltwiseInputNormalize()
    pattern.find_and_replace_pattern(graph)

142 143 144 145 146 147 148
    for_graph_and_each_sub_graph_recursively(graph, convert_matmul_to_fully_connected)

    merge_nodes_permutations(graph)
    permute_data_nodes_attrs(graph)
    permute_op_nodes_attrs(graph)

    graph_clean_up(graph)
149
    log_step(argv.steps, 'BACK')
openvino-pushbot's avatar
openvino-pushbot committed
150 151
    class_registration.apply_replacements(graph, class_registration.ClassType.BACK_REPLACER)

152 153 154 155 156
    for_graph_and_each_sub_graph_recursively(graph, remove_const_ops)
    CreateConstNodesReplacement().find_and_replace_pattern(graph)

    for_graph_and_each_sub_graph_recursively(graph, remove_output_ops)

157
    log_step(argv.steps, 'EMIT')
Alexey Suhov's avatar
Alexey Suhov committed
158 159 160
    prepare_emit_ir(graph=graph, data_type=argv.data_type, output_dir=output_dir, output_model_name=output_model_name,
                    meta_info=meta_info)
    return 0