Commit 439d4b51 authored by Ewa Tusień's avatar Ewa Tusień Committed by Michał Karzyński

[Py] Added unsqueeze operator to Python API (#3375)

parent 90eb91a5
......@@ -81,6 +81,7 @@ ngraph.ops
tan
tanh
topk
unsqueeze
......
......@@ -94,5 +94,7 @@ from ngraph.ops import sum
from ngraph.ops import tan
from ngraph.ops import tanh
from ngraph.ops import topk
from ngraph.ops import unsqueeze
from ngraph.runtime import runtime
......@@ -120,3 +120,4 @@ from _pyngraph.op import Sum
from _pyngraph.op import Tan
from _pyngraph.op import Tanh
from _pyngraph.op import TopK
from _pyngraph.op import Unsqueeze
......@@ -26,7 +26,7 @@ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgP
Equal, Exp, Floor, Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, \
LRN, Max, Maximum, MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, \
Parameter, Product, Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, \
Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK
Slice, Softmax, Sqrt, Subtract, Sum, Tan, Tanh, TopK, Unsqueeze
from typing import Callable, Iterable, List, Union
......@@ -79,6 +79,24 @@ def elu(data, alpha, name=None): # type: (NodeInput, NodeInput, str) -> Node
@nameable_op
def unsqueeze(data, axes, name=None): # type: (Node, NodeInput, str) -> Node
"""Perform unsqueeze operation on input tensor.
Insert single-dimensional entries to the shape of a tensor. Takes one required argument axes,
a list of dimensions that will be inserted.
Dimension indices in axes are as seen in the output tensor.
For example: Inputs: tensor with shape [3, 4, 5], axes=[0, 4]
Result: tensor with shape [1, 3, 4, 5, 1]
:param data: The node with data tensor.
:param axes: List of non-negative integers, indicate the dimensions to be inserted.
One of: input node or array.
:return: The new node performing an unsqueeze operation on input tensor.
"""
return Unsqueeze(data, as_node(axes))
def grn(data, bias, name=None): # type: (Node, float, str) -> Node
r"""Perform Global Response Normalization with L2 norm (across channels only).
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/unsqueeze.hpp"
#include "pyngraph/ops/fused/unsqueeze.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Unsqueeze(py::module m)
{
py::class_<ngraph::op::Unsqueeze, std::shared_ptr<ngraph::op::Unsqueeze>, ngraph::op::Op>
unsqueeze(m, "Unsqueeze");
unsqueeze.doc() = "ngraph.impl.op.Unsqueeze wraps ngraph::op::Unsqueeze";
unsqueeze.def(
py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Unsqueeze(py::module m);
......@@ -101,4 +101,5 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Tanh(m_op);
regclass_pyngraph_op_TopK(m_op);
regclass_pyngraph_op_Result(m_op);
regclass_pyngraph_op_Unsqueeze(m_op);
}
......@@ -49,6 +49,7 @@
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/fused/grn.hpp"
#include "pyngraph/ops/fused/unsqueeze.hpp"
#include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp"
......
......@@ -233,6 +233,7 @@ sources = [
'pyngraph/ops/batch_norm.cpp',
'pyngraph/ops/softmax.cpp',
'pyngraph/ops/result.cpp',
'pyngraph/ops/fused/unsqueeze.cpp',
'pyngraph/runtime/backend.cpp',
'pyngraph/runtime/executable.cpp',
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp',
......
......@@ -245,6 +245,21 @@ def test_clamp_operator_with_array():
assert np.allclose(result, expected)
def test_unsqueeze():
runtime = get_runtime()
data_shape = [3, 4, 5]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
data_value = np.arange(60., dtype=np.float32).reshape(3, 4, 5)
axes = [0, 4]
model = ng.unsqueeze(parameter_data, axes)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.arange(60., dtype=np.float32).reshape(1, 3, 4, 5, 1)
assert np.allclose(result, expected)
def test_grn_operator():
runtime = get_runtime()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment