Commit 56618491 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[Py] Added GRN operator to Python API. (#3365)

* [Py] Added grn operator to Python API.

* [Py] Changed order of included files.

* [Py] Changed docstring.

* style
parent 72fa2015
......@@ -42,6 +42,7 @@ ngraph.ops
get_output_element
greater
greater_eq
grn
less
less_eq
log
......
......@@ -55,6 +55,7 @@ from ngraph.ops import gemm
from ngraph.ops import get_output_element
from ngraph.ops import greater
from ngraph.ops import greater_eq
from ngraph.ops import grn
from ngraph.ops import less
from ngraph.ops import less_eq
from ngraph.ops import log
......
......@@ -79,6 +79,7 @@ from _pyngraph.op import Gemm
from _pyngraph.op import GetOutputElement
from _pyngraph.op import Greater
from _pyngraph.op import GreaterEq
from _pyngraph.op import GRN
from _pyngraph.op import Less
from _pyngraph.op import LessEq
from _pyngraph.op import Log
......
......@@ -23,7 +23,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
Gelu, Gemm, GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, \
Gelu, Gemm, GetOutputElement, Greater, GreaterEq, GRN, Less, LessEq, Log, LRN, Max, Maximum, \
MaxPool, Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
Sqrt, Subtract, Sum, Tan, Tanh, TopK
......@@ -78,6 +78,22 @@ def elu(data, alpha, name=None): # type: (NodeInput, NodeInput, str) -> Node
return Elu(as_node(data), as_node(alpha))
@nameable_op
def grn(data, bias, name=None): # type: (Node, float, str) -> Node
r"""Perform Global Response Normalization with L2 norm (across channels only).
Computes GRN operation on channels for input tensor:
.. math:: output_i = \dfrac{input_i}{\sqrt{\sum_{i}^{C} input_i}}
:param data: The node with data tensor.
:param bias: The bias added to the variance. Scalar value.
:param name: Optional output node name.
:return: The new node performing a GRN operation on tensor's channels.
"""
return GRN(data, bias)
# Unary ops
@unary_op
def absolute(node, name=None): # type: (NodeInput, str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/grn.hpp"
#include "pyngraph/ops/fused/grn.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_GRN(py::module m)
{
py::class_<ngraph::op::GRN, std::shared_ptr<ngraph::op::GRN>, ngraph::op::Op> grn(m, "GRN");
grn.doc() = "ngraph.impl.op.GRN wraps ngraph::op::GRN";
grn.def(py::init<const std::shared_ptr<ngraph::Node>&, float&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_GRN(py::module m);
......@@ -59,6 +59,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_GetOutputElement(m_op);
regclass_pyngraph_op_Greater(m_op);
regclass_pyngraph_op_GreaterEq(m_op);
regclass_pyngraph_op_GRN(m_op);
regclass_pyngraph_op_Less(m_op);
regclass_pyngraph_op_LessEq(m_op);
regclass_pyngraph_op_Log(m_op);
......
......@@ -46,6 +46,7 @@
#include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/fused/gelu.hpp"
#include "pyngraph/ops/fused/gemm.hpp"
#include "pyngraph/ops/fused/grn.hpp"
#include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp"
......
......@@ -188,6 +188,7 @@ sources = [
'pyngraph/ops/fused/gemm.cpp',
'pyngraph/ops/greater.cpp',
'pyngraph/ops/greater_eq.cpp',
'pyngraph/ops/fused/grn.cpp',
'pyngraph/ops/less.cpp',
'pyngraph/ops/less_eq.cpp',
'pyngraph/ops/log.cpp',
......
......@@ -170,3 +170,26 @@ def test_clamp_operator_with_array():
expected = np.clip(data_value, min_value, max_value)
assert np.allclose(result, expected)
def test_grn_operator():
runtime = get_runtime()
data_value = np.arange(start=1.0, stop=25.0, dtype=np.float32).reshape(1, 2, 3, 4)
bias = np.float32(1e-6)
data_shape = [1, 2, 3, 4]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.grn(parameter_data, bias)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.array([[[[0.0766965, 0.14142136, 0.19611613, 0.24253564],
[0.28216633, 0.31622776, 0.34570536, 0.37139067],
[0.39391932, 0.41380295, 0.4314555, 0.4472136]],
[[0.9970545, 0.98994946, 0.9805807, 0.97014254],
[0.9593655, 0.9486833, 0.9383431, 0.9284767],
[0.91914505, 0.9103665, 0.9021342, 0.8944272]]]], dtype=np.float32)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment