Commit 439d4b51 authored by Ewa Tusień's avatar Ewa Tusień Committed by Michał Karzyński

[Py] Added unsqueeze operator to Python API (#3375)

parent 90eb91a5
...@@ -81,6 +81,7 @@ ngraph.ops ...@@ -81,6 +81,7 @@ ngraph.ops
tan tan
tanh tanh
topk topk
unsqueeze
......
...@@ -94,5 +94,7 @@ from ngraph.ops import sum ...@@ -94,5 +94,7 @@ from ngraph.ops import sum
from ngraph.ops import tan from ngraph.ops import tan
from ngraph.ops import tanh from ngraph.ops import tanh
from ngraph.ops import topk from ngraph.ops import topk
from ngraph.ops import unsqueeze
from ngraph.runtime import runtime from ngraph.runtime import runtime
...@@ -120,3 +120,4 @@ from _pyngraph.op import Sum ...@@ -120,3 +120,4 @@ from _pyngraph.op import Sum
from _pyngraph.op import Tan from _pyngraph.op import Tan
from _pyngraph.op import Tanh from _pyngraph.op import Tanh
from _pyngraph.op import TopK from _pyngraph.op import TopK
from _pyngraph.op import Unsqueeze
...@@ -26,7 +26,7 @@ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgP ...@@ -26,7 +26,7 @@ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgP
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, \ Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, \
LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \ LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, \ Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, \
Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK, Unsqueeze
from typing import Callable, Iterable, List, Union from typing import Callable, Iterable, List, Union
...@@ -79,6 +79,24 @@ def elu(data, alpha, name=None): # type: (NodeInput, NodeInput, str) -> Node ...@@ -79,6 +79,24 @@ def elu(data, alpha, name=None): # type: (NodeInput, NodeInput, str) -> Node
@nameable_op @nameable_op
def unsqueeze(data, axes, name=None): # type: (Node, NodeInput, str) -> Node
"""Perform unsqueeze operation on input tensor.
Insert single-dimensional entries to the shape of a tensor. Takes one required argument axes,
a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
For example: Inputs: tensor with shape [3, 4, 5], axes=[0, 4]
Result: tensor with shape [1, 3, 4, 5, 1]
:param data: The node with data tensor.
:param axes: List of non-negative integers, indicate the dimensions to be inserted.
One of: input node or array.
:return: The new node performing an unsqueeze operation on input tensor.
"""
return Unsqueeze(data, as_node(axes))
def grn(data, bias, name=None): # type: (Node, float, str) -> Node def grn(data, bias, name=None): # type: (Node, float, str) -> Node
r"""Perform Global Response Normalization with L2 norm (across channels only). r"""Perform Global Response Normalization with L2 norm (across channels only).
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/unsqueeze.hpp"
#include "pyngraph/ops/fused/unsqueeze.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Unsqueeze(py::module m)
{
py::class_<ngraph::op::Unsqueeze, std::shared_ptr<ngraph::op::Unsqueeze>, ngraph::op::Op>
unsqueeze(m, "Unsqueeze");
unsqueeze.doc() = "ngraph.impl.op.Unsqueeze wraps ngraph::op::Unsqueeze";
unsqueeze.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Unsqueeze(py::module m);
...@@ -101,4 +101,5 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -101,4 +101,5 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Tanh(m_op); regclass_pyngraph_op_Tanh(m_op);
regclass_pyngraph_op_TopK(m_op); regclass_pyngraph_op_TopK(m_op);
regclass_pyngraph_op_Result(m_op); regclass_pyngraph_op_Result(m_op);
regclass_pyngraph_op_Unsqueeze(m_op);
} }
...@@ -49,6 +49,7 @@ ...@@ -49,6 +49,7 @@
#include "pyngraph/ops/fused/gelu.hpp" #include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp" #include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/fused/grn.hpp" #include "pyngraph/ops/fused/grn.hpp"
#include "pyngraph/ops/fused/unsqueeze.hpp"
#include "pyngraph/ops/get_output_element.hpp" #include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp" #include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp" #include "pyngraph/ops/greater_eq.hpp"
......
...@@ -233,6 +233,7 @@ sources = [ ...@@ -233,6 +233,7 @@ sources = [
'pyngraph/ops/batch_norm.cpp', 'pyngraph/ops/batch_norm.cpp',
'pyngraph/ops/softmax.cpp', 'pyngraph/ops/softmax.cpp',
'pyngraph/ops/result.cpp', 'pyngraph/ops/result.cpp',
'pyngraph/ops/fused/unsqueeze.cpp',
'pyngraph/runtime/backend.cpp', 'pyngraph/runtime/backend.cpp',
'pyngraph/runtime/executable.cpp', 'pyngraph/runtime/executable.cpp',
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp', 'pyngraph/runtime/regmodule_pyngraph_runtime.cpp',
......
...@@ -245,6 +245,21 @@ def test_clamp_operator_with_array(): ...@@ -245,6 +245,21 @@ def test_clamp_operator_with_array():
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_unsqueeze():
runtime = get_runtime()
data_shape = [3, 4, 5]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
data_value = np.arange(60., dtype=np.float32).reshape(3, 4, 5)
axes = [0, 4]
model = ng.unsqueeze(parameter_data, axes)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.arange(60., dtype=np.float32).reshape(1, 3, 4, 5, 1)
assert np.allclose(result, expected)
def test_grn_operator(): def test_grn_operator():
runtime = get_runtime() runtime = get_runtime()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment