Commit 7b711340 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[Py] Added operators Shuffle Channels, Squared Difference and Squeeze to Python API. (#3393)

* [Py] Added operators Shuffle Channels, Squared Difference and Squeeze to Python API.

* [Py] Changed docstring.

* [Py] Changed docstring.

* [Py] Changed docstring.
parent bbb9a566
......@@ -74,6 +74,7 @@ ngraph.ops
reverse
scale_shift
select
shuffle_channels
sign
sin
sinh
......@@ -81,6 +82,8 @@ ngraph.ops
softmax
space_to_depth
sqrt
squared_difference
squeeze
subtract
sum
tan
......
......@@ -87,6 +87,7 @@ from ngraph.ops import reshape
from ngraph.ops import reverse
from ngraph.ops import scale_shift
from ngraph.ops import select
from ngraph.ops import shuffle_channels
from ngraph.ops import sign
from ngraph.ops import sin
from ngraph.ops import sinh
......@@ -94,6 +95,8 @@ from ngraph.ops import slice
from ngraph.ops import softmax
from ngraph.ops import space_to_depth
from ngraph.ops import sqrt
from ngraph.ops import squared_difference
from ngraph.ops import squeeze
from ngraph.ops import subtract
from ngraph.ops import sum
from ngraph.ops import tan
......
......@@ -113,6 +113,7 @@ from _pyngraph.op import Reshape
from _pyngraph.op import Reverse
from _pyngraph.op import ScaleShift
from _pyngraph.op import Select
from _pyngraph.op import ShuffleChannels
from _pyngraph.op import Sign
from _pyngraph.op import Sin
from _pyngraph.op import Sinh
......@@ -120,6 +121,8 @@ from _pyngraph.op import Slice
from _pyngraph.op import Softmax
from _pyngraph.op import SpaceToDepth
from _pyngraph.op import Sqrt
from _pyngraph.op import SquaredDifference
from _pyngraph.op import Squeeze
from _pyngraph.op import Subtract
from _pyngraph.op import Sum
from _pyngraph.op import Tan
......
......@@ -26,8 +26,8 @@ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgP
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, HardSigmoid, Less, \
LessEq, Log, LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, MVN, Negative, Not, NotEqual, \
OneHot, Or, Pad, Parameter, Product, Power, PRelu, Relu, ReplaceSlice, Reshape, Reverse, \
ScaleShift, Select, Sign, Sin, Sinh, Slice, Softmax, SpaceToDepth, Sqrt, Subtract, Sum, Tan, \
Tanh, TopK, Unsqueeze
ScaleShift, Select, ShuffleChannels, Sign, Sin, Sinh, Slice, Softmax, SpaceToDepth, Sqrt, \
SquaredDifference, Squeeze, Subtract, Sum, Tan, Tanh, TopK, Unsqueeze
from typing import Callable, Iterable, List, Set, Union
......@@ -80,6 +80,81 @@ def elu(data, alpha, name=None): # type: (NodeInput, NumericType, str) -> Node
@nameable_op
def shuffle_channels(data, axis, groups, name=None): # type: (Node, int, int, str) -> Node
"""Perform permutation on data in the channel dimension of the input tensor.
The operation is the equivalent with the following transformation of the input tensor
:code:`data` of shape [N, C, H, W]:
:code:`data_reshaped` = reshape(:code:`data`, [N, group, C / group, H * W])
:code:`data_trnasposed` = transpose(:code:`data_reshaped`, [0, 2, 1, 3])
:code:`output` = reshape(:code:`data_trnasposed`, [N, C, H, W])
For example:
.. code-block:: python
Inputs: tensor of shape [1, 6, 2, 2]
data = [[[[ 0., 1.], [ 2., 3.]],
[[ 4., 5.], [ 6., 7.]],
[[ 8., 9.], [10., 11.]],
[[12., 13.], [14., 15.]],
[[16., 17.], [18., 19.]],
[[20., 21.], [22., 23.]]]]
axis = 1
groups = 3
Output: tensor of shape [1, 6, 2, 2]
output = [[[[ 0., 1.], [ 2., 3.]],
[[ 8., 9.], [10., 11.]],
[[16., 17.], [18., 19.]],
[[ 4., 5.], [ 6., 7.]],
[[12., 13.], [14., 15.]],
[[20., 21.], [22., 23.]]]]
:param data: The node with input tensor.
:param axis: Channel dimension index in the data tensor.
A negative value means that the index should be calculated
from the back of the input data shape.
:param group:The channel dimension specified by the axis parameter
should be split into this number of groups.
:param name: Optional output node name.
:return: The new node performing a permutation on data in the channel dimension
of the input tensor.
"""
return ShuffleChannels(data, axis, groups)
@nameable_op
def squeeze(data, axes, name=None): # type: (Node, NodeInput, str) -> Node
"""Perform squeeze operation on input tensor.
Remove single-dimensional entries from the shape of a tensor.
Takes a parameter :code:`axes` with a list of axes to squeeze.
If :code:`axes` is not provided, all the single dimensions will be removed from the shape.
If an :code:`axis` is selected with shape entry not equal to one, an error is raised.
For example:
Inputs: tensor with shape [1, 2, 1, 3, 1, 1], axes=[2, 4]
Result: tensor with shape [1, 2, 3, 1]
:param data: The node with data tensor.
:param axes: List of non-negative integers, indicate the dimensions to squeeze.
One of: input node or array.
:param name: Optional new name for output node.
:return: The new node performing a squeeze operation on input tensor.
"""
return Squeeze(data, as_node(axes))
def unsqueeze(data, axes, name=None): # type: (Node, NodeInput, str) -> Node
"""Perform unsqueeze operation on input tensor.
......@@ -532,6 +607,20 @@ def logical_not(node, name=None): # type: (Node, str) -> Node
return Not(node)
@binary_op
def squared_difference(x1, x2, name=None): # type: (Node, Node, str) -> Node
"""Perform an element-wise squared difference between two tensors.
.. math:: y[i] = (x_1[i] - x_2[i])^2
:param x1: The node with first input tensor.
:param x2: The node with second input tensor.
:param name: Optional new name for output node.
:return: The new node performing a squared difference between two tensors.
"""
return SquaredDifference(x1, x2)
# Extend Node class to support binary operators
Node.__add__ = add
Node.__sub__ = subtract
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "pyngraph/ops/fused/shuffle_channels.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_ShuffleChannels(py::module m)
{
py::class_<ngraph::op::ShuffleChannels,
std::shared_ptr<ngraph::op::ShuffleChannels>,
ngraph::op::Op>
shufflechannels(m, "ShuffleChannels");
shufflechannels.doc() = "ngraph.impl.op.ShuffleChannels wraps ngraph::op::ShuffleChannels";
shufflechannels.def(py::init<const std::shared_ptr<ngraph::Node>&, int&, int&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_ShuffleChannels(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/squared_difference.hpp"
#include "pyngraph/ops/fused/squared_difference.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_SquaredDifference(py::module m)
{
py::class_<ngraph::op::SquaredDifference,
std::shared_ptr<ngraph::op::SquaredDifference>,
ngraph::op::Op>
squareddifference(m, "SquaredDifference");
squareddifference.doc() =
"ngraph.impl.op.SquaredDifference wraps ngraph::op::SquaredDifference";
squareddifference.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_SquaredDifference(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/squeeze.hpp"
#include "pyngraph/ops/fused/squeeze.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Squeeze(py::module m)
{
py::class_<ngraph::op::Squeeze, std::shared_ptr<ngraph::op::Squeeze>, ngraph::op::Op> squeeze(
m, "Squeeze");
squeeze.doc() = "ngraph.impl.op.Squeeze wraps ngraph::op::Squeeze";
squeeze.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Squeeze(py::module m);
......@@ -93,6 +93,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Reverse(m_op);
regclass_pyngraph_op_ScaleShift(m_op);
regclass_pyngraph_op_Select(m_op);
regclass_pyngraph_op_ShuffleChannels(m_op);
regclass_pyngraph_op_Sign(m_op);
regclass_pyngraph_op_Sin(m_op);
regclass_pyngraph_op_Sinh(m_op);
......@@ -100,6 +101,8 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Softmax(m_op);
regclass_pyngraph_op_SpaceToDepth(m_op);
regclass_pyngraph_op_Sqrt(m_op);
regclass_pyngraph_op_SquaredDifference(m_op);
regclass_pyngraph_op_Squeeze(m_op);
regclass_pyngraph_op_Subtract(m_op);
regclass_pyngraph_op_Sum(m_op);
regclass_pyngraph_op_Tan(m_op);
......
......@@ -53,7 +53,10 @@
#include "pyngraph/ops/fused/mvn.hpp"
#include "pyngraph/ops/fused/prelu.hpp"
#include "pyngraph/ops/fused/scale_shift.hpp"
#include "pyngraph/ops/fused/shuffle_channels.hpp"
#include "pyngraph/ops/fused/space_to_depth.hpp"
#include "pyngraph/ops/fused/squared_difference.hpp"
#include "pyngraph/ops/fused/squeeze.hpp"
#include "pyngraph/ops/fused/unsqueeze.hpp"
#include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp"
......
......@@ -221,12 +221,15 @@ sources = [
'pyngraph/ops/reverse.cpp',
'pyngraph/ops/fused/scale_shift.cpp',
'pyngraph/ops/select.cpp',
'pyngraph/ops/fused/shuffle_channels.cpp',
'pyngraph/ops/sign.cpp',
'pyngraph/ops/sin.cpp',
'pyngraph/ops/sinh.cpp',
'pyngraph/ops/slice.cpp',
'pyngraph/ops/fused/space_to_depth.cpp',
'pyngraph/ops/sqrt.cpp',
'pyngraph/ops/fused/squared_difference.cpp',
'pyngraph/ops/fused/squeeze.cpp',
'pyngraph/ops/subtract.cpp',
'pyngraph/ops/sum.cpp',
'pyngraph/ops/tan.cpp',
......
......@@ -226,6 +226,67 @@ def test_clamp_operator_with_array():
assert np.allclose(result, expected)
def test_squeeze_operator():
runtime = get_runtime()
data_shape = [1, 2, 1, 3, 1, 1]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
data_value = np.arange(6., dtype=np.float32).reshape(1, 2, 1, 3, 1, 1)
axes = [2, 4]
model = ng.squeeze(parameter_data, axes)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.arange(6., dtype=np.float32).reshape(1, 2, 3, 1)
assert np.allclose(result, expected)
def test_squared_difference_operator():
runtime = get_runtime()
x1_shape = [1, 2, 3, 4]
x2_shape = [2, 3, 4]
parameter_x1 = ng.parameter(x1_shape, name='x1', dtype=np.float32)
parameter_x2 = ng.parameter(x2_shape, name='x2', dtype=np.float32)
x1_value = np.arange(24., dtype=np.float32).reshape(x1_shape)
x2_value = np.arange(start=4., stop=28., step=1.0, dtype=np.float32).reshape(x2_shape)
model = ng.squared_difference(parameter_x1, parameter_x2)
computation = runtime.computation(model, parameter_x1, parameter_x2)
result = computation(x1_value, x2_value)
expected = np.square(np.subtract(x1_value, x2_value))
assert np.allclose(result, expected)
def test_shuffle_channels_operator():
runtime = get_runtime()
data_shape = [1, 15, 2, 2]
axis = 1
groups = 5
parameter = ng.parameter(data_shape, name='Data', dtype=np.float32)
data_value = np.arange(60., dtype=np.float32).reshape(data_shape)
model = ng.shuffle_channels(parameter, axis, groups)
computation = runtime.computation(model, parameter)
result = computation(data_value)
expected = np.array([[[[0., 1.], [2., 3.]], [[12., 13.], [14., 15.]],
[[24., 25.], [26., 27.]], [[36., 37.], [38., 39.]],
[[48., 49.], [50., 51.]], [[4., 5.], [6., 7.]],
[[16., 17.], [18., 19.]], [[28., 29.], [30., 31.]],
[[40., 41.], [42., 43.]], [[52., 53.], [54., 55.]],
[[8., 9.], [10., 11.]], [[20., 21.], [22., 23.]],
[[32., 33.], [34., 35.]], [[44., 45.], [46., 47.]],
[[56., 57.], [58., 59.]]]], dtype=np.float32)
assert np.allclose(result, expected)
def test_unsqueeze():
runtime = get_runtime()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment