reshape.py 3.25 KB
Newer Older
openvino-pushbot's avatar
openvino-pushbot committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
"""
 Copyright (c) 2018 Intel Corporation

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""
16
import math
openvino-pushbot's avatar
openvino-pushbot committed
17 18

import networkx as nx
19
import numpy as np
openvino-pushbot's avatar
openvino-pushbot committed
20 21 22

from mo.front.common.partial_infer.elemental import single_output_infer
from mo.front.common.partial_infer.reshape import tf_reshape_shape_infer
23
from mo.graph.graph import Node
openvino-pushbot's avatar
openvino-pushbot committed
24
from mo.ops.op import Op
25
from mo.utils.error import Error
openvino-pushbot's avatar
openvino-pushbot committed
26 27 28 29 30 31 32 33 34 35 36


class Reshape(Op):
    op = 'Reshape'
    enabled = True

    def __init__(self, graph: nx.MultiDiGraph, attrs: dict):
        super().__init__(graph, {
            'kind': 'op',
            'type': __class__.op,
            'op': __class__.op,
37 38 39
            'infer': lambda node: single_output_infer(node, tf_reshape_shape_infer,
                                                      lambda node: np.reshape(node.in_node().value,
                                                                              node.out_node().shape))
openvino-pushbot's avatar
openvino-pushbot committed
40 41 42
        }, attrs)

    def supported_attrs(self):
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        return [('dim', lambda node: ','.join(map(str, node['dim'])))]

    @staticmethod
    def kaldi_infer(node: Node):
        in_node = node.in_node().in_node()  # prev_layer_node -> data -> this_node
        input_shape = node.in_node().shape
        # Kaldi Reshape hugely depends on the layers that precedes or succeeds
        # Convolution/Pooling layers. Therefore there are 4 cases with different
        # partial inference.
        batch = input_shape[0]
        if in_node.op == 'Convolution' or in_node.op == 'Pooling':
            output_spatial = np.array([batch, np.prod(input_shape[1:])], dtype=np.int64)
            return Reshape.set_shape_and_dim(node, output_spatial)
        # Supports ONLY NCHW and NH layouts
        if len(input_shape) not in [4, 2]:
            raise Error('Reshape in Kaldi support only 1d or 3d shapes')
        spatial_shape = input_shape[1]
        if len(input_shape) in [4]:
            spatial_shape = input_shape[2:3]
        out_node = node.out_node().out_node()
        if out_node.op == 'Convolution':
            output_spatial = np.array(
                [batch, math.ceil(spatial_shape / out_node.patch_stride), 1, out_node.patch_stride], dtype=np.int64)
            return Reshape.set_shape_and_dim(node, output_spatial)
        elif out_node.op == 'Pooling':
            if out_node.pool_step is None:
                out_node.stride = np.array([1, 1, out_node.window[-1], out_node.window[-1]], dtype=np.int64)
            output_spatial = np.array(
                [batch, out_node.pool_stride, 1, math.ceil(spatial_shape / out_node.pool_stride)], dtype=np.int64)
            return Reshape.set_shape_and_dim(node, output_spatial)

    @staticmethod
    def set_shape_and_dim(node: Node, reshape_dim):
        Reshape.update_node_stat(node, {'dim': reshape_dim})
        node.out_node().shape = reshape_dim