Commit 7b711340 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[Py] Added operators Shuffle Channels, Squared Difference and Squeeze to Python API. (#3393)

* [Py] Added operators Shuffle Channels, Squared Difference and Squeeze to Python API.

* [Py] Changed docstring.

* [Py] Changed docstring.

* [Py] Changed docstring.
parent bbb9a566
...@@ -74,6 +74,7 @@ ngraph.ops ...@@ -74,6 +74,7 @@ ngraph.ops
reverse reverse
scale_shift scale_shift
select select
shuffle_channels
sign sign
sin sin
sinh sinh
...@@ -81,6 +82,8 @@ ngraph.ops ...@@ -81,6 +82,8 @@ ngraph.ops
softmax softmax
space_to_depth space_to_depth
sqrt sqrt
squared_difference
squeeze
subtract subtract
sum sum
tan tan
......
...@@ -87,6 +87,7 @@ from ngraph.ops import reshape ...@@ -87,6 +87,7 @@ from ngraph.ops import reshape
from ngraph.ops import reverse from ngraph.ops import reverse
from ngraph.ops import scale_shift from ngraph.ops import scale_shift
from ngraph.ops import select from ngraph.ops import select
from ngraph.ops import shuffle_channels
from ngraph.ops import sign from ngraph.ops import sign
from ngraph.ops import sin from ngraph.ops import sin
from ngraph.ops import sinh from ngraph.ops import sinh
...@@ -94,6 +95,8 @@ from ngraph.ops import slice ...@@ -94,6 +95,8 @@ from ngraph.ops import slice
from ngraph.ops import softmax from ngraph.ops import softmax
from ngraph.ops import space_to_depth from ngraph.ops import space_to_depth
from ngraph.ops import sqrt from ngraph.ops import sqrt
from ngraph.ops import squared_difference
from ngraph.ops import squeeze
from ngraph.ops import subtract from ngraph.ops import subtract
from ngraph.ops import sum from ngraph.ops import sum
from ngraph.ops import tan from ngraph.ops import tan
......
...@@ -113,6 +113,7 @@ from _pyngraph.op import Reshape ...@@ -113,6 +113,7 @@ from _pyngraph.op import Reshape
from _pyngraph.op import Reverse from _pyngraph.op import Reverse
from _pyngraph.op import ScaleShift from _pyngraph.op import ScaleShift
from _pyngraph.op import Select from _pyngraph.op import Select
from _pyngraph.op import ShuffleChannels
from _pyngraph.op import Sign from _pyngraph.op import Sign
from _pyngraph.op import Sin from _pyngraph.op import Sin
from _pyngraph.op import Sinh from _pyngraph.op import Sinh
...@@ -120,6 +121,8 @@ from _pyngraph.op import Slice ...@@ -120,6 +121,8 @@ from _pyngraph.op import Slice
from _pyngraph.op import Softmax from _pyngraph.op import Softmax
from _pyngraph.op import SpaceToDepth from _pyngraph.op import SpaceToDepth
from _pyngraph.op import Sqrt from _pyngraph.op import Sqrt
from _pyngraph.op import SquaredDifference
from _pyngraph.op import Squeeze
from _pyngraph.op import Subtract from _pyngraph.op import Subtract
from _pyngraph.op import Sum from _pyngraph.op import Sum
from _pyngraph.op import Tan from _pyngraph.op import Tan
......
...@@ -26,8 +26,8 @@ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgP ...@@ -26,8 +26,8 @@ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgP
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, HardSigmoid, Less, \ Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, HardSigmoid, Less, \
LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, MVN, Negative, Not, NotEqual, \ LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, MVN, Negative, Not, NotEqual, \
OneHot, Or, Pad, Parameter, Product, Power, PRelu, Relu, ReplaceSlice, Reshape, Reverse, \ OneHot, Or, Pad, Parameter, Product, Power, PRelu, Relu, ReplaceSlice, Reshape, Reverse, \
ScaleShift, Select, Sign, Sin, Sinh, Slice, Softmax, SpaceToDepth, Sqrt, Subtract, Sum, Tan, \ ScaleShift, Select, ShuffleChannels, Sign, Sin, Sinh, Slice, Softmax, SpaceToDepth, Sqrt, \
Tanh, TopK, Unsqueeze SquaredDifference, Squeeze, Subtract, Sum, Tan, Tanh, TopK, Unsqueeze
from typing import Callable, Iterable, List, Set, Union from typing import Callable, Iterable, List, Set, Union
...@@ -80,6 +80,81 @@ def elu(data, alpha, name=None): # type: (NodeInput, NumericType, str) -> Node ...@@ -80,6 +80,81 @@ def elu(data, alpha, name=None): # type: (NodeInput, NumericType, str) -> Node
@nameable_op @nameable_op
def shuffle_channels(data, axis, groups, name=None): # type: (Node, int, int, str) -> Node
"""Perform permutation on data in the channel dimension of the input tensor.
The operation is the equivalent with the following transformation of the input tensor
:code:`data` of shape [N, C, H, W]:
:code:`data_reshaped` = reshape(:code:`data`, [N, group, C / group, H * W])
:code:`data_trnasposed` = transpose(:code:`data_reshaped`, [0, 2, 1, 3])
:code:`output` = reshape(:code:`data_trnasposed`, [N, C, H, W])
For example:
.. code-block:: python
Inputs: tensor of shape [1, 6, 2, 2]
data = [[[[ 0., 1.], [ 2., 3.]],
[[ 4., 5.], [ 6., 7.]],
[[ 8., 9.], [10., 11.]],
[[12., 13.], [14., 15.]],
[[16., 17.], [18., 19.]],
[[20., 21.], [22., 23.]]]]
axis = 1
groups = 3
Output: tensor of shape [1, 6, 2, 2]
output = [[[[ 0., 1.], [ 2., 3.]],
[[ 8., 9.], [10., 11.]],
[[16., 17.], [18., 19.]],
[[ 4., 5.], [ 6., 7.]],
[[12., 13.], [14., 15.]],
[[20., 21.], [22., 23.]]]]
:param data: The node with input tensor.
:param axis: Channel dimension index in the data tensor.
A negative value means that the index should be calculated
from the back of the input data shape.
:param group:The channel dimension specified by the axis parameter
should be split into this number of groups.
:param name: Optional output node name.
:return: The new node performing a permutation on data in the channel dimension
of the input tensor.
"""
return ShuffleChannels(data, axis, groups)
@nameable_op
def squeeze(data, axes, name=None): # type: (Node, NodeInput, str) -> Node
"""Perform squeeze operation on input tensor.
Remove single-dimensional entries from the shape of a tensor.
Takes a parameter :code:`axes` with a list of axes to squeeze.
If :code:`axes` is not provided, all the single dimensions will be removed from the shape.
If an :code:`axis` is selected with shape entry not equal to one, an error is raised.
For example:
Inputs: tensor with shape [1, 2, 1, 3, 1, 1], axes=[2, 4]
Result: tensor with shape [1, 2, 3, 1]
:param data: The node with data tensor.
:param axes: List of non-negative integers, indicate the dimensions to squeeze.
One of: input node or array.
:param name: Optional new name for output node.
:return: The new node performing a squeeze operation on input tensor.
"""
return Squeeze(data, as_node(axes))
def unsqueeze(data, axes, name=None): # type: (Node, NodeInput, str) -> Node def unsqueeze(data, axes, name=None): # type: (Node, NodeInput, str) -> Node
"""Perform unsqueeze operation on input tensor. """Perform unsqueeze operation on input tensor.
...@@ -532,6 +607,20 @@ def logical_not(node, name=None): # type: (Node, str) -> Node ...@@ -532,6 +607,20 @@ def logical_not(node, name=None): # type: (Node, str) -> Node
return Not(node) return Not(node)
@binary_op
def squared_difference(x1, x2, name=None): # type: (Node, Node, str) -> Node
"""Perform an element-wise squared difference between two tensors.
.. math:: y[i] = (x_1[i] - x_2[i])^2
:param x1: The node with first input tensor.
:param x2: The node with second input tensor.
:param name: Optional new name for output node.
:return: The new node performing a squared difference between two tensors.
"""
return SquaredDifference(x1, x2)
# Extend Node class to support binary operators # Extend Node class to support binary operators
Node.__add__ = add Node.__add__ = add
Node.__sub__ = subtract Node.__sub__ = subtract
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "pyngraph/ops/fused/shuffle_channels.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_ShuffleChannels(py::module m)
{
py::class_<ngraph::op::ShuffleChannels,
std::shared_ptr<ngraph::op::ShuffleChannels>,
ngraph::op::Op>
shufflechannels(m, "ShuffleChannels");
shufflechannels.doc() = "ngraph.impl.op.ShuffleChannels wraps ngraph::op::ShuffleChannels";
shufflechannels.def(py::init<const std::shared_ptr<ngraph::Node>&, int&, int&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_ShuffleChannels(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/squared_difference.hpp"
#include "pyngraph/ops/fused/squared_difference.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_SquaredDifference(py::module m)
{
py::class_<ngraph::op::SquaredDifference,
std::shared_ptr<ngraph::op::SquaredDifference>,
ngraph::op::Op>
squareddifference(m, "SquaredDifference");
squareddifference.doc() =
"ngraph.impl.op.SquaredDifference wraps ngraph::op::SquaredDifference";
squareddifference.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_SquaredDifference(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/squeeze.hpp"
#include "pyngraph/ops/fused/squeeze.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Squeeze(py::module m)
{
py::class_<ngraph::op::Squeeze, std::shared_ptr<ngraph::op::Squeeze>, ngraph::op::Op> squeeze(
m, "Squeeze");
squeeze.doc() = "ngraph.impl.op.Squeeze wraps ngraph::op::Squeeze";
squeeze.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Squeeze(py::module m);
...@@ -93,6 +93,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -93,6 +93,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Reverse(m_op); regclass_pyngraph_op_Reverse(m_op);
regclass_pyngraph_op_ScaleShift(m_op); regclass_pyngraph_op_ScaleShift(m_op);
regclass_pyngraph_op_Select(m_op); regclass_pyngraph_op_Select(m_op);
regclass_pyngraph_op_ShuffleChannels(m_op);
regclass_pyngraph_op_Sign(m_op); regclass_pyngraph_op_Sign(m_op);
regclass_pyngraph_op_Sin(m_op); regclass_pyngraph_op_Sin(m_op);
regclass_pyngraph_op_Sinh(m_op); regclass_pyngraph_op_Sinh(m_op);
...@@ -100,6 +101,8 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -100,6 +101,8 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Softmax(m_op); regclass_pyngraph_op_Softmax(m_op);
regclass_pyngraph_op_SpaceToDepth(m_op); regclass_pyngraph_op_SpaceToDepth(m_op);
regclass_pyngraph_op_Sqrt(m_op); regclass_pyngraph_op_Sqrt(m_op);
regclass_pyngraph_op_SquaredDifference(m_op);
regclass_pyngraph_op_Squeeze(m_op);
regclass_pyngraph_op_Subtract(m_op); regclass_pyngraph_op_Subtract(m_op);
regclass_pyngraph_op_Sum(m_op); regclass_pyngraph_op_Sum(m_op);
regclass_pyngraph_op_Tan(m_op); regclass_pyngraph_op_Tan(m_op);
......
...@@ -53,7 +53,10 @@ ...@@ -53,7 +53,10 @@
#include "pyngraph/ops/fused/mvn.hpp" #include "pyngraph/ops/fused/mvn.hpp"
#include "pyngraph/ops/fused/prelu.hpp" #include "pyngraph/ops/fused/prelu.hpp"
#include "pyngraph/ops/fused/scale_shift.hpp" #include "pyngraph/ops/fused/scale_shift.hpp"
#include "pyngraph/ops/fused/shuffle_channels.hpp"
#include "pyngraph/ops/fused/space_to_depth.hpp" #include "pyngraph/ops/fused/space_to_depth.hpp"
#include "pyngraph/ops/fused/squared_difference.hpp"
#include "pyngraph/ops/fused/squeeze.hpp"
#include "pyngraph/ops/fused/unsqueeze.hpp" #include "pyngraph/ops/fused/unsqueeze.hpp"
#include "pyngraph/ops/get_output_element.hpp" #include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp" #include "pyngraph/ops/greater.hpp"
......
...@@ -221,12 +221,15 @@ sources = [ ...@@ -221,12 +221,15 @@ sources = [
'pyngraph/ops/reverse.cpp', 'pyngraph/ops/reverse.cpp',
'pyngraph/ops/fused/scale_shift.cpp', 'pyngraph/ops/fused/scale_shift.cpp',
'pyngraph/ops/select.cpp', 'pyngraph/ops/select.cpp',
'pyngraph/ops/fused/shuffle_channels.cpp',
'pyngraph/ops/sign.cpp', 'pyngraph/ops/sign.cpp',
'pyngraph/ops/sin.cpp', 'pyngraph/ops/sin.cpp',
'pyngraph/ops/sinh.cpp', 'pyngraph/ops/sinh.cpp',
'pyngraph/ops/slice.cpp', 'pyngraph/ops/slice.cpp',
'pyngraph/ops/fused/space_to_depth.cpp', 'pyngraph/ops/fused/space_to_depth.cpp',
'pyngraph/ops/sqrt.cpp', 'pyngraph/ops/sqrt.cpp',
'pyngraph/ops/fused/squared_difference.cpp',
'pyngraph/ops/fused/squeeze.cpp',
'pyngraph/ops/subtract.cpp', 'pyngraph/ops/subtract.cpp',
'pyngraph/ops/sum.cpp', 'pyngraph/ops/sum.cpp',
'pyngraph/ops/tan.cpp', 'pyngraph/ops/tan.cpp',
......
...@@ -226,6 +226,67 @@ def test_clamp_operator_with_array(): ...@@ -226,6 +226,67 @@ def test_clamp_operator_with_array():
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_squeeze_operator():
runtime = get_runtime()
data_shape = [1, 2, 1, 3, 1, 1]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
data_value = np.arange(6., dtype=np.float32).reshape(1, 2, 1, 3, 1, 1)
axes = [2, 4]
model = ng.squeeze(parameter_data, axes)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.arange(6., dtype=np.float32).reshape(1, 2, 3, 1)
assert np.allclose(result, expected)
def test_squared_difference_operator():
runtime = get_runtime()
x1_shape = [1, 2, 3, 4]
x2_shape = [2, 3, 4]
parameter_x1 = ng.parameter(x1_shape, name='x1', dtype=np.float32)
parameter_x2 = ng.parameter(x2_shape, name='x2', dtype=np.float32)
x1_value = np.arange(24., dtype=np.float32).reshape(x1_shape)
x2_value = np.arange(start=4., stop=28., step=1.0, dtype=np.float32).reshape(x2_shape)
model = ng.squared_difference(parameter_x1, parameter_x2)
computation = runtime.computation(model, parameter_x1, parameter_x2)
result = computation(x1_value, x2_value)
expected = np.square(np.subtract(x1_value, x2_value))
assert np.allclose(result, expected)
def test_shuffle_channels_operator():
runtime = get_runtime()
data_shape = [1, 15, 2, 2]
axis = 1
groups = 5
parameter = ng.parameter(data_shape, name='Data', dtype=np.float32)
data_value = np.arange(60., dtype=np.float32).reshape(data_shape)
model = ng.shuffle_channels(parameter, axis, groups)
computation = runtime.computation(model, parameter)
result = computation(data_value)
expected = np.array([[[[0., 1.], [2., 3.]], [[12., 13.], [14., 15.]],
[[24., 25.], [26., 27.]], [[36., 37.], [38., 39.]],
[[48., 49.], [50., 51.]], [[4., 5.], [6., 7.]],
[[16., 17.], [18., 19.]], [[28., 29.], [30., 31.]],
[[40., 41.], [42., 43.]], [[52., 53.], [54., 55.]],
[[8., 9.], [10., 11.]], [[20., 21.], [22., 23.]],
[[32., 33.], [34., 35.]], [[44., 45.], [46., 47.]],
[[56., 57.], [58., 59.]]]], dtype=np.float32)
assert np.allclose(result, expected)
def test_unsqueeze(): def test_unsqueeze():
runtime = get_runtime() runtime = get_runtime()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment