split.py 8.7 KB
Newer Older
openvino-pushbot's avatar
openvino-pushbot committed
1
"""
2
 Copyright (c) 2018-2019 Intel Corporation
openvino-pushbot's avatar
openvino-pushbot committed
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import logging as log
Alexey Suhov's avatar
Alexey Suhov committed
18

openvino-pushbot's avatar
openvino-pushbot committed
19
import numpy as np
Alexey Suhov's avatar
Alexey Suhov committed
20 21

from mo.ops.op import PermuteAttrs
22
from mo.graph.graph import Node
openvino-pushbot's avatar
openvino-pushbot committed
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41


def part_sizes_to_indices(part_sizes: list):
    """
    Calculates indices of splits in the array based on part sizes for the split.
    Output list can be used as the second argument for np.split function.
    """
    idx = 0
    indices = []
    for part_size in part_sizes:
        idx += part_size
        indices.append(idx)
    # the last element should equal to the size of original array and it is redundant to numpy
    log.debug("part_sizes: {}   -->   indices: {}".format(part_sizes, indices))
    del indices[-1]
    log.debug("part_sizes: {}   -->   indices: {}".format(part_sizes, indices))
    return np.array(indices)


42
def split(input_data_node: Node, node: Node, axis: int, part_sizes: list):
openvino-pushbot's avatar
openvino-pushbot committed
43 44 45 46 47
    """
    Partial inference of generic split node.

    Args:
        @input: input tensor node, subject to split
48
        @node: node of one of the Split types
openvino-pushbot's avatar
openvino-pushbot committed
49 50 51 52 53 54 55 56
        @axis: split dimension index
        @part_sizes: a NumPy array with sizes of all pieces that we split to

    Returns:
        int: normalized axis index

    """

57
    if input_data_node.shape is None:
openvino-pushbot's avatar
openvino-pushbot committed
58 59 60 61
        return

    # normalize axis
    if axis < 0:
62
        axis = input_data_node.shape.size + axis
openvino-pushbot's avatar
openvino-pushbot committed
63

64
    if axis < 0 or axis >= input_data_node.shape.size:
openvino-pushbot's avatar
openvino-pushbot committed
65 66 67 68 69 70 71 72 73 74 75
        log.error('Model is incorrect: axis for split node is out of range')
        return

    undef_indices = np.argwhere(part_sizes == -1)
    if undef_indices.size > 1:
        log.error('Desired split part sizes have more than one -1 element -- cannot deduce real sizes for them')
        return

    if undef_indices.size == 1:
        undef_index = undef_indices[0]
        part_sizes[undef_index] = 0
76
        deduced_dim = input_data_node.shape[axis] - np.add.reduce(part_sizes)
openvino-pushbot's avatar
openvino-pushbot committed
77
        if deduced_dim < 0:
78 79
            log.error('Just deduced dimension for the split has negative value that means that split input shape and '
                      'desired parts are not compatible')
openvino-pushbot's avatar
openvino-pushbot committed
80 81 82
            return

    all_parts_size = np.add.reduce(part_sizes)
83 84 85
    if all_parts_size != input_data_node.shape[axis]:
        log.error("input.shape[{}] = {}  !=  {} = sum of all parts in part_sizes".format(axis,
                                                                                         input_data_node.shape[axis],
openvino-pushbot's avatar
openvino-pushbot committed
86 87 88
                                                                                         all_parts_size))
        return

89 90 91
    splitted = None
    if input_data_node.value is not None:
        splitted = np.split(input_data_node.value, part_sizes_to_indices(part_sizes), axis)
openvino-pushbot's avatar
openvino-pushbot committed
92

93 94 95 96 97 98 99 100 101 102 103 104
    # not all outputs from the split could be used so it is necessary to iterate over output edges and infer shape for
    # necessary nodes only
    for _, dst, edge_attrs in node.graph.out_edges(node.id, data=True):
        out_port = edge_attrs['out']
        out_node = node.out_node(out_port)

        new_out_shape = input_data_node.shape.copy()
        new_out_shape[axis] = part_sizes[out_port]
        node.out_node(out_port).shape = new_out_shape
        if splitted is not None:
            out_node.value = splitted[out_port]
            assert all(out_node.value.shape == out_node.shape)
openvino-pushbot's avatar
openvino-pushbot committed
105 106 107

    assert not node.has_valid('axis') or node.axis == axis
    node.axis = axis
108 109
    # WARNING: != 4 is supposed to work for NHWC to NCHW translation only.
    # if other global permutations happen this will fail
openvino-pushbot's avatar
openvino-pushbot committed
110 111
    # TODO: redesign it to have this logic built in NHWC to NCHW translation pass; it requires
    #       additional attributes with layout to be propagated through the network
112 113 114
    if len(input_data_node.shape) != 4 and node.has_valid('dim_attrs') and 'axis' in node.dim_attrs:
        log.warning('Removed "axis" attribute from the scope of the model relayout pass because len(input.shape) == {} '
                    '!= 4 for node {}'.format(len(input_data_node.shape), node.soft_get('name')))
openvino-pushbot's avatar
openvino-pushbot committed
115 116
        node.dim_attrs.remove('axis')
        assert 'axis' not in node.dim_attrs
117
    log.debug('output shapes after split: {}'.format([v.shape for k, v in node.out_nodes().items()]))
openvino-pushbot's avatar
openvino-pushbot committed
118 119 120 121 122 123


def tf_split_infer(node):
    """
    Partial infer of split node similar to Split op of TF.
    """
124 125
    # Two inputs: [split_dim, input]
    assert len(node.in_nodes()) == 2, 'Node "{}" must have exactly two inputs'.format(node.soft_get('name'))
openvino-pushbot's avatar
openvino-pushbot committed
126 127 128 129
    split_dim = node.in_node(0).value
    if split_dim is None:
        log.error('split_dim value for node {} is None. Cannot do shape inference.')
        return
130 131

    assert split_dim.ndim == 0, 'The split dimension for node "{}" must be a scalar.'.format(node.soft_get('name'))
openvino-pushbot's avatar
openvino-pushbot committed
132 133 134
    split_dim = split_dim.item()
    input = node.in_node(1)

135 136
    if input.shape is None:
        log.error('Input shape for node {} is not defined'.format(node.soft_get('name')))
openvino-pushbot's avatar
openvino-pushbot committed
137 138 139 140 141 142 143 144 145 146 147
        return

    log.debug('input shape for split: {}, should be split along {} dim'.format(input.shape, split_dim))
    split_dim_size = input.shape[split_dim]
    log.debug('split_dim_size type = {}'.format(type(split_dim_size)))

    if split_dim_size % node.num_split != 0:
        log.error("split_dim cannot be evenly divided by a given number of parts")
        return

    # split_dim is a numpy array, axis is split_dim[0]
148 149 150
    log.debug('split_dim_size = {}, node.num_split = {}, div = {}, typeof div = {}'.format(
        split_dim_size, node.num_split, split_dim_size / node.num_split, type(split_dim_size / node.num_split)))
    split(input, node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split)
openvino-pushbot's avatar
openvino-pushbot committed
151 152 153
    node.graph.remove_edge(node.in_node(0).id, node.id)
    node['input_port'] = 1

Alexey Suhov's avatar
Alexey Suhov committed
154 155
    PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:1')])

openvino-pushbot's avatar
openvino-pushbot committed
156

157
def tf_split_v_infer(node: Node):
openvino-pushbot's avatar
openvino-pushbot committed
158 159 160 161
    """
    Partial infer of split node similar to SplitV op of TF.
    """

Alexey Suhov's avatar
Alexey Suhov committed
162
    if len(node.in_nodes()) == 1 and not (node.has_valid('axis') and node.has_valid('size_splits')):
163
        return
Alexey Suhov's avatar
Alexey Suhov committed
164 165

    if len(node.in_nodes()) == 3 and (node.has_valid('axis') or node.has_valid('size_splits')):
166
        return
openvino-pushbot's avatar
openvino-pushbot committed
167 168

    # Three inputs: [input, size_splits, split_dim)
169
    if len(node.in_nodes()) == 3:
Alexey Suhov's avatar
Alexey Suhov committed
170 171 172 173 174 175
        split_dim = node.in_node(2).value
        assert split_dim.ndim == 0
        split_dim = split_dim.item()
        size_splits = node.in_node(1).value
        node.graph.remove_edge(node.in_node(1).id, node.id)
        node.graph.remove_edge(node.in_node(2).id, node.id)
176
    else:
Alexey Suhov's avatar
Alexey Suhov committed
177 178 179
        split_dim = node.axis
        size_splits = node.size_splits
   
openvino-pushbot's avatar
openvino-pushbot committed
180 181 182
    if split_dim is None:
        log.error('split_dim value for node {} is None. Cannot do shape inference.')
        return
Alexey Suhov's avatar
Alexey Suhov committed
183
    
openvino-pushbot's avatar
openvino-pushbot committed
184
    input = node.in_node(0)
185 186
    if input.shape is None or size_splits is None:
        log.error('input shape or size of splits are not defined for node {}'.format(node.soft_get('name')))
openvino-pushbot's avatar
openvino-pushbot committed
187 188
        return

189 190
    log.debug('split_dim = {}, input.shape = {}, size_splits.value = {}'.format(split_dim, input.shape, size_splits))

openvino-pushbot's avatar
openvino-pushbot committed
191
    # split_dim is a numpy array, axis is split_dim
192 193 194
    split(input, node, split_dim, size_splits)

    PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])
openvino-pushbot's avatar
openvino-pushbot committed
195

196 197

def tf_unpack_infer(node: Node):
openvino-pushbot's avatar
openvino-pushbot committed
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
    if len(node.in_nodes()) != 1:
        log.debug('Unpack node "{}" must have one input.'.format(node.name))
        return

    in_shape = node.in_node().shape
    if in_shape is None:
        log.debug('Unpack node "{}" input node shape is not defined.'.format(node.name))
        return

    split_dim = node.axis
    log.debug('input shape for unpack: {}, should be split along {} dim'.format(in_shape, split_dim))
    split_dim_size = in_shape[split_dim]
    log.debug('split_dim_size type = {}'.format(type(split_dim_size)))

    if node.num_split is not None and node.num_split != split_dim_size:
        log.debug('The unpack where num to unpack is not equal to the size of the dimension to unpack is not supported')
        return

    if node.num_split is None:
        node.num_split = split_dim_size

    if split_dim_size % node.num_split != 0:
        log.error("split_dim cannot be evenly divided by a given number of parts")
        return

223
    split(node.in_node(), node, split_dim, [int(split_dim_size / node.num_split)] * node.num_split)
openvino-pushbot's avatar
openvino-pushbot committed
224
    # node shapes will be squeezed in the separate pass