Commit e4955613 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[Py] Added elu operator to Python API. (#3236)

* Added elu operator to Python API.

* Added missing file.

* Specified elu function description.

* Expand docstring

* [Py] Added test with scalar for elu operator.

* Bugfix

*  [Py] Changed input type in elu test.

* Update test_ops_binary.py

* [Py] Syntax bugfix.

* [Py] Added elu operator to list in documentation.
parent a58d3bc2
...@@ -32,6 +32,7 @@ ngraph.ops ...@@ -32,6 +32,7 @@ ngraph.ops
cosh cosh
divide divide
dot dot
elu
equal equal
exp exp
floor floor
......
...@@ -45,6 +45,7 @@ from ngraph.ops import cos ...@@ -45,6 +45,7 @@ from ngraph.ops import cos
from ngraph.ops import cosh from ngraph.ops import cosh
from ngraph.ops import divide from ngraph.ops import divide
from ngraph.ops import dot from ngraph.ops import dot
from ngraph.ops import elu
from ngraph.ops import equal from ngraph.ops import equal
from ngraph.ops import exp from ngraph.ops import exp
from ngraph.ops import floor from ngraph.ops import floor
......
...@@ -69,6 +69,7 @@ from _pyngraph.op import Cos ...@@ -69,6 +69,7 @@ from _pyngraph.op import Cos
from _pyngraph.op import Cosh from _pyngraph.op import Cosh
from _pyngraph.op import Divide from _pyngraph.op import Divide
from _pyngraph.op import Dot from _pyngraph.op import Dot
from _pyngraph.op import Elu
from _pyngraph.op import Equal from _pyngraph.op import Equal
from _pyngraph.op import Exp from _pyngraph.op import Exp
from _pyngraph.op import Floor from _pyngraph.op import Floor
......
...@@ -22,7 +22,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -22,7 +22,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Equal, Exp, Floor, \ Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \ GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \ Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \ Power, Relu, ReplaceSlice, Reshape, Reverse, Select, Sign, Sin, Sinh, Slice, Softmax, \
...@@ -35,7 +35,7 @@ from ngraph.utils.decorators import nameable_op, binary_op, unary_op ...@@ -35,7 +35,7 @@ from ngraph.utils.decorators import nameable_op, binary_op, unary_op
from ngraph.utils.input_validation import assert_list_of_ints from ngraph.utils.input_validation import assert_list_of_ints
from ngraph.utils.reduction import get_reduction_axes from ngraph.utils.reduction import get_reduction_axes
from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \ from ngraph.utils.types import NumericType, NumericData, TensorShape, make_constant_node, \
NodeInput, ScalarData NodeInput, ScalarData, as_node
from ngraph.utils.types import get_element_type from ngraph.utils.types import get_element_type
...@@ -60,6 +60,24 @@ def constant(value, dtype=None, name=None): # type: (NumericData, NumericType, ...@@ -60,6 +60,24 @@ def constant(value, dtype=None, name=None): # type: (NumericData, NumericType,
return make_constant_node(value, dtype) return make_constant_node(value, dtype)
@nameable_op
def elu(data, alpha, name=None): # type: (NodeInput, NodeInput, str) -> Node
"""Perform Exponential Linear Unit operation element-wise on data from input node.
Computes exponential linear: alpha * (exp(data) - 1) if < 0, data otherwise.
For more information refer to:
`Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
<http://arxiv.org/abs/1511.07289>`_
:param data: Input tensor. One of: input node, array or scalar.
:param alpha: Multiplier for negative values. One of: input node or scalar value.
:param name: Optional output node name.
:return: The new node performing an ELU operation on its input data element-wise.
"""
return Elu(as_node(data), as_node(alpha))
# Unary ops # Unary ops
@unary_op @unary_op
def absolute(node, name=None): # type: (NodeInput, str) -> Node def absolute(node, name=None): # type: (NodeInput, str) -> Node
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/elu.hpp"
#include "pyngraph/ops/elu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Elu(py::module m)
{
py::class_<ngraph::op::Elu, std::shared_ptr<ngraph::op::Elu>, ngraph::op::Op> elu(m, "Elu");
elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu";
elu.def(py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Elu(py::module m);
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/elu.hpp"
#include "pyngraph/ops/fused/elu.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Elu(py::module m)
{
py::class_<ngraph::op::Elu, std::shared_ptr<ngraph::op::Elu>, ngraph::op::Op> elu(m, "Elu");
elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu";
elu.def(py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
}
...@@ -49,6 +49,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -49,6 +49,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Cosh(m_op); regclass_pyngraph_op_Cosh(m_op);
regclass_pyngraph_op_Divide(m_op); regclass_pyngraph_op_Divide(m_op);
regclass_pyngraph_op_Dot(m_op); regclass_pyngraph_op_Dot(m_op);
regclass_pyngraph_op_Elu(m_op);
regclass_pyngraph_op_Equal(m_op); regclass_pyngraph_op_Equal(m_op);
regclass_pyngraph_op_Exp(m_op); regclass_pyngraph_op_Exp(m_op);
regclass_pyngraph_op_Floor(m_op); regclass_pyngraph_op_Floor(m_op);
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "pyngraph/ops/cosh.hpp" #include "pyngraph/ops/cosh.hpp"
#include "pyngraph/ops/divide.hpp" #include "pyngraph/ops/divide.hpp"
#include "pyngraph/ops/dot.hpp" #include "pyngraph/ops/dot.hpp"
#include "pyngraph/ops/elu.hpp"
#include "pyngraph/ops/equal.hpp" #include "pyngraph/ops/equal.hpp"
#include "pyngraph/ops/exp.hpp" #include "pyngraph/ops/exp.hpp"
#include "pyngraph/ops/floor.hpp" #include "pyngraph/ops/floor.hpp"
......
...@@ -179,6 +179,7 @@ sources = [ ...@@ -179,6 +179,7 @@ sources = [
'pyngraph/ops/ceiling.cpp', 'pyngraph/ops/ceiling.cpp',
'pyngraph/ops/divide.cpp', 'pyngraph/ops/divide.cpp',
'pyngraph/ops/dot.cpp', 'pyngraph/ops/dot.cpp',
'pyngraph/ops/elu.cpp',
'pyngraph/ops/equal.cpp', 'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp', 'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp', 'pyngraph/ops/floor.cpp',
......
# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import ngraph as ng
from test.ngraph.util import get_runtime
def test_elu_operator():
runtime = get_runtime()
data_shape = [2, 2]
alpha_shape = [2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
parameter_alpha = ng.parameter(alpha_shape, name='Alpha', dtype=np.float32)
model = ng.elu(parameter_data, parameter_alpha)
computation = runtime.computation(model, parameter_data, parameter_alpha)
value_data = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
value_alpha = np.array([3, 3], dtype=np.float32)
result = computation(value_data, value_alpha)
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)
def test_elu_operator_with_scalar_and_array():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
alpha_value = np.float32(3)
model = ng.elu(data_value, alpha_value)
computation = runtime.computation(model)
result = computation()
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)
def test_elu_operator_with_scalar():
runtime = get_runtime()
data_value = np.array([[-5, 1], [-2, 3]], dtype=np.float32)
alpha_value = np.float32(3)
data_shape = [2, 2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
model = ng.elu(parameter_data, alpha_value)
computation = runtime.computation(model, parameter_data)
result = computation(data_value)
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)
...@@ -39,7 +39,7 @@ NodeVector op::Elu::decompose_op() const ...@@ -39,7 +39,7 @@ NodeVector op::Elu::decompose_op() const
auto data = get_argument(0); auto data = get_argument(0);
auto alpha_node = get_argument(1); auto alpha_node = get_argument(1);
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape()); alpha_node = ngraph::op::numpy_style_broadcast(alpha_node, data->get_shape());
shared_ptr<ngraph::Node> zero_node = shared_ptr<ngraph::Node> zero_node =
builder::make_constant(data->get_element_type(), data->get_shape(), 0); builder::make_constant(data->get_element_type(), data->get_shape(), 0);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment