TensorArrayGather.py 2.14 KB
Newer Older
Alexey Suhov's avatar
Alexey Suhov committed
1
"""
2
 Copyright (C) 2018-2020 Intel Corporation
Alexey Suhov's avatar
Alexey Suhov committed
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19

 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
 You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

 Unless required by applicable law or agreed to in writing, software
 distributed under the License is distributed on an "AS IS" BASIS,
 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
"""

import networkx as nx
import numpy as np

20
from mo.graph.graph import Node, Graph
Alexey Suhov's avatar
Alexey Suhov committed
21
from mo.ops.op import Op
22
from mo.utils.utils import symm_match_shapes
Alexey Suhov's avatar
Alexey Suhov committed
23 24 25 26 27


class TensorArrayGather(Op):
    op = "TensorArrayGatherV3"

28
    def __init__(self, graph: Graph, attrs: dict):
Alexey Suhov's avatar
Alexey Suhov committed
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
        mandatory_props = {
            'type': __class__.op,
            'op': __class__.op,
            'infer': TensorArrayGather.array_infer,
        }
        super().__init__(graph, mandatory_props, attrs)

    @staticmethod
    def array_infer(node: Node):
        assert len(node.in_nodes()) == 3

        handle = node.in_node(0)
        indices = node.in_node(1)
        flow_in = node.in_node(2)

        ta_node = Node(node.graph, str(handle.value))

46 47 48 49
        if ta_node.has_valid('element_shape') and ta_node.element_shape is not None and len(ta_node.element_shape) > 0:
            assert symm_match_shapes(ta_node['element_shape'], node.element_shape)
        else:
            ta_node['element_shape'] = node.element_shape
Alexey Suhov's avatar
Alexey Suhov committed
50
        data_shape = ta_node['element_shape']
51
        assert -1 not in data_shape or data_shape.size == 2 and data_shape[0] == -1 and data_shape[1] != -1
Alexey Suhov's avatar
Alexey Suhov committed
52 53 54 55

        assert ta_node.has_valid('size')
        size = ta_node['size']

56 57
        assert size > 0

Alexey Suhov's avatar
Alexey Suhov committed
58 59 60 61 62 63
        output_shape = [size] + [data_shape[i] for i in range(len(data_shape))]
        output_value = None

        for _, out_node in node.graph.out_edges(node.id):
            node.graph.node[out_node]['shape'] = np.array(output_shape)
            node.graph.node[out_node]['value'] = None if output_value is None else np.array(output_value)