""" Copyright (c) 2018-2019 Intel Corporation Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import unittest import numpy as np from mo.front.common.partial_infer.split import tf_split_infer, tf_unpack_infer, tf_split_v_infer, split from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Node from mo.utils.unittest.graph import build_graph, build_graph_with_edge_attrs class TestTFSplitInfer(unittest.TestCase): graph = None def setUp(self): self.graph = build_graph({'split_dim': {'value': None, 'kind': 'data'}, 'data_to_split': {'value': None, 'shape': None, 'kind': 'data'}, 'split_node': {'kind': 'op', 'op': 'Split', 'num_split': 3, 'axis': None}, 'out_data_1': {'value': None, 'shape': None, 'kind': 'data'}, 'out_data_2': {'value': None, 'shape': None, 'kind': 'data'}, 'out_data_3': {'value': None, 'shape': None, 'kind': 'data'}, }, [('split_dim', 'split_node'), ('data_to_split', 'split_node'), ('split_node', 'out_data_1'), ('split_node', 'out_data_2'), ('split_node', 'out_data_3'), ]) def test_tf_split_infer(self): split_node = Node(self.graph, 'split_node') self.graph.node['split_dim']['value'] = np.array(1) self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30]) tf_split_infer(split_node) exp_shape = int64_array([2, 4, 25, 30]) for out_node in split_node.out_nodes().values(): self.assertTrue(np.all(exp_shape == out_node.shape)) self.assertEqual(1, split_node.input_port) def test_tf_split_infer_negative_index(self): split_node = Node(self.graph, 'split_node') self.graph.node['split_dim']['value'] = np.array(-3) self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30]) tf_split_infer(split_node) exp_shape = int64_array([2, 4, 25, 30]) for out_node in split_node.out_nodes().values(): self.assertTrue(np.all(exp_shape == out_node.shape)) self.assertEqual(1, split_node.input_port) def test_tf_split_infer_unknown_index(self): split_node = Node(self.graph, 'split_node') self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30]) tf_split_infer(split_node) for out_node in split_node.out_nodes().values(): self.assertIsNone(out_node.shape) def test_tf_split_infer_input_shape_is_None(self): split_node = Node(self.graph, 'split_node') self.graph.node['split_dim']['value'] = np.array(1) tf_split_infer(split_node) for out_node in split_node.out_nodes().values(): self.assertIsNone(out_node.shape) def test_tf_split_infer_wrong_num_split(self): split_node = Node(self.graph, 'split_node') self.graph.node['split_dim']['value'] = np.array(0) self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30]) tf_split_infer(split_node) for out_node in split_node.out_nodes().values(): self.assertIsNone(out_node.shape) class TestTFSplitVInfer(unittest.TestCase): graph = None def setUp(self): self.graph = build_graph({'data_to_split': {'value': None, 'shape': None, 'kind': 'data'}, 'size_splits': {'value': [3, 5, 4], 'kind': 'data'}, 'split_dim': {'value': None, 'kind': 'data'}, 'split_node': {'kind': 'op', 'op': 'Split', 'axis': None}, 'out_data_1': {'value': None, 'shape': None, 'kind': 'data'}, 'out_data_2': {'value': None, 'shape': None, 'kind': 'data'}, 'out_data_3': {'value': None, 'shape': None, 'kind': 'data'}, }, [('data_to_split', 'split_node'), ('size_splits', 'split_node'), ('split_dim', 'split_node'), ('split_node', 'out_data_1'), ('split_node', 'out_data_2'), ('split_node', 'out_data_3'), ]) def test_tf_split_infer_three_inputs(self): split_node = Node(self.graph, 'split_node') self.graph.node['split_dim']['value'] = np.array(1) self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30]) tf_split_v_infer(split_node) exp_shape = [int64_array([2, 3, 25, 30]), int64_array([2, 5, 25, 30]), int64_array([2, 4, 25, 30])] for ind, out_node in split_node.out_nodes().items(): self.assertTrue(np.all(exp_shape[ind] == out_node.shape)) def test_tf_split_infer_undef_size(self): split_node = Node(self.graph, 'split_node') self.graph.node['split_dim']['value'] = np.array(1) self.graph.node['data_to_split']['shape'] = int64_array([2, 12, 25, 30]) self.graph.node['size_splits']['value'] = np.array([3, 2, -1]) tf_split_v_infer(split_node) exp_shape = [int64_array([2, 3, 25, 30]), int64_array([2, 2, 25, 30]), int64_array([2, 7, 25, 30])] for ind, out_node in split_node.out_nodes().items(): self.assertTrue(np.all(exp_shape[ind] == out_node.shape)) class TestTFUnpack(unittest.TestCase): graph = None def setUp(self): self.graph = build_graph({'data_to_split': {'value': None, 'shape': None, 'kind': 'data'}, 'unpack': {'kind': 'op', 'op': 'Split', 'num_split': 3, 'axis': None}, 'out_data_1': {'value': None, 'shape': None, 'kind': 'data'}, 'out_data_2': {'value': None, 'shape': None, 'kind': 'data'}, 'out_data_3': {'value': None, 'shape': None, 'kind': 'data'}, 'out_data_4': {'value': None, 'shape': None, 'kind': 'data'}, }, [('data_to_split', 'unpack'), ('unpack', 'out_data_1'), ('unpack', 'out_data_2'), ('unpack', 'out_data_3'), ]) def test_tf_unpack_infer(self): unpack_node = Node(self.graph, 'unpack') self.graph.node['unpack']['axis'] = np.array(1) self.graph.node['data_to_split']['shape'] = int64_array([2, 3, 25, 30]) tf_unpack_infer(unpack_node) exp_shape = int64_array([2, 1, 25, 30]) for out_node in unpack_node.out_nodes().values(): self.assertTrue(np.all(exp_shape == out_node.shape)) def test_tf_unpack_infer_default_number_of_pieces(self): unpack_node = Node(self.graph, 'unpack') self.graph.node['unpack']['axis'] = np.array(1) self.graph.node['unpack']['num_split'] = None self.graph.node['data_to_split']['shape'] = int64_array([2, 3, 25, 30]) tf_unpack_infer(unpack_node) exp_shape = int64_array([2, 1, 25, 30]) for out_node in unpack_node.out_nodes().values(): self.assertTrue(np.all(exp_shape == out_node.shape)) def test_tf_unpack_infer_not_supported(self): # the case when the size of the dimension being unpacked is not equal to number of pieces is not supported unpack_node = Node(self.graph, 'unpack') self.graph.node['unpack']['axis'] = np.array(1) self.graph.node['data_to_split']['shape'] = int64_array([2, 6, 25, 30]) tf_unpack_infer(unpack_node) for out_node in unpack_node.out_nodes().values(): self.assertIsNone(out_node.shape) class TestSplitFunc(unittest.TestCase): graph = None def setUp(self): self.graph = build_graph_with_edge_attrs( {'data_to_split': {'value': None, 'shape': int64_array([2, 12, 25, 44]), 'kind': 'data'}, 'split_node': {'kind': 'op', 'op': 'Split', 'axis': None}, 'out_data_2': {'value': None, 'shape': None, 'kind': 'data'}, 'out_data_5': {'value': None, 'shape': None, 'kind': 'data'}, 'out_data_7': {'value': None, 'shape': None, 'kind': 'data'}, }, [('data_to_split', 'split_node', {'in': 0}), ('split_node', 'out_data_2', {'out': 2}), ('split_node', 'out_data_5', {'out': 5}), ('split_node', 'out_data_7', {'out': 7}), ]) def test_split_non_sequential_output_port(self): split(Node(self.graph, 'data_to_split'), Node(self.graph, 'split_node'), -1, [3, 2, 7, 5, 6, 4, 9, 8]) self.assertTrue(np.all(Node(self.graph, 'out_data_2').shape == [2, 12, 25, 7])) self.assertTrue(np.all(Node(self.graph, 'out_data_5').shape == [2, 12, 25, 4])) self.assertTrue(np.all(Node(self.graph, 'out_data_7').shape == [2, 12, 25, 8])) def test_split_value_infer_non_sequential_output_port(self): data_node = Node(self.graph, 'data_to_split') value = np.array(range(2 * 12 * 25 * 44)).reshape(data_node.shape) data_node.value = value.copy() split(data_node, Node(self.graph, 'split_node'), -1, [3, 2, 7, 5, 6, 4, 9, 8]) self.assertTrue(np.all(Node(self.graph, 'out_data_2').shape == [2, 12, 25, 7])) self.assertTrue(np.all(Node(self.graph, 'out_data_5').shape == [2, 12, 25, 4])) self.assertTrue(np.all(Node(self.graph, 'out_data_7').shape == [2, 12, 25, 8])) self.assertTrue(np.all(Node(self.graph, 'out_data_2').value == value[:, :, :, 5:12])) self.assertTrue(np.all(Node(self.graph, 'out_data_5').value == value[:, :, :, 23:27])) self.assertTrue(np.all(Node(self.graph, 'out_data_7').value == value[:, :, :, 36:]))