Commit 17e66dc5 authored by Alexey Suhov's avatar Alexey Suhov Committed by openvino-pushbot

Added unit tests and readme for model optimizer (#79)

* added unit tests
* added readme for model optimizer
* added a list of supported IE plugins
parent 30594bb3
## Repository components
The Inference Engine can infer models in different formats with various input and output formats.
The open source version of Inference Engine includes the following plugins:
| PLUGIN | DEVICE TYPES |
| ---------------------| -------------|
| CPU plugin | Intel® Xeon® with Intel® AVX2 and AVX512, Intel® Core™ Processors with Intel® AVX2, Intel® Atom® Processors with Intel® SSE |
| GPU plugin | Intel® Processor Graphics, including Intel® HD Graphics and Intel® Iris® Graphics |
| GNA plugin | Intel® Speech Enabling Developer Kit, Amazon Alexa* Premium Far-Field Developer Kit, Intel® Pentium® Silver processor J5005, Intel® Celeron® processor J4005, Intel® Core™ i3-8121U processor |
| Heterogeneous plugin | Heterogeneous plugin enables computing for inference on one network on several Intel® devices. |
Inference Engine plugins for Intel® FPGA and Intel® Movidius™ Neural Compute Stick are distributed only in a binary form as a part of [Intel® Distribution of OpenVINO™](https://software.intel.com/en-us/openvino-toolkit).
## Build on Linux\* Systems
The software was validated on:
......
## Project structure
Project structure:
<pre>
|-- root
|-- extensions
|-- front/ - graph transformations during front phase
|-- middle/ - graph transformations during middle phase (after partial inference)
|-- end/ - graph transformations during back phase (before IR generation)
|-- ops/ - Model Optimizer operation classes
|-- mo
|-- back - Back-End logic: contains IR emitting logic
|-- front - Front-End logic: contains matching between Framework-specific layers and IR specific,
calculation of output shapes for each registered layer
|-- graph - Graph utilities to work with internal IR representation
|-- middle - Graph transformations - optimizations of the model
|-- ops - Model Optimizer operation classes
|-- pipeline - Sequence of steps required to create IR for each framework
|-- utils - Utility functions
|-- tf_call_ie_layer - Sources for TensorFlow fallback in Inference Engine during model inference
|-- mo.py - Centralized entry point that can be used for any supported framework
|-- mo_caffe.py - Entry point particularly for Caffe
|-- mo_mxnet.py - Entry point particularly for MXNet
|-- mo_tf.py - Entry point particularly for TensorFlow
</pre>
## Prerequisites
Model Optimizer requires:
1. Python 3.4 or newer
## Installation instructions
1. Go to the Model Optimizer folder
2. Create virtual environment and activate it. This option is strongly recommended as it creates a Python sandbox and
dependencies for Model Optimizer do not influence global Python configuration, installed libraries etc. At the same
time, special flag ensures that system-wide Python libraries are also available in this sandbox. Skip this
step only if you do want to install all Model Optimizer dependencies globally:
* Create environment:
<pre>virtualenv -p /usr/bin/python3.6 .env3 --system-site-packages</pre>
* Activate it:
<pre>. .env3/bin/activate</pre>
3. Install dependencies. If you want to convert models only from particular framework, you should use one of
available <code>requirements_*.txt</code> files corresponding to the framework of choice. For example, for Caffe use
<code>requirements_caffe.txt</code> and so on. When you decide to switch later to other frameworks, please install dependencies
for them using the same mechanism:
<pre>
pip3 install -r requirements.txt
</pre>
## Command-Line Interface (CLI)
The following short examples are framework-dependent. Please read the complete help
with --help option for details across all frameworks:
<pre>
python3 mo.py --help
</pre>
There are several scripts that convert a model:
1. <code>mo.py</code> -- universal entry point that can convert a model from any supported framework
2. <code>mo_caffe.py</code> -- dedicated script for Caffe models conversion
3. <code>mo_mxnet.py</code> -- dedicated script for MXNet models conversion
4. <code>mo_tf.py</code> -- dedicated script for TensorFlow models conversion
5. <code>mo_onnx.py</code> -- dedicated script for ONNX models conversion
6. <code>mo_kaldi.py</code> -- dedicated script for Kaldi models conversion
<code>mo.py</code> can deduce original framework where input model was trained by an extension of
the model file. Or <code>--framework</code> option can be used for this purpose if model files
don't have standard extensions (<code>.pb</code> - for TensorFlow models, <code>.params</code> - for MXNet models,
<code>.caffemodel</code> - for Caffe models). So, the following commands are equivalent::
<pre>
python3 mo.py --input_model /user/models/model.pb
python3 mo.py --framework tf --input_model /user/models/model.pb
</pre>
The following examples illustrate the shortest command lines to convert a model per
framework.
### Convert TensorFlow model
To convert a frozen TensorFlow model contained in binary file <code>model-file.pb</code>, run
dedicated entry point <code>mo_tf.py</code>:
python3 mo_tf.py --input_model model-file.pb
### Convert Caffe model
To convert a Caffe model contained in <code>model-file.prototxt</code> and <code>model-file.caffemodel</code> run
dedicated entry point <code>mo_caffe.py</code>:
<pre>
python3 mo_caffe.py --input_model model-file.caffemodel
</pre>
### Convert MXNet model
To Convert an MXNet model in <code>model-file-symbol.json</code> and <code>model-file-0000.params</code> run
dedicated entry point <code>mo_mxnet.py</code>:
<pre>
python3 mo_mxnet.py --input_model model-file
</pre>
> **NOTE**: for TensorFlow* all Placeholder ops are represented as Input layers in the final IR.
### Convert ONNX* model
The Model Optimizer assumes that you have an ONNX model that was directly downloaded from a public repository or converted from any framework that supports exporting to the ONNX format.
Use the mo_onnx.py script to simply convert a model with the path to the input model .onnx file:
<pre>
python3 mo_onnx.py --input_model model-file.onnx
</pre>
Input channels re-ordering, scaling, subtraction of mean values and other preprocessing features
are not applied by default. To pass necessary values to Model Optimizer, please run <code>mo.py</code>
(or <code>mo_tf.py</code>, <code>mo_caffe.py</code>, <code>mo_mxnet.py</code>) with <code>--help</code> and
examine all available options.
## Working with Inference Engine
To the moment, Inference Engine is the only consumer of IR models that Model Optimizer produces.
The whole workflow and more documentation on the structure of IR are documented in the Developer Guide
of Inference Engine. Note that sections about running Model Optimizer refer to the old version
of the tool and can not be applied to the current version of Model Optimizer.
### How to run unit-tests
1. Run tests with:
<pre>
python -m unittest discover -p "*_test.py" [-s PATH_TO_DIR]
</pre>
---
\* Other names and brands may be claimed as the property of others.
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.back.PermuteForReshape import PermuteForReshape
from mo.graph.graph import Node
from mo.ops.op import PermuteAttrs
from mo.utils.unittest.graph import build_graph_with_attrs, compare_graphs
class ReshapeToPermuteTest(unittest.TestCase):
nodes = [
('input_data', {'kind': 'data', 'shape': None}),
('reshape', {'kind': 'op', 'op': 'Squeeze', 'type': 'Reshape', 'dim': None}),
('reshape_data', {'kind': 'data'}),
]
edges = [
('input_data', 'reshape'),
('reshape', 'reshape_data'),
]
permute_nodes = [
('permute', {'kind': 'op', 'op': 'Permute'}),
('permute_data', {'kind': 'data', 'shape': None})
]
permute_edges = [
('input_data', 'permute'),
('permute', 'permute_data'),
('permute_data', 'reshape'),
]
def test_from3D_to3D(self):
input_shape = np.array([2, 3, 4])
new_shape = np.array([2, 3, 4])
graph = build_graph_with_attrs(
nodes_with_attrs=self.nodes,
edges_with_attrs=self.edges,
update_nodes_attributes=[('input_data', {'shape': input_shape}),
('reshape', {'dim': new_shape}),
('reshape_data', {'shape': new_shape})]
)
graph.graph['layout'] = 'NHWC'
# add permute attrs to reshape
reshape = Node(graph, 'reshape')
PermuteAttrs.create_permute_attrs(reshape, attrs=[('dim', 'output:0')])
tested_pattern = PermuteForReshape()
tested_pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph, last_node='reshape_data')
self.assertTrue(flag, resp)
def test_from4D_to3D(self):
input_shape = np.array([1, 2, 3, 4])
new_shape = np.array([3, 4, 2])
nhwc_shape = np.array([1, 3, 4, 2])
graph = build_graph_with_attrs(
nodes_with_attrs=self.nodes,
edges_with_attrs=self.edges,
update_nodes_attributes=[('input_data', {'shape': input_shape}),
('reshape', {'dim': new_shape}),
('reshape_data', {'shape': new_shape})]
)
graph.graph['layout'] = 'NHWC'
# add permute attrs to reshape
reshape = Node(graph, 'reshape')
PermuteAttrs.create_permute_attrs(reshape, attrs=[('dim', 'output:0')])
tested_pattern = PermuteForReshape()
tested_pattern.find_and_replace_pattern(graph)
graph_ref = build_graph_with_attrs(
nodes_with_attrs=self.nodes + self.permute_nodes,
edges_with_attrs=self.edges[1:] + self.permute_edges,
update_nodes_attributes=[('input_data', {'shape': input_shape}),
('reshape', {'dim': new_shape}),
('reshape_data', {'shape': new_shape}),
('permute_data', {'shape': nhwc_shape})]
)
# check graphs equality
(flag, resp) = compare_graphs(graph, graph_ref, last_node='reshape_data')
self.assertTrue(flag, resp)
# check righ order in new permutation node
permute_order = graph.node['reshape/Permute_']['order']
self.assertTrue(np.all(permute_order == np.array([0, 2, 3, 1]))) # from NCHW to NHWC
def test_from_5D_to_3D(self):
input_shape = np.array([1, 2, 1, 3, 4]) # NCDHW 1 1 3 4 2
new_shape = np.array([3, 4, 2])
nhwc_shape = np.array([1, 1, 3, 4, 2])
graph = build_graph_with_attrs(
nodes_with_attrs=self.nodes,
edges_with_attrs=self.edges,
update_nodes_attributes=[('input_data', {'shape': input_shape}),
('reshape', {'dim': new_shape}),
('reshape_data', {'shape': new_shape})]
)
graph.graph['layout'] = 'NHWC'
# add permute attrs to reshape
reshape = Node(graph, 'reshape')
PermuteAttrs.create_permute_attrs(reshape, attrs=[('dim', 'output:0')])
tested_pattern = PermuteForReshape()
tested_pattern.find_and_replace_pattern(graph)
graph_ref = build_graph_with_attrs(
nodes_with_attrs=self.nodes + self.permute_nodes,
edges_with_attrs=self.edges[1:] + self.permute_edges,
update_nodes_attributes=[('input_data', {'shape': input_shape}),
('reshape', {'dim': new_shape}),
('reshape_data', {'shape': new_shape}),
('permute_data', {'shape': nhwc_shape})]
)
# check graphs equality
(flag, resp) = compare_graphs(graph, graph_ref, last_node='reshape_data')
self.assertTrue(flag, resp)
# check righ order in new permutation node
permute_order = graph.node['reshape/Permute_']['order']
self.assertTrue(np.all(permute_order == np.array([0, 2, 3, 4, 1]))) # from NCDHW to NDHWC
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.back.ShufflenetReLUReorder import ShufflenetReLUReorder
from mo.utils.unittest.graph import build_graph, compare_graphs
# The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the
# dictionary with node attributes.
nodes_attributes = {
'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# ReLU
'relu_1': {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'},
'relu_1_data': {'value': None, 'shape': None, 'kind': 'data'},
# Reshape layers
'reshape_1': {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'},
'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'},
'reshape_2': {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'},
'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'},
'reshape_3': {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'},
'reshape_3_data': {'value': None, 'shape': None, 'kind': 'data'},
# Transpose layer
'transpose_1': {'type': 'Permute', 'kind': 'op', 'op': 'Transpose'},
'transpose_1_data': {'value': None, 'shape': None, 'kind': 'data'},
# Conv layer
'conv_1': {'type': 'Convolution', 'kind': 'op', 'op': 'Conv2d'},
'conv_1_data': {'value': None, 'shape': None, 'kind': 'data'},
}
class ShufflenetReLUReorderTests(unittest.TestCase):
def test_1(self):
graph = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'relu_1'),
('relu_1', 'relu_1_data'),
('relu_1_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'transpose_1'),
('transpose_1', 'transpose_1_data'),
('transpose_1_data', 'reshape_2'),
('reshape_2', 'reshape_2_data'),
('reshape_2_data', 'conv_1'),
('conv_1', 'conv_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 112])},
'relu_1_data': {'shape': np.array([1, 227, 227, 112])},
'reshape_1_data': {'shape': np.array([227, 227, 4, 28])},
'transpose_1': {'order': np.array([0, 1, 3, 2])},
'transpose_1_data': {'shape': np.array([227, 227, 28, 4])},
'reshape_2_data': {'shape': np.array([1, 227, 227, 112])},
'conv_1_data': {'shape': np.array([1, 227, 227, 112])},
'conv_1': {'pad': np.array([1, 1])}
})
graph.graph['layout'] = 'NHWC'
graph_ref = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'transpose_1'),
('transpose_1', 'transpose_1_data'),
('transpose_1_data', 'reshape_2'),
('reshape_2', 'reshape_2_data'),
('reshape_2_data', 'relu_1'),
('relu_1', 'relu_1_data'),
('relu_1_data', 'conv_1'),
('conv_1', 'conv_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 112])},
'relu_1_data': {'shape': np.array([1, 227, 227, 112])},
'reshape_1_data': {'shape': np.array([227, 227, 4, 28])},
'transpose_1': {'order': np.array([0, 1, 3, 2])},
'transpose_1_data': {'shape': np.array([227, 227, 28, 4])},
'reshape_2_data': {'shape': np.array([1, 227, 227, 112])},
'conv_1_data': {'shape': np.array([1, 227, 227, 112])},
})
pattern = ShufflenetReLUReorder()
pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_2_neg(self):
graph = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'transpose_1'),
('transpose_1', 'transpose_1_data'),
('transpose_1_data', 'reshape_2'),
('reshape_2', 'reshape_2_data'),
('reshape_2_data', 'conv_1'),
('conv_1', 'conv_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 112])},
'relu_1_data': {'shape': np.array([1, 227, 227, 112])},
'reshape_1_data': {'shape': np.array([227, 227, 4, 28])},
'transpose_1': {'order': np.array([0, 1, 3, 2])},
'transpose_1_data': {'shape': np.array([227, 227, 28, 4])},
'reshape_2_data': {'shape': np.array([1, 227, 227, 112])},
'conv_1_data': {'shape': np.array([1, 227, 227, 112])},
})
graph.graph['layout'] = 'NHWC'
graph_ref = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'transpose_1'),
('transpose_1', 'transpose_1_data'),
('transpose_1_data', 'reshape_2'),
('reshape_2', 'reshape_2_data'),
('reshape_2_data', 'conv_1'),
('conv_1', 'conv_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 112])},
'relu_1_data': {'shape': np.array([1, 227, 227, 112])},
'reshape_1_data': {'shape': np.array([227, 227, 4, 28])},
'transpose_1': {'order': np.array([0, 1, 3, 2])},
'transpose_1_data': {'shape': np.array([227, 227, 28, 4])},
'reshape_2_data': {'shape': np.array([1, 227, 227, 112])},
'conv_1_data': {'shape': np.array([1, 227, 227, 112])},
})
pattern = ShufflenetReLUReorder()
pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'conv_1_data', check_op_attrs=True)
self.assertTrue(flag, resp)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.back.TileReshaper import TileReshaper
from mo.ops.tile import Tile
from mo.utils.unittest.graph import build_graph, compare_graphs
# The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the
# dictionary with node attributes.
nodes_attributes = {
'previous_data': {'shape': np.array([1, 1, 101]), 'kind': 'data'},
'tile': {'type': 'Tile', 'kind': 'op', 'axis': 1, 'tiles': 16, 'infer': Tile.infer},
'tile_data': {'shape': np.array([1, 16, 101]), 'kind': 'data'},
'next_op': {'kind': 'op', 'op': 'SomeOp'},
}
edge_attributes = [
('previous_data', 'tile'),
('tile', 'tile_data'),
('tile_data', 'next_op'),
]
nodes_attributes_ref = {
'previous_data': {'kind': 'data', 'shape': np.array([1, 1, 101])},
'reshape_op_before': {'type': 'Reshape', 'kind': 'op', 'dim': [1, 1, 101, 1]},
'reshape_data_before': {'kind': 'data', 'shape': np.array([1, 1, 101, 1])},
'tile': {'type': 'Tile', 'kind': 'op', 'infer': Tile.infer, 'axis': 1, 'tiles': 16},
'tile_data': {'shape': np.array([1, 16, 101, 1]), 'kind': 'data'},
'reshape_op_after': {'type': 'Reshape', 'kind': 'op', 'dim': [1, 16, 101]},
'reshape_data_after': {'kind': 'data', 'shape': np.array([1, 16, 101])},
'next_op': {'kind': 'op', 'op': 'SomeOp'},
}
edge_attributes_ref = [
('previous_data', 'reshape_op_before'),
('reshape_op_before', 'reshape_data_before'),
('reshape_data_before', 'tile'),
('tile', 'tile_data'),
('tile_data', 'reshape_op_after'),
('reshape_op_after', 'reshape_data_after'),
('reshape_data_after', 'next_op')
]
class TileReshaperTests(unittest.TestCase):
def test_tile_reshaper(self):
graph = build_graph(nodes_attributes, edge_attributes)
graph_ref = build_graph(nodes_attributes_ref, edge_attributes_ref)
pattern = TileReshaper()
pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'next_op', check_op_attrs=True)
self.assertTrue(flag, resp)
"""
Copyright (c) 2017-2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.back.insert_compatibility_l2normalization import CompatibilityL2NormalizationPattern
from mo.utils.unittest.graph import build_graph
class CompatibilityL2NormalizationPatternTest(unittest.TestCase):
nodes = {
'input_node': {
'kind': 'data'
},
'l2norm_node': {
'op': 'Normalize',
'kind': 'op',
'type': 'Normalize',
},
'output_node': {
'kind': 'data'
}
}
def test_insert_data(self):
graph = build_graph(self.nodes, [('input_node', 'l2norm_node'), ('l2norm_node', 'output_node')],
{'input_node': {'shape': np.array([1, 10])},
})
CompatibilityL2NormalizationPattern().find_and_replace_pattern(graph)
self.assertEqual(len(graph.nodes()), 4)
self.assertEqual(graph.node['l2norm_node_weights']['name'], 'l2norm_node_weights')
self.assertEqual(len(graph.node['l2norm_node_weights']['value']), 10)
expect_value = np.full([10], 1.0, np.float32)
for i, val in enumerate(expect_value):
self.assertEqual(graph.node['l2norm_node_weights']['value'][i], val)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
from mo.utils.unittest.graph import build_graph
class KaldiRemoveMemoryOutputTest(unittest.TestCase):
nodes = {
'input_node': {
'kind': 'data'
},
'memory_node': {
'op': 'Memory',
'kind': 'op'
},
'output_node': {
'kind': 'data'
}
}
def test_remove_out_data_for_memory(self):
graph = build_graph(self.nodes, [('input_node', 'memory_node')])
# Need for matching in pattern. The edge memory_node->out_node must contain only the attribute 'out' = 0
# build_graph creates edge memory_node->out_node with attributes 'in' and 'out'
graph.add_node('output_node', is_output=True, **self.nodes['output_node'])
graph.add_edge('memory_node', 'output_node', out=0)
KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
self.assertNotIn('output_node', graph.node)
def test_do_not_remove_out_data_for_memory(self):
graph = build_graph(self.nodes, [('input_node', 'memory_node')])
graph.add_node('output_node', **self.nodes['output_node'])
graph.add_edge('memory_node', 'output_node', out=0)
KaldiRemoveMemoryOutputBackReplacementPattern().find_and_replace_pattern(graph)
self.assertIn('output_node', graph.node)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.back.remove_last_softmax_pattern import RemoveLastSoftMaxPattern
from mo.utils.unittest.graph import build_graph
class KaldiRemoveLastSoftMaxTest(unittest.TestCase):
nodes = {
'input_node': {
'kind': 'data'
},
'softmax_node': {
'op': 'SoftMax',
'kind': 'op'
},
'output_node': {
'kind': 'data'
}
}
def test_remove_last_SoftMax(self):
graph = build_graph(self.nodes, [
('input_node', 'softmax_node'),
('softmax_node', 'output_node')
], {'output_node': {'is_output': True}})
RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
self.assertNotIn('softmax_node', graph.node)
def test_do_not_remove_no_last_SoftMax(self):
graph = build_graph(self.nodes, [
('input_node', 'softmax_node'),
('softmax_node', 'output_node')
])
RemoveLastSoftMaxPattern().find_and_replace_pattern(graph)
self.assertIn('softmax_node', graph.node)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.accum_ext import AccumFrontExtractor
from extensions.ops.accum import AccumOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeAccumProtoLayer:
def __init__(self, val):
self.accum_param = val
class TestAccumExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['Accum'] = AccumOp
def test_accum_no_pb_no_ml(self):
self.assertRaises(AttributeError, AccumFrontExtractor.extract, None)
@patch('extensions.front.caffe.accum_ext.collect_attributes')
def test_accum_ext(self, collect_attributes_mock):
params = {
'top_height': 200,
'top_width': 300,
'size_divisible_by': 3,
'have_reference': 'False',
}
collect_attributes_mock.return_value = {
**params,
'have_reference': 0
}
fake_pl = FakeAccumProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
AccumFrontExtractor.extract(fake_node)
exp_res = {
'type': "Accum",
'top_height': 200,
'top_width': 300,
'size_divisible_by': 3,
'have_reference': 0,
'infer': AccumOp.accum_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.argmax_ext import ArgMaxFrontExtractor
from extensions.ops.argmax import ArgMaxOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeArgMaxProtoLayer:
def __init__(self, val):
self.argmax_param = val
class TestArgMaxExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['ArgMax'] = ArgMaxOp
def test_argmax_no_pb_no_ml(self):
self.assertRaises(AttributeError, ArgMaxFrontExtractor.extract, None)
@patch('extensions.front.caffe.argmax_ext.merge_attrs')
def test_argmax_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'out_max_val': True,
'top_k': 100,
'axis': 2
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakeArgMaxProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
ArgMaxFrontExtractor.extract(fake_node)
exp_res = {
'out_max_val': True,
'top_k': 100,
'axis': 2,
'infer': ArgMaxOp.argmax_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.caffe.axpy import AxpyToEltwise
from mo.graph.graph import Node
from mo.utils.unittest.graph import build_graph_with_edge_attrs
class TestAxpyReplacer(unittest.TestCase):
def test_axpy(self):
nodes = {
'node_1': {'kind': 'op', 'type': 'Identity', 'op': 'Placeholder'},
'node_2': {'kind': 'op', 'type': 'Identity', 'op': 'Placeholder'},
'node_3': {'kind': 'op', 'type': 'Identity', 'op': 'Placeholder'},
'axpy': {'type': 'Axpy', 'kind': 'op', 'op': 'Axpy'},
'node_4': {'kind': 'op', 'type': 'Identity', 'op': 'Placeholder'}}
edges = [
('node_1', 'axpy', {'in': 0}),
('node_2', 'axpy', {'in': 1}),
('node_3', 'axpy', {'in': 2}),
('axpy', 'node_4', {'in': 0})]
graph = build_graph_with_edge_attrs(nodes, edges)
node = Node(graph, 'axpy')
replacer = AxpyToEltwise()
replacer.replace_op(graph, node)
scale_node = [node for node, attrs in list(graph.nodes(data=True)) if attrs['type'] == 'ScaleShift']
self.assertEqual(len(scale_node), 1)
add_node = [node for node, attrs in list(graph.nodes(data=True)) if attrs['type'] == 'Eltwise']
self.assertEqual(len(add_node), 1)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import unittest
from extensions.front.caffe.bn import BNToScaleShift
from mo.graph.graph import Node
from mo.utils.unittest.extractors import FakeParam
from mo.utils.unittest.graph import build_graph_with_edge_attrs
class FakeBNProtoLayer:
def __init__(self, val):
self.bn_param = val
class FakeBNBinLayer:
def __init__(self, val):
self.blobs = val
class TestBNReplacer(unittest.TestCase):
def test_bn(self):
bn_pb = FakeBNProtoLayer(FakeParam('eps', 0.0001))
mean = [1, 2.5, 3]
var = [0.5, 0.1, 1.2]
scale = [2.3, 3.4, 4.5]
shift = [0.8, 0.6, 0.4]
bn_bin = FakeBNBinLayer([FakeParam('data', mean),
FakeParam('data', var),
FakeParam('data', scale),
FakeParam('data', shift)])
nodes = {
'node_1': {'kind': 'op', 'type': 'Identity', 'op': 'Placeholder'},
'bn': {'type': 'BN', 'kind': 'op', 'op': 'BN',
'pb': bn_pb,
'model_pb': bn_bin},
'node_2': {'kind': 'op', 'type': 'Identity', 'op': 'Placeholder'}}
edges = [
('node_1', 'bn', {'in': 0}),
('bn', 'node_2', {'in': 0})]
graph = build_graph_with_edge_attrs(nodes, edges)
node = Node(graph, 'bn')
replacer = BNToScaleShift()
replacer.replace_op(graph, node)
scale_node = [node for node, attrs in list(graph.nodes(data=True)) if attrs['type'] == 'ScaleShift']
self.assertEqual(len(scale_node), 1)
scale_ref = np.array([1.11796412, 3.2272172, 4.74282367])
shift_ref = np.array([-2.07131747, -10.87253847, -20.14270653])
for i in range(len(mean)):
self.assertAlmostEqual(graph.node[scale_node[0]]['scale'][i], scale_ref[i])
self.assertAlmostEqual(graph.node[scale_node[0]]['bias'][i], shift_ref[i])
This diff is collapsed.
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.correlation_ext import CorrelationFrontExtractor
from extensions.ops.correlation import CorrelationOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeCorrProtoLayer:
def __init__(self, val):
self.correlation_param = val
class TestCorrelationExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['Correlation'] = CorrelationOp
def test_da_no_pb_no_ml(self):
self.assertRaises(AttributeError, CorrelationFrontExtractor.extract, None)
@patch('extensions.front.caffe.correlation_ext.merge_attrs')
def test_resample_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'pad': 20,
'kernel_size': 1,
'max_displacement': 20,
'stride_1': 1,
'stride_2': 2,
'single_direction': 0,
'do_abs': False,
'correlation_type': 'caffe.CorrelationParameter.MULTIPLY'
}
merge_attrs_mock.return_value = {
**params,
'test': 54,
'test2': 'test3'
}
fake_pl = FakeCorrProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
CorrelationFrontExtractor.extract(fake_node)
exp_res = {
'type': "Correlation",
'pad': 20,
'kernel_size': 1,
'max_displacement': 20,
'stride_1': 1,
'stride_2': 2,
'single_direction': 0,
'do_abs': False,
'correlation_type': 'caffe.CorrelationParameter.MULTIPLY',
'infer': CorrelationOp.corr_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.ctcgreedydecoder_ext import CTCGreedyDecoderFrontExtractor
from extensions.ops.ctc_greedy_decoder import CTCGreedyDecoderOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeCTCGreedyDecoderProtoLayer:
def __init__(self, val):
self.ctc_decoder_param = val
class TestCTCGreedyDecoderExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['CTCGreedyDecoder'] = CTCGreedyDecoderOp
def test_ctcgreedydecoder_no_pb_no_ml(self):
self.assertRaises(AttributeError, CTCGreedyDecoderFrontExtractor.extract, None)
@patch('extensions.front.caffe.ctcgreedydecoder_ext.merge_attrs')
def test_ctcgreedydecoder_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'ctc_merge_repeated': True
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakeCTCGreedyDecoderProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
CTCGreedyDecoderFrontExtractor.extract(fake_node)
exp_res = {
'type': "CTCGreedyDecoder",
'ctc_merge_repeated': 1,
'infer': CTCGreedyDecoderOp.ctc_greedy_decoder_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2017-2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
import numpy as np
from extensions.front.caffe.data_augmentation_ext import DataAugmentationFrontExtractor
from extensions.ops.data_augmentation import DataAugmentationOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeDAProtoLayer:
def __init__(self, val):
self.augmentation_param = val
class TestDA(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['DataAugmentation'] = DataAugmentationOp
def test_da_no_pb_no_ml(self):
self.assertRaises(AttributeError, DataAugmentationFrontExtractor.extract, None)
@patch('extensions.front.caffe.data_augmentation_ext.merge_attrs')
def test_da_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'crop_width': 0,
'crop_height': 0,
'write_augmented': "",
'max_multiplier': 255.0,
'augment_during_test': True,
'recompute_mean': 0,
'write_mean': "",
'mean_per_pixel': False,
'mean': 0,
'mode': "add",
'bottomwidth': 0,
'bottomheight': 0,
'num': 0,
'chromatic_eigvec': [0.0]
}
merge_attrs_mock.return_value = {
**params,
'test': 54,
'test2': 'test3'
}
fake_pl = FakeDAProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
DataAugmentationFrontExtractor.extract(fake_node)
exp_res = {
'type': 'DataAugmentation',
'op': 'DataAugmentation',
'crop_width': 0,
'crop_height': 0,
'write_augmented': "",
'max_multiplier': 255.0,
'augment_during_test': 1,
'recompute_mean': 0,
'write_mean': "",
'mean_per_pixel': 0,
'mean': 0,
'mode': "add",
'bottomwidth': 0,
'bottomheight': 0,
'num': 0,
'chromatic_eigvec': [0.0],
'infer': DataAugmentationOp.data_augmentation_infer
}
for key in exp_res.keys():
if key in ('chromatic_eigvec',):
np.testing.assert_equal(exp_res[key], fake_node[key])
else:
self.assertEqual(exp_res[key], fake_node[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.grn_ext import GRNFrontExtractor
from extensions.ops.grn import GRNOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.ops.op import Op
class FakeGRNProtoLayer:
def __init__(self, val):
self.grn_param = val
class TestGRNExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['GRN'] = GRNOp
def test_grn_no_pb_no_ml(self):
self.assertRaises(AttributeError, GRNFrontExtractor.extract, None)
@patch('extensions.front.caffe.grn_ext.merge_attrs')
def test_grn_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'bias': 0.7
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakeGRNProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
GRNFrontExtractor.extract(fake_node)
exp_res = {
'type': "GRN",
'bias': 0.7,
'infer': copy_shape_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.interp_ext import InterpFrontExtractor
from extensions.ops.interp import InterpOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeInterpProtoLayer:
def __init__(self, val):
self.interp_param = val
class TestInterpExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['Interp'] = InterpOp
def test_interp_no_pb_no_ml(self):
self.assertRaises(AttributeError, InterpFrontExtractor.extract, None)
@patch('extensions.front.caffe.interp_ext.merge_attrs')
def test_interp_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'height': 1.1,
'width': 2.2,
'zoom_factor': 3.3,
'shrink_factor': 4.4,
'pad_beg': 5.5,
'pad_end': 6.6
}
merge_attrs_mock.return_value = {
**params,
'test': 54,
'test2': 'test3'
}
fake_pl = FakeInterpProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
InterpFrontExtractor.extract(fake_node)
exp_res = {
'type': "Interp",
'height': 1.1,
'width': 2.2,
'zoom_factor': 3.3,
'shrink_factor': 4.4,
'pad_beg': 5.5,
'pad_end': 6.6,
'infer': InterpOp.interp_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.normalize_ext import NormalizeFrontExtractor
from extensions.ops.normalize import NormalizeOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.ops.op import Op
class FakeNormalizeProtoLayer:
def __init__(self, val):
self.norm_param = val
class TestNormalizeExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['Normalize'] = NormalizeOp
def test_normalize_no_pb_no_ml(self):
self.assertRaises(AttributeError, NormalizeFrontExtractor.extract, None)
@patch('extensions.front.caffe.normalize_ext.collect_attributes')
def test_normalize_ext_ideal_numbers(self, collect_attributes_mock):
params = {
'across_spatial': 1,
'channel_shared': 0,
'eps': 0.00001
}
collect_attributes_mock.return_value = {
**params
}
fake_pl = FakeNormalizeProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
NormalizeFrontExtractor.extract(fake_node)
exp_res = {
'type': "Normalize",
'across_spatial': 1,
'channel_shared': 0,
'eps': 0.00001,
'infer': copy_shape_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.caffe.pooling_ext import PoolingFrontExtractor
from mo.front.common.extractors.utils import layout_attrs
from mo.ops.pooling import Pooling
from mo.utils.unittest.extractors import PB, FakeMultiParam
class FakeProtoLayer:
def __init__(self, val):
self.pooling_param = val
class TestPooling(unittest.TestCase):
def test_pooling_ext_global(self):
params = {
'kernel_size': 1,
'stride': 2,
'pad': 3,
'pool': 0,
'global_pooling': 1,
'ceil_mode': 1
}
node = PB({'pb': FakeProtoLayer(FakeMultiParam(params))})
PoolingFrontExtractor.extract(node)
res = node
exp_res = {
'window': np.array([1, 1, 0, 0], dtype=np.int64),
'stride': np.array([1, 1, 1, 1], dtype=np.int64),
'pad': np.array([[0, 0], [0, 0], [0, 0], [0, 0]], dtype=np.int64),
'pad_spatial_shape': np.array([[0, 0], [0, 0]], dtype=np.int64),
'pool_method': 'max',
'exclude_pad': 'true',
'infer': Pooling.infer,
'global_pool': 1,
'output_spatial_shape': None,
'pooling_convention': 'full',
'rounding_type': 'ceil'
}
exp_res.update(layout_attrs())
for i in exp_res.keys():
if i in ('window', 'stride',
'pad', 'pad_spatial_shape',
'spatial_dims', 'batch_dims',
'channel_dims'):
np.testing.assert_array_equal(res[i], exp_res[i])
else:
self.assertEqual(res[i], exp_res[i])
def test_pooling_ext(self):
params = {
'kernel_size': 1,
'stride': 2,
'pad': 3,
'pool': 1,
'global_pooling': 0,
'ceil_mode': 0
}
node = PB({'pb': FakeProtoLayer(FakeMultiParam(params))})
PoolingFrontExtractor.extract(node)
res = node
exp_res = {
'window': np.array([1, 1, 1, 1], dtype=np.int64),
'stride': np.array([1, 1, 2, 2], dtype=np.int64),
'pad': np.array([[0, 0], [0, 0], [3, 3], [3, 3]], dtype=np.int64),
'pad_spatial_shape': np.array([[3, 3], [3, 3]], dtype=np.int64),
'pool_method': 'avg',
'exclude_pad': 'false',
'infer': Pooling.infer,
'global_pool': 0,
'output_spatial_shape': None,
'pooling_convention': 'valid'
}
exp_res.update(layout_attrs())
for i in exp_res.keys():
if i in ('window', 'stride',
'pad', 'pad_spatial_shape',
'spatial_dims', 'batch_dims',
'channel_dims'):
np.testing.assert_array_equal(res[i], exp_res[i])
else:
self.assertEqual(res[i], exp_res[i])
def test_pooling_ext_exception(self):
params = {
'kernel_size': 1,
'stride': 2,
'pad': 3,
'pool': 3,
'global_pooling': 1
}
node = PB({'pb': FakeProtoLayer(FakeMultiParam(params))})
self.assertRaises(ValueError, PoolingFrontExtractor.extract, node)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.power_file_ext import PowerFileFrontExtractor
from extensions.ops.power_file import PowerFileOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.ops.op import Op
class FakePowerFileProtoLayer:
def __init__(self, val):
self.power_file_param = val
class TestPowerFileExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['PowerFile'] = PowerFileOp
def test_power_file_no_pb_no_ml(self):
self.assertRaises(AttributeError, PowerFileFrontExtractor.extract, None)
@patch('extensions.front.caffe.power_file_ext.collect_attributes')
def test_mvn_ext_ideal_numbers(self, collect_attributes_mock):
params = {
'normalize_variance': 'True',
'across_channels': 'False',
'eps': 1e-9
}
collect_attributes_mock.return_value = {
'shift_file': 'some_file_path'
}
fake_pl = FakePowerFileProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
PowerFileFrontExtractor.extract(fake_node)
exp_res = {
'type': "PowerFile",
'shift_file': 'some_file_path',
'infer': copy_shape_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.prelu_ext import PreluFrontExtractor
from extensions.ops.prelu import PreluOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakePReLUProtoLayer:
def __init__(self, val):
self.prelu_param = val
class TestPreluExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['PReLU'] = PreluOp
def test_prelu_no_pb_no_ml(self):
self.assertRaises(AttributeError, PreluFrontExtractor.extract, None)
@patch('extensions.front.caffe.prelu_ext.merge_attrs')
def test_reogyolo_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'channel_shared': False
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakePReLUProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
PreluFrontExtractor.extract(fake_node)
exp_res = {
'type': 'PReLU',
'op': 'PReLU',
'channel_shared': 0,
'infer': PreluOp.prelu_shape_infer,
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
import numpy as np
from extensions.front.caffe.priorbox_clustered_ext import PriorBoxClusteredFrontExtractor
from extensions.ops.priorbox_clustered import PriorBoxClusteredOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakePriorBoxClusteredProtoLayer:
def __init__(self, val):
self.prior_box_param = val
class TestPriorBoxClusteredExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['PriorBoxClustered'] = PriorBoxClusteredOp
def test_priorboxclustered_no_pb_no_ml(self):
self.assertRaises(AttributeError, PriorBoxClusteredFrontExtractor.extract, None)
@patch('extensions.front.caffe.priorbox_clustered_ext.merge_attrs')
def test_priorboxclustered_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'width': '30.0',
'height': '60.0',
'clip': False,
'flip': True,
'variance': np.array(['0.2', '0.3', '0.2', '0.3']),
'img_size': '300',
'img_h': '0',
'img_w': '0',
'step': '0,5',
'step_h': '0',
'step_w': '0',
'offset': '0.6'
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakePriorBoxClusteredProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
PriorBoxClusteredFrontExtractor.extract(fake_node)
exp_res = {
'op': 'PriorBoxClustered',
'type': 'PriorBoxClustered',
'width': '30.0',
'height': '60.0',
'clip': 0,
'flip': 1,
'variance': np.array(['0.2', '0.3', '0.2', '0.3']),
'img_size': '300',
'img_h': '0',
'img_w': '0',
'step': '0,5',
'step_h': '0',
'step_w': '0',
'offset': '0.6'
}
for key in exp_res.keys():
if key in ['width', 'height', 'variance']:
np.testing.assert_equal(fake_node[key], exp_res[key])
else:
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
import numpy as np
from extensions.front.caffe.priorbox_ext import PriorBoxFrontExtractor
from extensions.ops.priorbox import PriorBoxOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakePriorBoxProtoLayer:
def __init__(self, val):
self.prior_box_param = val
class TestPriorBoxExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['PriorBox'] = PriorBoxOp
def test_priorbox_no_pb_no_ml(self):
self.assertRaises(AttributeError, PriorBoxFrontExtractor.extract, None)
@patch('extensions.front.caffe.priorbox_ext.merge_attrs')
def test_priorbox_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'clip': False,
'flip': True,
'min_size': np.array([]),
'max_size': np.array([]),
'aspect_ratio': np.array([2, 3]),
'variance': np.array(['0.2', '0.3', '0.2', '0.3']),
'img_size': '300',
'img_h': '0',
'img_w': '0',
'step': '0,5',
'step_h': '0',
'step_w': '0',
'offset': '0.6'
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakePriorBoxProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
PriorBoxFrontExtractor.extract(fake_node)
exp_res = {
'op': 'PriorBox',
'type': 'PriorBox',
'clip': 0,
'variance': np.array(['0.2', '0.3', '0.2', '0.3']),
'img_size': '300',
'img_h': '0',
'img_w': '0',
'step': '0,5',
'step_h': '0',
'step_w': '0',
'offset': '0.6'
}
for key in exp_res.keys():
if key in ['width', 'height', 'variance']:
np.testing.assert_equal(fake_node[key], exp_res[key])
else:
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.proposal_ext import ProposalFrontExtractor
from extensions.ops.proposal import ProposalOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeProposalProtoLayer:
def __init__(self, val):
self.proposal_param = val
class TestProposalExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['Proposal'] = ProposalOp
def test_proposal_no_pb_no_ml(self):
self.assertRaises(AttributeError, ProposalFrontExtractor.extract, None)
@patch('extensions.front.caffe.proposal_ext.merge_attrs')
def test_proposal_ext_ideal_numbers(self, merge_attrs):
params = {
'feat_stride': 1,
'base_size': 16,
'min_size': 16,
'ratio': 1,
'scale': 2,
'pre_nms_topn': 6000,
'post_nms_topn': 300,
'nms_thresh': 0.7
}
merge_attrs.return_value = {
**params
}
fake_pl = FakeProposalProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
ProposalFrontExtractor.extract(fake_node)
exp_res = {
'type': "Proposal",
'feat_stride': 1,
'base_size': 16,
'min_size': 16,
'ratio': 1,
'scale': 2,
'pre_nms_topn': 6000,
'post_nms_topn': 300,
'nms_thresh': 0.7,
'infer': ProposalOp.proposal_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.proposal_python_ext import ProposalPythonFrontExtractor
from extensions.ops.proposal import ProposalOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeProposalPythonProtoLayer:
def __init__(self, val):
self.python_param = val
class TestProposalPythonExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['Proposal'] = ProposalOp
def test_proposal_no_pb_no_ml(self):
self.assertRaises(AttributeError, ProposalPythonFrontExtractor.extract, None)
@patch('mo.front.caffe.extractors.utils.merge_attrs')
def test_proposal_ext_ideal_numbers(self, merge_attrs):
params = {
'param_str': "'feat_stride': 16"
}
merge_attrs.return_value = params
fake_pl = FakeProposalPythonProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
ProposalPythonFrontExtractor.extract(fake_node)
exp_res = {
'type': "Proposal",
'feat_stride': 16,
'base_size': 16,
'min_size': 16,
'ratio': [0.5, 1, 2],
'scale': [8, 16, 32],
'pre_nms_topn': 6000,
'post_nms_topn': 300,
'nms_thresh': 0.7,
'infer': ProposalOp.proposal_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
@patch('mo.front.caffe.extractors.utils.merge_attrs')
def test_proposal_ext_scales(self, merge_attrs):
params = {
'param_str': "'feat_stride': 16, 'scales': [1,2,3], 'ratios':[5, 6,7]"
}
merge_attrs.return_value = params
fake_pl = FakeProposalPythonProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
ProposalPythonFrontExtractor.extract(fake_node)
exp_res = {
'type': "Proposal",
'feat_stride': 16,
'base_size': 16,
'min_size': 16,
'ratio': [5, 6, 7],
'scale': [1, 2, 3],
'pre_nms_topn': 6000,
'post_nms_topn': 300,
'nms_thresh': 0.7,
'infer': ProposalOp.proposal_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
@patch('mo.front.caffe.extractors.utils.merge_attrs')
def test_proposal_ext_scale(self, merge_attrs):
params = {
'param_str': "'feat_stride': 16, 'scale': [1,2,3], 'ratio':[5, 6,7]"
}
merge_attrs.return_value = params
fake_pl = FakeProposalPythonProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
ProposalPythonFrontExtractor.extract(fake_node)
exp_res = {
'type': "Proposal",
'feat_stride': 16,
'base_size': 16,
'min_size': 16,
'ratio': [5, 6, 7],
'scale': [1, 2, 3],
'pre_nms_topn': 6000,
'post_nms_topn': 300,
'nms_thresh': 0.7,
'infer': ProposalOp.proposal_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.psroipooling_ext import PSROIPoolingFrontExtractor
from extensions.ops.psroipooling import PSROIPoolingOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakePSROIPoolingProtoLayer:
def __init__(self, val):
self.psroi_pooling_param = val
class TestPSROIPoolingExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['PSROIPooling'] = PSROIPoolingOp
def test_psroipooling_no_pb_no_ml(self):
self.assertRaises(AttributeError, PSROIPoolingFrontExtractor.extract, None)
@patch('extensions.front.caffe.psroipooling_ext.merge_attrs')
def test_psroipooling_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'spatial_scale': 4,
'output_dim': 20,
'group_size': 5,
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakePSROIPoolingProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
PSROIPoolingFrontExtractor.extract(fake_node)
exp_res = {
'type': "PSROIPooling",
'spatial_scale': 4,
'output_dim': 20,
'group_size': 5,
'infer': PSROIPoolingOp.psroipooling_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.regionyolo_ext import RegionYoloFrontExtractor
from extensions.ops.regionyolo import RegionYoloOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeRegionYoloProtoLayer:
def __init__(self, val, val_f):
self.region_yolo_param = val
self.flatten_param = val_f
class TestReorgYoloExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['RegionYolo'] = RegionYoloOp
def test_reogyolo_no_pb_no_ml(self):
self.assertRaises(AttributeError, RegionYoloFrontExtractor.extract, None)
@patch('extensions.front.caffe.regionyolo_ext.merge_attrs')
def test_reogyolo_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'coords': 4,
'classes': 20,
'num': 5,
'do_softmax': 1,
'anchors': 5,
'mask': 5,
}
params_flatten = {
'axis': 1,
'end_axis': -1
}
merge_attrs_mock.return_value = {
**params,
**params_flatten
}
fake_pl = FakeRegionYoloProtoLayer(FakeMultiParam(params), FakeMultiParam(params_flatten))
fake_node = FakeNode(fake_pl, None)
RegionYoloFrontExtractor.extract(fake_node)
exp_res = {
'type': "RegionYolo",
'coords': 4,
'classes': 20,
'num': 5,
'axis': 1,
'end_axis': -1,
'do_softmax': 1,
'anchors': 5,
'mask': 5,
'infer': RegionYoloOp.regionyolo_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.reorgyolo_ext import ReorgYoloFrontExtractor
from extensions.ops.reorgyolo import ReorgYoloOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeReorgYoloProtoLayer:
def __init__(self, val):
self.reorg_yolo_param = val
class TestReorgYoloExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['ReorgYolo'] = ReorgYoloOp
def test_elu_no_pb_no_ml(self):
self.assertRaises(AttributeError, ReorgYoloFrontExtractor.extract, None)
@patch('extensions.front.caffe.reorgyolo_ext.merge_attrs')
def test_elu_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'stride': 2
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakeReorgYoloProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
ReorgYoloFrontExtractor.extract(fake_node)
exp_res = {
'type': "ReorgYolo",
'stride': 2,
'infer': ReorgYoloOp.reorgyolo_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.resample_ext import ResampleFrontExtractor
from extensions.ops.resample import ResampleOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeResampleProtoLayer:
def __init__(self, val):
self.resample_param = val
class TestResampleExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['Resample'] = ResampleOp
def test_da_no_pb_no_ml(self):
self.assertRaises(AttributeError, ResampleFrontExtractor.extract, None)
@patch('extensions.front.caffe.resample_ext.merge_attrs')
def test_resample_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'antialias': True,
'height': 384,
'width': 512,
'type': 2,
'factor': 1.0,
}
merge_attrs_mock.return_value = {
'antialias': True,
'height': 384,
'width': 512,
'type': 'caffe.ResampleParameter.LINEAR',
'factor': 1.0,
}
fake_pl = FakeResampleProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
ResampleFrontExtractor.extract(fake_node)
exp_res = {
'op': 'Resample',
'antialias': 1,
'height': 384,
'width': 512,
'resample_type': 'caffe.ResampleParameter.LINEAR',
'factor': 1.0,
'infer': ResampleOp.resample_infer
}
for key in exp_res.keys():
self.assertEqual(exp_res[key], fake_node[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.simplernms_ext import SimplerNMSFrontExtractor
from extensions.ops.simplernms import SimplerNMSOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeSimplerNMSProtoLayer:
def __init__(self, val):
self.simpler_nms_param = val
class TestSimplerNMSExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['SimplerNMS'] = SimplerNMSOp
def test_simplernms_no_pb_no_ml(self):
self.assertRaises(AttributeError, SimplerNMSFrontExtractor.extract, None)
@patch('extensions.front.caffe.simplernms_ext.merge_attrs')
def test_simplernms_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'cls_threshold': 0.5,
'max_num_proposals': 300,
'iou_threshold': 0.7,
'min_bbox_size': 16,
'feat_stride': 16,
'pre_nms_topn': 6000,
'post_nms_topn': 150,
'scale': [1, 2, 3]
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakeSimplerNMSProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
SimplerNMSFrontExtractor.extract(fake_node)
exp_res = {
'cls_threshold': 0.5,
'max_num_proposals': 300,
'iou_threshold': 0.7,
'min_bbox_size': 16,
'feat_stride': 16,
'pre_nms_topn': 6000,
'post_nms_topn': 150,
'scale': [1, 2, 3],
'infer': SimplerNMSOp.simplernms_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from unittest.mock import patch
from extensions.front.caffe.spatial_transformer_ext import SpatialTransformFrontExtractor
from extensions.ops.spatial_transformer import SpatialTransformOp
from mo.utils.unittest.extractors import FakeMultiParam
from mo.utils.unittest.graph import FakeNode
from mo.ops.op import Op
class FakeSpatialTransformProtoLayer:
def __init__(self, val):
self.st_param = val
class TestSpatialTransformExt(unittest.TestCase):
@classmethod
def setUpClass(cls):
Op.registered_ops['SpatialTransformer'] = SpatialTransformOp
def test_st_no_pb_no_ml(self):
self.assertRaises(AttributeError, SpatialTransformFrontExtractor.extract, None)
@patch('extensions.front.caffe.spatial_transformer_ext.merge_attrs')
def test_st_ext_ideal_numbers(self, merge_attrs_mock):
params = {
'transform_type': "ffff",
'sampler_type': "gggg",
'output_H': 56,
'output_W': 78,
'to_compute_dU': True,
'theta_1_1': 0.1,
'theta_1_2': 0.2,
'theta_1_3': 0.3,
'theta_2_1': 0.4,
'theta_2_2': 0.5,
'theta_2_3': 0.6
}
merge_attrs_mock.return_value = {
**params
}
fake_pl = FakeSpatialTransformProtoLayer(FakeMultiParam(params))
fake_node = FakeNode(fake_pl, None)
SpatialTransformFrontExtractor.extract(fake_node)
exp_res = {
'type': "SpatialTransformer",
'transform_type': "ffff",
'sampler_type': "gggg",
'output_H': 56,
'output_W': 78,
'to_compute_dU': 1,
'theta_1_1': 0.1,
'theta_1_2': 0.2,
'theta_1_3': 0.3,
'theta_2_1': 0.4,
'theta_2_2': 0.5,
'theta_2_3': 0.6,
'infer': SpatialTransformOp.sp_infer
}
for key in exp_res.keys():
self.assertEqual(fake_node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.eltwise_n import EltwiseNReplacement
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class TestAddNFrontReplacement(unittest.TestCase):
def test_replase_eltwise_n(self):
graph = build_graph(
{'node_1': {'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
'add_n': {'value': None, 'operation': 'sum', 'type': None, 'kind': 'op', 'op': 'EltwiseN'},
'node_4': {'type': 'Identity', 'value': None, 'kind': 'op'},
},
[('node_1', 'node_2'),
('node_2', 'add_n'),
('node_3', 'add_n'),
('add_n', 'node_4'), ],
)
add_n_node = Node(graph, 'add_n')
rep_op = EltwiseNReplacement()
rep_op.replace_op(graph, add_n_node)
eltwise_nodes = [node for node, attrs in list(graph.nodes(data=True)) if attrs['type'] == 'Eltwise']
self.assertEqual(len(eltwise_nodes), 1)
def test_replase_eltwise_n_2(self):
graph = build_graph(
{'node_1': {'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
'node_4': {'type': 'Identity', 'value': None, 'kind': 'op'},
'add_n': {'value': None, 'operation': 'sum', 'type': None, 'kind': 'op', 'op': 'EltwiseN'},
'node_5': {'type': 'Identity', 'value': None, 'kind': 'op'},
},
[('node_1', 'node_2'),
('node_2', 'add_n'),
('node_3', 'add_n'),
('node_4', 'add_n'),
('add_n', 'node_5'), ],
)
add_n_node = Node(graph, 'add_n')
rep_op = EltwiseNReplacement()
rep_op.replace_op(graph, add_n_node)
eltwise_nodes = [node for node, attrs in list(graph.nodes(data=True)) if attrs['type'] == 'Eltwise']
self.assertEqual(len(eltwise_nodes), 2)
def test_replase_eltwise_n_3(self):
graph = build_graph(
{'node_1': {'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
'node_4': {'type': 'Identity', 'value': None, 'kind': 'op'},
'node_5': {'type': 'Identity', 'value': None, 'kind': 'op'},
'add_n': {'value': None, 'operation': 'sum', 'type': None, 'kind': 'op', 'op': 'EltwiseN'},
'node_6': {'type': 'Identity', 'value': None, 'kind': 'op'},
},
[('node_1', 'node_2'),
('node_2', 'add_n'),
('node_3', 'add_n'),
('node_4', 'add_n'),
('node_5', 'add_n'),
('add_n', 'node_6'), ],
)
add_n_node = Node(graph, 'add_n')
rep_op = EltwiseNReplacement()
rep_op.replace_op(graph, add_n_node)
eltwise_nodes = [node for node, attrs in list(graph.nodes(data=True)) if attrs['type'] == 'Eltwise']
self.assertEqual(len(eltwise_nodes), 3)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.freeze_placeholder_value import FreezePlaceholderValue
from mo.utils.unittest.graph import build_graph
nodes_bool = {
'0': {'name': 'input1', 'kind': 'op', 'op': 'Placeholder', 'data_type': bool, 'shape': np.array([])},
'1': {'name': 'input2', 'kind': 'op', 'op': 'Placeholder', 'data_type': bool, 'shape': np.array([])},
'2': {'name': 'node_1', 'kind': 'op', 'op': 'NotPlaceholder'},
'3': {'name': 'node_2', 'kind': 'op', 'op': 'NotPlaceholder'},
'4': {'name': 'node_3', 'kind': 'op', 'op': 'NotPlaceholder'},
'5': {'name': 'node_4', 'kind': 'op', 'op': 'NotPlaceholder'},
'6': {'name': 'output1', 'kind': 'op', 'op': 'OpOutput', 'is_output': True},
'7': {'name': 'output2', 'kind': 'op', 'op': 'OpOutput', 'is_output': True}
}
edges = {
('0', '2'),
('2', '3'),
('4', '6'),
('1', '5'),
('5', '7')
}
class TestFreezePlaceholderValue(unittest.TestCase):
def test_freeze_true(self):
graph = build_graph(nodes_bool, edges)
graph.graph['fw'] = 'tf'
tested_class = FreezePlaceholderValue()
tested_class.replacement_dict = {'input1': 'True'}
before_pattern = graph.nodes()
tested_class.find_and_replace_pattern(graph=graph)
after_pattern = graph.nodes()
# number of nodes in the grpaph didn't change
self.assertEqual(len(before_pattern), len(after_pattern))
# reach new placeholder
try:
new_ph_dict = graph.node[[u for u, v in graph.in_edges('2')][0]]
except Exception as e:
self.fail("Can't get frozen placeholder. Broken edge. Additional information: {}".format(e))
# check value
self.assertEqual('value' in new_ph_dict, True)
self.assertEqual(new_ph_dict['value'], True)
def test_freeze_false(self):
graph = build_graph(nodes_bool, edges)
graph.graph['fw'] = 'tf'
tested_class = FreezePlaceholderValue()
tested_class.replacement_dict = {'input1': 'False'}
before_pattern = graph.nodes()
tested_class.find_and_replace_pattern(graph=graph)
after_pattern = graph.nodes()
# number of nodes in the grpaph didn't change
self.assertEqual(len(before_pattern), len(after_pattern))
# reach new placeholder
try:
new_ph_dict = graph.node[[u for u, v in graph.in_edges('2')][0]]
except Exception as e:
self.fail("Can't get frozen placeholder. Broken edge. Additional information: {}".format(e))
# check value
self.assertEqual('value' in new_ph_dict, True)
self.assertEqual(new_ph_dict['value'], False)
def test_freeze_both(self):
graph = build_graph(nodes_bool, edges)
graph.graph['fw'] = 'tf'
tested_class = FreezePlaceholderValue()
tested_class.replacement_dict = {'input1': 'False', 'input2': 'True'}
before_pattern = graph.nodes()
tested_class.find_and_replace_pattern(graph=graph)
after_pattern = graph.nodes()
# number of nodes in the graph didn't change
self.assertEqual(len(before_pattern), len(after_pattern))
# reach new placeholder
try:
new_ph_dict_1 = graph.node[[u for u, v in graph.in_edges('2')][0]]
new_ph_dict_2 = graph.node[[u for u, v in graph.in_edges('5')][0]]
except Exception as e:
self.fail("Can't get frozen placeholder. Broken edge. Additional information: {}".format(e))
# check value
self.assertEqual('value' in new_ph_dict_1, True)
self.assertEqual('value' in new_ph_dict_2, True)
self.assertEqual(new_ph_dict_1['value'], False)
self.assertEqual(new_ph_dict_2['value'], True)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.image_scaler import ImageScaler
from mo.utils.unittest.graph import build_graph, compare_graphs
nodes_attributes = {
'placeholder_1': {'shape': None, 'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# ImageScaler operation
'im_scaler': {'type': None, 'kind': 'op', 'op': 'ImageScaler'},
'im_scaler_data': {'value': None, 'shape': None, 'kind': 'data'},
# Test operation
'last': {'type': None, 'value': None, 'kind': 'op', 'op': None},
'last_data': {'value': None, 'shape': None, 'kind': 'data'},
# Mul and Add operations
'mul_1': {'type': None, 'value': None, 'kind': 'op', 'op': 'Mul'},
'mul_1_w': {'value': None, 'shape': None, 'kind': 'op', 'op': 'Const'},
'mul_1_data': {'value': None, 'shape': None, 'kind': 'data'},
'add_1': {'type': None, 'value': None, 'kind': 'op', 'op': 'Add'},
'add_1_w': {'value': None, 'shape': None, 'kind': 'op', 'op': 'Const'},
'add_1_data': {'value': None, 'shape': None, 'kind': 'data'},
}
class ImageScalerTest(unittest.TestCase):
def test_image_scaler_test1(self):
graph = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'im_scaler'),
('im_scaler', 'im_scaler_data'),
('im_scaler_data', 'last'),
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
'im_scaler': {'scale': np.array(1.0), 'bias': np.reshape(np.array([1, 2, 3]), [3, 1, 1])},
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'add_1'),
('add_1_w', 'add_1'),
('add_1', 'add_1_data'),
('add_1_data', 'last')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
'add_1_w': {'shape': np.array([3, 1, 1]),
'value': np.reshape(np.array([1, 2, 3]), [3, 1, 1])},
}, nodes_with_edges_only=True)
graph.graph['layout'] = 'NCHW'
replacer = ImageScaler()
replacer.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'last')
self.assertTrue(flag, resp)
def test_image_scaler_test2(self):
graph = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'im_scaler'),
('im_scaler', 'im_scaler_data'),
('im_scaler_data', 'last'),
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
'im_scaler': {'scale': np.array(2.0), 'bias': np.reshape(np.array([0, 0, 0]), [3, 1, 1])},
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'mul_1'),
('mul_1_w', 'mul_1'),
('mul_1', 'mul_1_data'),
('mul_1_data', 'last')
],
{'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])},
'mul_1_w': {'shape': np.array(2.0).shape, 'value': np.array(2.0)},
}, nodes_with_edges_only=True)
graph.graph['layout'] = 'NCHW'
replacer = ImageScaler()
replacer.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'last')
self.assertTrue(flag, resp)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import networkx as nx
from extensions.front.instance_normalization import InstanceNormalization
from mo.utils.unittest.graph import build_graph
from mo.middle.pattern_match import node_match
class TestInstanceNormalization(unittest.TestCase):
def test_default(self):
nodes = {
'input': {'kind': 'op', 'op': 'AnyOp'},
'scale': {'kind': 'op', 'op': 'AnyOp'},
'B': {'kind': 'op', 'op': 'AnyOp'},
'node': {'kind': 'op', 'op': 'InstanceNormalization', 'epsilon': 0.123},
}
edges = [
('input', 'node'),
('scale', 'node'),
('B', 'node'),
]
graph = build_graph(nodes, edges)
tested_class = InstanceNormalization()
tested_class.find_and_replace_pattern(graph)
ref_nodes = {
'input': {'op': 'AnyOp'},
'scale': {'op': 'AnyOp'},
'B': {'op': 'AnyOp'},
'mvn': {'kind': 'op', 'op': 'MVN', 'name': 'node/InstanceNormalization/MVN', 'eps': 0.123},
'mul': {'kind': 'op', 'op': 'Mul', 'name': 'node/InstanceNormalization/Mul'},
'add': {'kind': 'op', 'op': 'Add', 'name': 'node/InstanceNormalization/Add'},
}
ref_edges = [
('input', 'mvn'),
('mvn', 'mul'),
('scale', 'mul'),
('mul', 'add'),
('B', 'add'),
]
ref_graph = build_graph(ref_nodes, ref_edges)
self.assertTrue(nx.is_isomorphic(graph, ref_graph, node_match))
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.kaldi.replace_splice_node_pattern import ReplaceSpliceNodePattern
from mo.graph.graph import Node
from mo.utils.unittest.graph import build_graph
class ReplaceSpliceNodePatternTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.nodes_attributes = {
'in_node': {'kind': 'op', 'op': 'Input', 'shape': [1, 13]},
'slice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 5)}
}
cls.graph = build_graph(cls.nodes_attributes,
[('in_node', 'slice')])
ReplaceSpliceNodePattern().find_and_replace_pattern(cls.graph)
def test_memory(self):
memory_nodes = [node for node in self.graph.nodes(data=True) if node[1]['op'] == 'Memory']
self.assertEqual(len(memory_nodes), 2)
for memory_node in memory_nodes:
node = Node(self.graph, memory_node[0])
if len(node.in_nodes()):
self.assertEqual(node.index, 0)
elif len(node.out_nodes()):
self.assertEqual(node.index, 1)
self.assertEqual(memory_nodes[0][1]['id'], memory_nodes[1][1]['id'])
def test_crop(self):
crop_node = [node for node in self.graph.nodes(data=True) if node[1]['op'] == 'Crop']
self.assertEqual(len(crop_node), 1)
crop_node = Node(self.graph, crop_node[0][0])
self.assertEqual(crop_node.offset, [13])
self.assertEqual(crop_node.dim, [13 * 9])
def test_concat(self):
concat_node = [node for node in self.graph.nodes(data=True) if node[1]['op'] == 'Concat']
self.assertEqual(len(concat_node), 1)
crop_node = Node(self.graph, concat_node[0][0])
self.assertEqual(crop_node.axis, 1)
"""
Copyright (c) 2017-2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.mxnet.check_softmax_node_inputs import CheckSoftmaxNodeInputs
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class TestCheckSoftmaxNodeInputs(unittest.TestCase):
def test_remove_softmax_output_input(self):
graph = build_graph(
{'node_1': {'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
'node_2': {'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
'softmax': {'type': 'SoftmaxOutput', 'value': None, 'kind': 'op', 'op': 'SoftmaxOutput'},
},
[('node_1', 'softmax'),
('node_2', 'softmax')
])
pattern = CheckSoftmaxNodeInputs()
pattern.find_and_replace_pattern(graph)
node_softmax = Node(graph, 'softmax')
self.assertEqual(len(node_softmax.in_nodes()), 1)
node_input1 = node_softmax.in_node(0)
self.assertEqual(node_input1.name, 'node_1')
def test_remove_softmax_activation_input(self):
graph = build_graph(
{'node_1': {'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
'softmax': {'type': 'SoftmaxActivation', 'value': None, 'kind': 'op', 'op': 'SoftmaxActivation'},
},
[('node_1', 'softmax')])
pattern = CheckSoftmaxNodeInputs()
pattern.find_and_replace_pattern(graph)
node_softmax = Node(graph, 'softmax')
self.assertEqual(len(node_softmax.in_nodes()), 1)
node_input1 = node_softmax.in_node(0)
self.assertEqual(node_input1.name, 'node_1')
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.mxnet.conv_ext import DeconvFrontExtractor
from mo.utils.unittest.extractors import PB
class TestDeconvShapesParsing(unittest.TestCase):
def test_conv_ext_ideal_numbers(self):
params = {'attrs': {
"kernel": "(4, 4)",
"no_bias": "True",
"num_filter": "21",
"num_group": "14",
"pad": "(4, 4)",
"stride": "(2, 2)",
"dilate": "(3, 3)",
"workspace": "1536"
}}
node = PB({'symbol_dict': params})
DeconvFrontExtractor.extract(node)
exp_res = {
'op': 'Deconvolution',
'pad': np.array([[0, 0], [0, 0], [4, 4], [4, 4]]),
'pad_spatial_shape': np.array([[4, 4], [4, 4]]),
'stride': np.array([1, 1, 2, 2]),
'kernel_spatial': np.array([4, 4]),
'dilation': np.array([1, 1, 3, 3]),
'group': 14,
'output': 21,
'bias_addable': True,
'bias_term': False,
}
for key in exp_res.keys():
if key in ('pad', 'pad_spatial_shape', 'stride', 'kernel_spatial', 'dilation'):
np.testing.assert_equal(node[key], exp_res[key])
else:
self.assertEqual(node[key], exp_res[key])
def test_conv_ext_no_bias(self):
params = { 'attrs':{
"kernel": "(4, 4)",
"num_filter": "21",
"num_group": "14",
"pad": "(4, 4)",
"stride": "(2, 2)",
"dilate": "(3, 3)",
"workspace": "1536"
}}
node = PB({'symbol_dict': params})
DeconvFrontExtractor.extract(node)
exp_res = {
'op': 'Deconvolution',
'pad': np.array([[0, 0], [0, 0], [4, 4], [4, 4]]),
'pad_spatial_shape': np.array([[4, 4], [4, 4]]),
'stride': np.array([1, 1, 2, 2]),
'kernel_spatial': np.array([4, 4]),
'dilation': np.array([1, 1, 3, 3]),
'group': 14,
'output': 21,
'bias_addable': True,
'bias_term': False,
}
for key in exp_res.keys():
if key in ('pad', 'pad_spatial_shape', 'stride', 'kernel_spatial', 'dilation'):
np.testing.assert_equal(node[key], exp_res[key])
else:
self.assertEqual(node[key], exp_res[key])
def test_conv_ext_with_bias(self):
params = { 'attrs':{
"kernel": "(4, 4)",
"no_bias": "False",
"num_filter": "21",
"num_group": "14",
"pad": "(4, 4)",
"stride": "(2, 2)",
"dilate": "(3, 3)",
"workspace": "1536"
}}
node = PB({'symbol_dict': params})
DeconvFrontExtractor.extract(node)
exp_res = {
'op': 'Deconvolution',
'pad': np.array([[0, 0], [0, 0], [4, 4], [4, 4]]),
'pad_spatial_shape': np.array([[4, 4], [4, 4]]),
'stride': np.array([1, 1, 2, 2]),
'kernel_spatial': np.array([4, 4]),
'dilation': np.array([1, 1, 3, 3]),
'group': 14,
'output': 21,
'bias_addable': True,
'bias_term': True,
}
for key in exp_res.keys():
if key in ('pad', 'pad_spatial_shape', 'stride', 'kernel_spatial', 'dilation'):
np.testing.assert_equal(node[key], exp_res[key])
else:
self.assertEqual(node[key], exp_res[key])
def test_deconv_ext_target_shape(self):
params = {'attrs': {
"kernel": "(4, 4)",
"no_bias": "True",
"num_filter": "21",
"num_group": "14",
"pad": "(4, 4)",
"stride": "(2, 2)",
"dilate": "(3, 3)",
"workspace": "1536",
"target_shape": "(120, 120)"
}}
node = PB({'symbol_dict': params})
DeconvFrontExtractor.extract(node)
exp_res = {
'op': 'Deconvolution',
'pad': np.array([[0, 0], [0, 0], [4, 4], [4, 4]]),
'pad_spatial_shape': np.array([[4, 4], [4, 4]]),
'stride': np.array([1, 1, 2, 2]),
'kernel_spatial': np.array([4, 4]),
'dilation': np.array([1, 1, 3, 3]),
'group': 14,
'output': 21,
'bias_addable': True,
'bias_term': False,
'output_spatial_shape': np.array([120, 120]),
}
for key in exp_res.keys():
if key in ('pad', 'pad_spatial_shape', 'stride', 'kernel_spatial', 'dilation', 'output_spatial_shape'):
np.testing.assert_equal(node[key], exp_res[key])
else:
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.mxnet.custom import CustomFrontExtractorOp
from mo.utils.unittest.graph import build_graph
from mo.front.extractor import FrontExtractorOp, MXNetCustomFrontExtractorOp
from mo.graph.graph import Node
attrs = {'test_attr': 1}
class FakeExtractor(MXNetCustomFrontExtractorOp):
@staticmethod
def extract(node: Node):
return True, attrs
class TestCustomFrontExtractorOp(unittest.TestCase):
@classmethod
def setUpClass(cls):
FrontExtractorOp.registered_ops['Custom'] = CustomFrontExtractorOp
def test_extract_custom_layer(self):
graph = build_graph(
{'node_1': {'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
'node_2': {'type': 'Identity', 'value': None, 'kind': 'op'},
'node_custom': {'type': 'Custom', 'value': None, 'kind': 'op', 'op': 'Custom', },
'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
},
[('node_1', 'node_2'),
('node_2', 'node_custom'),
('node_custom', 'node_3'),
],
{
'node_custom': {'symbol_dict': {'attrs': {'op_type': 'test_type'}}},
})
custom_node = Node(graph, 'node_custom')
custom_op = FakeExtractor()
supported, op_attrs = custom_op.extract(custom_node)
self.assertTrue(supported)
self.assertEquals(op_attrs, attrs)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.mxnet.pooling_ext import PoolingFrontExtractor
from mo.utils.unittest.extractors import PB
class TestPoolingShapesParsing(unittest.TestCase):
def test_conv_ext_ideal_numbers(self):
params = {'attrs': {
"kernel": "(3, 4)",
"stride": "(3, 2)",
"pad": "(7, 8)",
"pool_type": "max"
}}
node = PB({'symbol_dict': params})
PoolingFrontExtractor.extract(node)
exp_res = {
'op': 'Pooling',
'pad': np.array([[0, 0], [0, 0], [7, 7], [8, 8]]),
'pad_spatial_shape': np.array([[7, 7], [8, 8]]),
'stride': np.array([1, 1, 3, 2]),
'window': np.array([1, 1, 3, 4]),
'pool_method': 'max',
'exclude_pad': 'false',
}
for key in exp_res.keys():
if key in ('pad', 'stride', 'window', 'pad_spatial_shape'):
np.testing.assert_equal(node[key], exp_res[key])
else:
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.mxnet.slice_channel_ext import SliceChannelFrontExtractor
from mo.utils.unittest.extractors import PB
class TestSliceChannelParsing(unittest.TestCase):
def test_parse_values(self):
params = {'attrs': {
"num_outputs": "2",
'axis': "2",
}}
node = PB({'symbol_dict': params})
SliceChannelFrontExtractor.extract(node)
exp_res = {
'op': 'Split',
'axis': 2,
'num_split': 2,
}
for key in exp_res.keys():
self.assertEqual(node[key], exp_res[key])
def test_parse_dafault_values(self):
params = {'attrs': {
"num_outputs": "2",
}}
node = PB({'symbol_dict': params})
SliceChannelFrontExtractor.extract(node)
exp_res = {
'op': 'Split',
'axis': 1,
'num_split': 2,
}
for key in exp_res.keys():
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2017-2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.mxnet.ssd_pattern_flatten_softmax_activation import SsdPatternFlattenSoftmaxActivation
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class TestSsdPatternFlattenSoftmaxActivation(unittest.TestCase):
def test_pattern_remove_transpose(self):
graph = build_graph({'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
'node_2': {'type': 'Identity', 'kind': 'op'},
'node_3': {'type': 'Identity', 'kind': 'op'},
'node_softmax_activation': {'type': 'SoftMax', 'kind': 'op', 'op': 'SoftMax'},
'node_multi_box_detection': {'type': '_contrib_MultiBoxDetection', 'kind': 'op',
'op': '_contrib_MultiBoxDetection'},
'node_4': {'type': 'Identity', 'kind': 'op'},
},
[('node_1', 'node_softmax_activation'),
('node_2', 'node_multi_box_detection'),
('node_softmax_activation', 'node_multi_box_detection'),
('node_3', 'node_multi_box_detection'),
('node_multi_box_detection', 'node_4'), ],
)
pattern = SsdPatternFlattenSoftmaxActivation()
pattern.find_and_replace_pattern(graph)
flatten_name = list(graph.nodes())[-1]
self.assertTrue(graph.has_node(flatten_name))
self.assertFalse(graph.has_edge(Node(graph, 'softmax_activation').id, Node(graph, 'multi_box_detection').id))
"""
Copyright (c) 2017-2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.mxnet.ssd_pattern_remove_flatten import SsdPatternRemoveFlatten
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class TestSsdPatternRemoveFlatten(unittest.TestCase):
def test_pattern_remove_transpose(self):
graph = build_graph({'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
'node_2': {'type': 'Identity', 'kind': 'op'},
'node_multi_box_prior': {'type': '_contrib_MultiBoxPrior', 'kind': 'op',
'op': '_contrib_MultiBoxPrior'},
'node_flatten': {'type': 'Flatten', 'kind': 'op', 'op': 'Flatten'},
'node_3': {'type': 'Identity', 'kind': 'op'},
},
[('node_1', 'node_2'),
('node_2', 'node_multi_box_prior'),
('node_multi_box_prior', 'node_flatten'),
('node_flatten', 'node_3'), ],
)
pattern = SsdPatternRemoveFlatten()
pattern.find_and_replace_pattern(graph)
self.assertFalse(graph.has_node('node_flatten'))
self.assertTrue(graph.has_edge(Node(graph, 'node_multi_box_prior').id, Node(graph, 'node_3').id))
"""
Copyright (c) 2017-2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.mxnet.ssd_pattern_remove_reshape import SsdPatternRemoveReshape
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class TestSsdPatternRemoveReshape(unittest.TestCase):
def test_pattern_remove_reshape(self):
graph = build_graph({'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
'node_2': {'type': 'Identity', 'kind': 'op'},
'node_multi_box_prior1': {'type': '_contrib_MultiBoxPrior', 'kind': 'op',
'op': '_contrib_MultiBoxPrior'},
'node_multi_box_prior2': {'type': '_contrib_MultiBoxPrior', 'kind': 'op',
'op': '_contrib_MultiBoxPrior'},
'node_multi_box_prior3': {'type': '_contrib_MultiBoxPrior', 'kind': 'op',
'op': '_contrib_MultiBoxPrior'},
'node_concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
'node_reshape': {'type': 'Reshape', 'kind': 'op', 'op': 'Reshape'},
'node_3': {'type': 'Identity', 'kind': 'op'},
},
[('node_1', 'node_2'),
('node_2', 'node_multi_box_prior1'),
('node_2', 'node_multi_box_prior2'),
('node_2', 'node_multi_box_prior3'),
('node_multi_box_prior1', 'node_concat'),
('node_multi_box_prior2', 'node_concat'),
('node_multi_box_prior3', 'node_concat'),
('node_concat', 'node_reshape'),
('node_reshape', 'node_3'), ],
{
'node_concat': {'symbol_dict': {'attrs': {'dim': 3}}},
})
pattern = SsdPatternRemoveReshape()
pattern.find_and_replace_pattern(graph)
node_concat = Node(graph, 'node_concat')
self.assertEqual(node_concat['symbol_dict']['attrs']['dim'], 2)
self.assertFalse(graph.has_node('node_reshape'))
self.assertTrue(graph.has_edge(Node(graph, 'node_concat').id, Node(graph, 'node_3').id))
"""
Copyright (c) 2017-2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.mxnet.ssd_pattern_remove_transpose import SsdPatternRemoveTranspose
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class TestSsdPatternRemoveTranspose(unittest.TestCase):
def test_pattern_remove_transpose(self):
graph = build_graph({'node_1': {'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder'},
'node_3': {'type': 'Identity', 'value': None, 'kind': 'op'},
'node_4': {'type': 'Identity', 'value': None, 'kind': 'op'},
'node_transpose': {'type': 'transpose', 'value': None, 'kind': 'op', 'op': 'transpose'},
'node_softmax_activation': {'type': 'SoftMax', 'value': None, 'kind': 'op',
'op': 'SoftMax'},
'node_multi_box_detection': {'type': '_contrib_MultiBoxDetection', 'value': None,
'kind': 'op', 'op': '_contrib_MultiBoxDetection'},
'node_5': {'type': 'Identity', 'value': None, 'kind': 'op'},
},
[('node_1', 'node_transpose'),
('node_transpose', 'node_softmax_activation'),
('node_3', 'node_multi_box_detection'),
('node_softmax_activation', 'node_multi_box_detection'),
('node_4', 'node_multi_box_detection'),
('node_multi_box_detection', 'node_5'), ],
)
pattern = SsdPatternRemoveTranspose()
pattern.find_and_replace_pattern(graph)
self.assertFalse(graph.has_node('node_transpose'))
self.assertTrue(graph.has_edge(Node(graph, 'node_1').id, Node(graph, 'node_softmax_activation').id))
"""
Copyright (c) 2017-2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.front.mxnet.ssd_reorder_detection_out_inputs import SsdReorderDetectionOutInputs
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class TestSsdReorderDetectionOutInputs(unittest.TestCase):
def test_reorder_detection_out_inputs(self):
graph = build_graph(
{'node_1': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
'node_2': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
'node_3': {'type': 'Identity', 'kind': 'op', 'op': 'Placeholder'},
'multi_box_detection': {'type': '_contrib_MultiBoxDetection', 'kind': 'op',
'op': '_contrib_MultiBoxDetection'},
},
[('node_1', 'multi_box_detection'),
('node_2', 'multi_box_detection'),
('node_3', 'multi_box_detection')],
{
'node_1': {'shape': np.array([1, 34928])},
'node_2': {'shape': np.array([1, 183372])},
'node_3': {'shape': np.array([1, 2, 34928])},
})
pattern = SsdReorderDetectionOutInputs()
pattern.find_and_replace_pattern(graph)
node_multi_box = Node(graph, 'multi_box_detection')
node_input1 = node_multi_box.in_node(0)
node_input2 = node_multi_box.in_node(1)
node_input3 = node_multi_box.in_node(2)
self.assertEqual(node_input1.name, 'node_2')
self.assertEqual(node_input2.name, 'node_1')
self.assertEqual(node_input3.name, 'node_3')
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import onnx
from extensions.front.onnx.affine_ext import AffineFrontExtractor
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class AffineONNXExtractorTest(unittest.TestCase):
@staticmethod
def _create_node(attrs: dict):
pb = onnx.helper.make_node("Affine", ["X"], ["Y"], **attrs)
graph = build_graph({'node_0': {'pb': pb}}, [])
return Node(graph, 'node_0')
@staticmethod
def _base_attrs():
# Commonly used attributes in the tests
# Each test takes these ones and then adds/modifies/deletes particular fields
return (
# test input ONNX attributes
dict(
alpha=1.0,
beta=0.0
),
# reference output Node attributes
dict(
op='ImageScaler',
scale=1.0,
bias=0.0
)
)
@staticmethod
def _extract(inp):
node = __class__._create_node(inp)
AffineFrontExtractor.extract(node)
return node.graph.node[node.id]
def _match(self, out, ref):
for key in ref.keys():
status = out[key] == ref[key]
if type(status) in [list, np.ndarray]:
status = np.all(status)
self.assertTrue(status, 'Mismatch for field {}, observed: {}, expected: {}'.format(key, out[key], ref[key]))
def test_default(self):
inp, ref = self._base_attrs()
out = self._extract(inp)
self._match(out, ref)
def test_random(self):
inp, ref = self._base_attrs()
inp['alpha'] = 123.
inp['beta'] = 321.
ref['scale'] = 123.
ref['bias'] = 321.
out = self._extract(inp)
self._match(out, ref)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import onnx
from extensions.front.onnx.conv_ext import ConvTransposeFrontExtractor
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
from mo.utils.error import Error
class ConvTransposeONNXExtractorTest(unittest.TestCase):
@staticmethod
def _create_node(attrs: dict):
pb = onnx.helper.make_node("ConvTranspose", ["X", "W"], ["Y"], **attrs)
graph = build_graph({'node_0': {'pb': pb}}, [])
return Node(graph, 'node_0')
@staticmethod
def _base_attrs():
# Commonly used attributes in the tests
# Each test takes these ones and then adds/modifies/deletes particular fields
return (
# test input ONNX attributes
dict(
pads=[1, 2, 3, 4],
kernel_shape=[5, 6]
),
# reference output Node attributes
dict(
type='Deconvolution',
pad=[[0, 0], [0, 0], [1, 3], [2, 4]],
pad_spatial_shape=[[1, 3], [2, 4]],
kernel_spatial=[5, 6],
bias_term=None,
output_shape=None,
output_padding=[0, 0, 0, 0],
dilation=[1, 1, 1, 1],
stride=[1, 1, 1, 1],
output_spatial_shape=None,
group=1
)
)
@staticmethod
def _extract(inp):
node = __class__._create_node(inp)
ConvTransposeFrontExtractor.extract(node)
return node.graph.node[node.id]
def _match(self, out, ref):
for key in ref.keys():
status = out[key] == ref[key]
if type(status) in [list, np.ndarray]:
status = np.all(status)
self.assertTrue(status, 'Mismatch for field {}, observed: {}, expected: {}'.format(key, out[key], ref[key]))
def test_all_valid_default(self):
inp, ref = self._base_attrs()
del inp['pads']
ref['pad'] = [[0, 0], [0, 0], [0, 0], [0, 0]]
ref['pad_spatial_shape'] = [[0, 0], [0, 0]]
out = self._extract(inp)
self._match(out, ref)
def test_most_used(self):
inp, ref = self._base_attrs()
out = self._extract(inp)
self._match(out, ref)
def test_dilation(self):
inp, ref = self._base_attrs()
inp['dilations'] = [10, 11]
ref['dilation'] = [1, 1, 10, 11]
out = self._extract(inp)
self._match(out, ref)
def test_stride(self):
inp, ref = self._base_attrs()
inp['strides'] = [12, 13]
ref['stride'] = [1, 1, 12, 13]
out = self._extract(inp)
self._match(out, ref)
def test_group(self):
inp, ref = self._base_attrs()
inp['group'] = 14
ref['group'] = 14
out = self._extract(inp)
self._match(out, ref)
def test_auto_pad_supported(self):
inp, ref = self._base_attrs()
del inp['pads']
inp['auto_pad'] = 'SAME_UPPER'
ref['auto_pad'] = 'same_upper'
ref['pad'] = [[0, 0], [0, 0], [0, 0], [0, 0]]
ref['pad_spatial_shape'] = [[0, 0], [0, 0]]
out = self._extract(inp)
self._match(out, ref)
def test_pads_not_even_invalid(self):
inp, ref = self._base_attrs()
inp['pads'] = [1, 2, 3]
with self.assertRaisesRegex(Error, '.*pads.*not correct.*'):
out = self._extract(inp)
def test_missing_kernel_shape_not_supported(self):
inp, ref = self._base_attrs()
del inp['kernel_shape']
with self.assertRaisesRegex(Error, '.*kernel_shape.*not supported.*'):
out = self._extract(inp)
def test_output_padding(self):
inp, ref = self._base_attrs()
inp['output_padding'] = [19, 20]
ref['output_padding'] = [0, 0, 19, 20]
out = self._extract(inp)
self._match(out, ref)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import onnx
from extensions.front.onnx.crop_ext import CropFrontExtractor
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class CropONNXExtractorTest(unittest.TestCase):
@staticmethod
def _create_node(attrs: dict):
pb = onnx.helper.make_node("Crop", ["X"], ["Y"], **attrs)
graph = build_graph({'node_0': {'pb': pb}}, [])
return Node(graph, 'node_0')
@staticmethod
def _base_attrs():
# Commonly used attributes in the tests
# Each test takes these ones and then adds/modifies/deletes particular fields
return (
# test input ONNX attributes
dict(
border=[5, 10, 15, 20],
),
# reference output Node attributes
dict(
op='Crop',
crop_begin=np.array([10, 5]),
crop_end=np.array([20, 15]),
axis=np.array([2, 3])
)
)
@staticmethod
def _extract(inp):
node = __class__._create_node(inp)
CropFrontExtractor.extract(node)
return node.graph.node[node.id]
def _match(self, out, ref):
for key in ref.keys():
status = out[key] == ref[key]
if type(status) in [list, np.ndarray]:
status = np.all(status)
self.assertTrue(status, 'Mismatch for field {}, observed: {}, expected: {}'.format(key, out[key], ref[key]))
def test_default(self):
inp, ref = self._base_attrs()
out = self._extract(inp)
self._match(out, ref)
def test_with_scale(self):
inp, ref = self._base_attrs()
inp['scale'] = np.array([34, 50])
del ref['crop_begin']
del ref['crop_end']
ref['dim'] = np.array([34, 50])
ref['offset'] = np.array([10, 5])
out = self._extract(inp)
self._match(out, ref)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import onnx
from generator import generator, generate
from extensions.front.onnx.elu_ext import EluFrontExtractor
from mo.ops.activation import Activation
from mo.ops.op import Op
from mo.utils.unittest.extractors import PB
@generator
class TestEluONNXExt(unittest.TestCase):
@staticmethod
def _create_elu_node(alpha=1.0):
pb = onnx.helper.make_node(
'Elu',
inputs=['x'],
outputs=['y'],
alpha=alpha
)
node = PB({'pb': pb})
return node
@classmethod
def setUpClass(cls):
Op.registered_ops['Elu'] = Activation
@generate(*[1.0, 2.0, 3.0])
def test_elu_ext(self, alpha):
node = self._create_elu_node(alpha)
EluFrontExtractor.extract(node)
exp_res = {
'type': 'Activation',
'operation': 'elu',
'alpha': alpha,
'infer': Activation.infer
}
for key in exp_res.keys():
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import onnx
from generator import generator, generate
from extensions.front.onnx.flatten_ext import FlattenFrontExtractor
from mo.ops.flatten_onnx import FlattenONNX
from mo.ops.op import Op
from mo.utils.unittest.extractors import PB
@generator
class TestFlattenONNXExt(unittest.TestCase):
@staticmethod
def _create_flatten_node(axis):
pb = onnx.helper.make_node(
'Flatten',
inputs=['a'],
outputs=['b'],
axis=axis,
)
node = PB({'pb': pb})
return node
@classmethod
def setUpClass(cls):
Op.registered_ops['Flatten'] = FlattenONNX
@generate(*[x for x in range(4)])
def test_flatten_ext(self, axis):
node = self._create_flatten_node(axis)
FlattenFrontExtractor.extract(node)
exp_res = {
'type': 'Reshape',
'axis': axis,
'infer': FlattenONNX.infer
}
for key in exp_res.keys():
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import onnx
from generator import generator, generate
from extensions.front.onnx.gather_ext import GatherFrontExtractor
from extensions.ops.gather import Gather
from mo.ops.op import Op
from mo.utils.unittest.extractors import PB
@generator
class TestGatherONNXExt(unittest.TestCase):
@staticmethod
def _create_gather_node(axis=0):
pb = onnx.helper.make_node(
'Gather',
inputs=['data', 'indices'],
outputs=['y'],
axis=axis,
)
node = PB({'pb': pb})
return node
@classmethod
def setUpClass(cls):
Op.registered_ops['Gather'] = Gather
@generate(*[0, 1, 2, 3])
def test_gather_ext(self, axis):
node = self._create_gather_node(axis)
GatherFrontExtractor.extract(node)
exp_res = {
'type': 'Gather',
'axis': axis,
'infer': Gather.infer
}
for key in exp_res.keys():
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import onnx
from extensions.front.onnx.image_scaler_ext import ImageScalerFrontExtractor
from mo.utils.unittest.extractors import PB
class TestImageScalerONNXExt(unittest.TestCase):
@staticmethod
def _create_image_scaler_node():
pb = onnx.helper.make_node(
'ImageScaler',
inputs=['a'],
outputs=['b'],
scale=1.0,
bias=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
)
node = PB({'pb': pb, 'graph': PB({'graph': {'layout': 'NCHW'}})})
return node
def test_image_scaler_ext(self):
node = self._create_image_scaler_node()
ImageScalerFrontExtractor.extract(node)
exp_res = {
'scale': 1.0,
'bias': [[[1.0]], [[2.0]], [[3.0]], [[4.0]], [[5.0]], [[6.0]], [[7.0]], [[8.0]]],
}
for key in exp_res.keys():
if type(node[key]) in [list, np.ndarray]:
self.assertTrue(np.array_equal(np.array(node[key]), np.array(exp_res[key])))
else:
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import onnx
from extensions.front.onnx.instance_normalization_ext import InstanceNormalizationExtractor
from mo.utils.unittest.extractors import PB, BaseExtractorsTestingClass
class TestInstanceNormalization(BaseExtractorsTestingClass):
@staticmethod
def _create_node():
pb = onnx.helper.make_node(
'InstanceNormalization',
inputs=['a'],
outputs=['b'],
epsilon=0.5,
)
node = PB({'pb': pb})
return node
def test_image_scaler_ext(self):
node = self._create_node()
InstanceNormalizationExtractor.extract(node)
self.res = node
self.expected = {
'epsilon': 0.5,
}
self.compare()
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import onnx
from extensions.front.onnx.pad_ext import PadFrontExtractor
from mo.utils.unittest.extractors import PB, BaseExtractorsTestingClass
class TestPad(BaseExtractorsTestingClass):
@staticmethod
def _create_node(pads=None, value=None, mode=None):
if pads is None:
pads = [1, 2, 3, 4]
if value is None:
value = 0.0
if mode is None:
mode = 'constant'
pb = onnx.helper.make_node(
'Pad',
pads=pads,
mode=mode,
value=value,
inputs=['a'],
outputs=['b']
)
node = PB({'pb': pb})
return node
def test_ok(self):
node = self._create_node()
PadFrontExtractor.extract(node)
self.res = node
self.expected = {
'pads': [[1, 3], [2, 4]],
'mode': 'constant',
'fill_value': 0
}
self.compare()
def test_reflect(self):
node = self._create_node(mode='reflect')
PadFrontExtractor.extract(node)
self.res = node
self.expected = {
'pads': [[1, 3], [2, 4]],
'mode': 'reflect',
'fill_value': 0
}
self.compare()
def test_non_zero_fill_value(self):
node = self._create_node(value=1.0)
PadFrontExtractor.extract(node)
self.res = node
self.expected = {
'pads': [[1, 3], [2, 4]],
'mode': 'constant',
'fill_value': 1.0
}
self.compare()
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import onnx
from extensions.front.onnx.sigmoid_ext import SigmoidFrontExtractor
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class SigmoidONNXExtractorTest(unittest.TestCase):
@staticmethod
def _create_node():
pb = onnx.helper.make_node("Sigmoid", ["X"], ["Y"])
graph = build_graph({'node_0': {'pb': pb}}, [])
return Node(graph, 'node_0')
@staticmethod
def _base_attrs():
# Commonly used attributes in the tests
# Each test takes these ones and then adds/modifies/deletes particular fields
return (
# reference output Node attributes
dict(
op='Activation',
operation='sigmoid'
)
)
@staticmethod
def _extract():
node = __class__._create_node()
SigmoidFrontExtractor.extract(node)
return node.graph.node[node.id]
def _match(self, out, ref):
for key in ref.keys():
status = out[key] == ref[key]
if type(status) in [list, np.ndarray]:
status = np.all(status)
self.assertTrue(status, 'Mismatch for field {}, observed: {}, expected: {}'.format(key, out[key], ref[key]))
def test_default(self):
ref = self._base_attrs()
out = self._extract()
self._match(out, ref)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import onnx
from generator import generator, generate
from extensions.front.onnx.slice_ext import SliceFrontExtractor
from mo.ops.op import Op
from mo.ops.slice import Slice
from mo.utils.unittest.extractors import PB
@generator
class TestSliceONNXExt(unittest.TestCase):
@staticmethod
def _create_slice_node(axes, starts, ends):
if axes is None:
pb = onnx.helper.make_node(
'Slice',
inputs=['x'],
outputs=['y'],
starts=starts,
ends=ends,
)
else:
pb = onnx.helper.make_node(
'Slice',
inputs=['x'],
outputs=['y'],
axes=axes,
starts=starts,
ends=ends,
)
node = PB({'pb': pb})
return node
@classmethod
def setUpClass(cls):
Op.registered_ops['Slice'] = Slice
@generate(*[([0, 1], [0, 0], [28, 28]), (None, [0, 0], [28, 28])])
def test_slice_ext(self, axes, starts, ends):
node = self._create_slice_node(axes, starts, ends)
SliceFrontExtractor.extract(node)
exp_res = {
'op': 'Slice',
'axis': axes,
'start': starts,
'end': ends,
'infer': Slice.infer
}
for key in exp_res.keys():
if type(node[key]) in [list, np.ndarray]:
self.assertTrue(np.array_equal(np.array(node[key]), np.array(exp_res[key])))
else:
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import onnx
from generator import generator, generate
from extensions.front.onnx.squeeze_ext import SqueezeFrontExtractor
from mo.ops.op import Op
from mo.ops.squeeze import Squeeze
from mo.utils.unittest.extractors import PB
@generator
class TestSqueezeONNXExt(unittest.TestCase):
@staticmethod
def _create_squeeze_node(axes):
if axes is None:
pb = onnx.helper.make_node(
'Squeeze',
inputs=['x'],
outputs=['y'],
)
else:
pb = onnx.helper.make_node(
'Squeeze',
inputs=['x'],
outputs=['y'],
axes=axes,
)
node = PB({'pb': pb})
return node
@classmethod
def setUpClass(cls):
Op.registered_ops['Squeeze'] = Squeeze
@generate(*[[0, 1, 2, 3], [1], None])
def test_squeeze_ext(self, axes):
node = self._create_squeeze_node(axes)
SqueezeFrontExtractor.extract(node)
exp_res = {
'type': 'Reshape',
'squeeze_dims': axes,
}
for key in exp_res.keys():
if type(node[key]) in [list, np.ndarray]:
self.assertTrue(np.array_equal(np.array(node[key]), np.array(exp_res[key])))
else:
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import onnx
from extensions.front.onnx.tanh_ext import TanhFrontExtractor
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
class TanhONNXExtractorTest(unittest.TestCase):
@staticmethod
def _create_node():
pb = onnx.helper.make_node("Tanh", ["X"], ["Y"])
graph = build_graph({'node_0': {'pb': pb}}, [])
return Node(graph, 'node_0')
@staticmethod
def _base_attrs():
# Commonly used attributes in the tests
# Each test takes these ones and then adds/modifies/deletes particular fields
return (
# reference output Node attributes
dict(
op='Activation',
operation='tanh'
)
)
@staticmethod
def _extract():
node = __class__._create_node()
TanhFrontExtractor.extract(node)
return node.graph.node[node.id]
def _match(self, out, ref):
for key in ref.keys():
status = out[key] == ref[key]
if type(status) in [list, np.ndarray]:
status = np.all(status)
self.assertTrue(status, 'Mismatch for field {}, observed: {}, expected: {}'.format(key, out[key], ref[key]))
def test_default(self):
ref = self._base_attrs()
out = self._extract()
self._match(out, ref)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import itertools
import unittest
import numpy as np
import onnx
from generator import generator, generate
from extensions.front.onnx.transpose_ext import TransposeFrontExtractor
from mo.ops.op import Op
from mo.ops.permute import Permute
from mo.utils.unittest.extractors import PB
@generator
class TestTransposeONNXExt(unittest.TestCase):
@staticmethod
def _create_transpose_node(order: list):
if order is None:
# Default transpose
pb = onnx.helper.make_node(
'Transpose',
inputs=['data'],
outputs=['transposed'],
)
else:
# Transpose with order
pb = onnx.helper.make_node(
'Transpose',
inputs=['data'],
outputs=['transposed'],
perm=order
)
node = PB({'pb': pb})
return node
@classmethod
def setUpClass(cls):
Op.registered_ops['Permute'] = Permute
pass
# This generator generates all permutations for [0,1,2,3] and [0,1,2] orders
@generate(*[list(order) for order in list(itertools.permutations(np.arange(4)))] +
[list(order) for order in list(itertools.permutations(np.arange(3)))] + [None])
def test_transpose_ext(self, order):
node = self._create_transpose_node(order)
TransposeFrontExtractor.extract(node)
exp_res = {
'type': 'Permute',
'order': order,
'infer': Permute.infer
}
for key in exp_res.keys():
if isinstance(exp_res[key], list):
self.assertTrue(np.array_equal(node[key], exp_res[key]),
"Orders are not the same: {} and {}".format(node[key], exp_res[key]))
else:
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
import onnx
from generator import generator, generate
from extensions.front.onnx.unsqueeze_ext import UnsqueezeFrontExtractor
from mo.ops.op import Op
from mo.ops.unsqueeze import Unsqueeze
from mo.utils.unittest.extractors import PB
@generator
class TestUnsqueezeONNXExt(unittest.TestCase):
@staticmethod
def _create_unsqueeze_node(axes):
pb = onnx.helper.make_node(
'Unsqueeze',
inputs=['x'],
outputs=['y'],
axes=axes,
)
node = PB({'pb': pb})
return node
@classmethod
def setUpClass(cls):
Op.registered_ops['Unsqueeze'] = Unsqueeze
@generate(*[[0, 1, 2, 3], [1]])
def test_unsqueeze_ext(self, axes):
node = self._create_unsqueeze_node(axes)
UnsqueezeFrontExtractor.extract(node)
exp_res = {
'unsqueeze_dims': axes,
}
for key in exp_res.keys():
if type(node[key]) in [list, np.ndarray]:
self.assertTrue(np.array_equal(np.array(node[key]), np.array(exp_res[key])))
else:
self.assertEqual(node[key], exp_res[key])
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import onnx
from extensions.front.onnx.upsample_ext import UpsampleFrontExtractor
from mo.utils.unittest.graph import build_graph
from mo.graph.graph import Node
from mo.utils.error import Error
from mo.utils.unittest.extractors import BaseExtractorsTestingClass
class UpsampleONNXExtractorTest(BaseExtractorsTestingClass):
@staticmethod
def _create_node(attrs: dict):
pb = onnx.helper.make_node("Upsample", ["X"], ["Y"], **attrs)
graph = build_graph({'node_0': {'pb': pb}}, [])
return Node(graph, 'node_0')
@staticmethod
def _base_attrs():
# Commonly used attributes in the tests
# Each test takes these ones and then adds/modifies/deletes particular fields
return (
# test input ONNX attributes
dict(
mode='nearest',
width_scale=2.0,
height_scale=2.0,
),
# reference output Node attributes
dict(
type='Resample',
resample_type='caffe.ResampleParameter.NEAREST',
factor=2,
antialias=0,
)
)
@staticmethod
def _extract(inp):
node = __class__._create_node(inp)
UpsampleFrontExtractor.extract(node)
return node
def _match(self, out, ref):
self.res = out
self.expected = ref
self.compare()
def test_all_valid_default(self):
inp, ref = self._base_attrs()
out = self._extract(inp)
self._match(out, ref)
def test_invalid_mode(self):
inp, ref = self._base_attrs()
inp['mode'] = 'invalid_mode'
with self.assertRaisesRegex(Error, '.*decoding Upsample.*supported modes.*'):
out = self._extract(inp)
def test_unsupported_linear(self):
inp, ref = self._base_attrs()
inp['mode'] = 'linear'
with self.assertRaisesRegex(Error, '.*Only nearest is supported.*'):
out = self._extract(inp)
def test_unsupported_scale(self):
inp, ref = self._base_attrs()
inp['scales'] = [2.0, 2.0]
with self.assertRaisesRegex(Error, '.*Only scale_width and scale_height are supported.*'):
out = self._extract(inp)
def test_missing_width_scale(self):
inp, ref = self._base_attrs()
del inp['width_scale']
with self.assertRaisesRegex(Error, '.*One/both of widths_scale.*and height_scale.*is not defined.*'):
out = self._extract(inp)
def test_missing_height_scale(self):
inp, ref = self._base_attrs()
del inp['height_scale']
with self.assertRaisesRegex(Error, '.*One/both of widths_scale.*and height_scale.*is not defined.*'):
out = self._extract(inp)
def test_different_scales(self):
inp, ref = self._base_attrs()
inp['height_scale'] = 2.0
inp['width_scale'] = 3.0
with self.assertRaisesRegex(Error, '.*different widths_scale.*and height_scale.*not supported.*'):
out = self._extract(inp)
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.reciprocal import ReciprocalReplacer
from mo.utils.unittest.graph import build_graph, compare_graphs
class ReciprocalReplacerTests(unittest.TestCase):
@staticmethod
def _create_graphs():
return (
build_graph(
{'placeholder': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
'reciprocal': {'kind': 'op', 'op': 'Reciprocal'}},
[('placeholder', 'reciprocal')]),
build_graph(
{'placeholder': {'type': 'Placeholder', 'kind': 'op', 'op': 'Placeholder'},
'power': {'type': 'Power', 'kind': 'op', 'op': 'Power', 'scale': 1, 'power': -1, 'shift': 0}},
[('placeholder', 'power')])
)
def test_replace_reciprocal(self):
graph, graph_ref = __class__._create_graphs()
pattern = ReciprocalReplacer()
pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'reciprocal/power_', last_node_ref='power', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_neg_replace_reciprocal(self):
graph, graph_ref = __class__._create_graphs()
graph_ref.node['power']['power'] = 0
pattern = ReciprocalReplacer()
pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'reciprocal/power_', last_node_ref='power', check_op_attrs=True)
self.assertTrue(not flag)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
"""
Copyright (c) 2018 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from extensions.front.tf.pad_ext import PadFrontExtractor
from mo.utils.unittest.extractors import PB
class TestPad(unittest.TestCase):
def test_no_pads(self):
node = PB({})
PadFrontExtractor.extract(node)
self.assertTrue(not 'pads' in node or node['pads'] is None)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment