Commit 809f7f69 authored by Ewa21's avatar Ewa21

[Py] Added clamp operator to Python API. Created new package for registration of…

[Py] Added clamp operator to Python API. Created new package for registration of fused ops in pyngraph.
parent feefdbb2
...@@ -23,6 +23,7 @@ ngraph.ops ...@@ -23,6 +23,7 @@ ngraph.ops
broadcast broadcast
broadcast_to broadcast_to
ceiling ceiling
clamp
concat concat
constant constant
convert convert
......
...@@ -36,6 +36,7 @@ from ngraph.ops import broadcast ...@@ -36,6 +36,7 @@ from ngraph.ops import broadcast
from ngraph.ops import broadcast_to from ngraph.ops import broadcast_to
from ngraph.ops import ceiling from ngraph.ops import ceiling
from ngraph.ops import ceiling as ceil from ngraph.ops import ceiling as ceil
from ngraph.ops import clamp
from ngraph.ops import concat from ngraph.ops import concat
from ngraph.ops import constant from ngraph.ops import constant
from ngraph.ops import convert from ngraph.ops import convert
......
...@@ -61,6 +61,7 @@ from _pyngraph.op import Constant ...@@ -61,6 +61,7 @@ from _pyngraph.op import Constant
""" """
Constant.get_data = lambda self: np.array(self, copy=True) Constant.get_data = lambda self: np.array(self, copy=True)
from _pyngraph.op import Clamp
from _pyngraph.op import Convert from _pyngraph.op import Convert
from _pyngraph.op import Convolution from _pyngraph.op import Convolution
from _pyngraph.op import ConvolutionBackpropData from _pyngraph.op import ConvolutionBackpropData
......
...@@ -21,7 +21,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio ...@@ -21,7 +21,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
Shape, Strides Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \ from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Concat, Constant, Convert, \ BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \ Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \ GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \ Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
...@@ -555,6 +555,35 @@ def tanh(node, name=None): # type: (Node, str) -> Node ...@@ -555,6 +555,35 @@ def tanh(node, name=None): # type: (Node, str) -> Node
return Tanh(node) return Tanh(node)
@nameable_op
def clamp(data, min_value, max_value, name=None):
# type: (NodeInput, ScalarData, ScalarData, str) -> NodeInput
"""Perform clamp element-wise on data from input node.
Performs a clipping operation on an input value between a pair of boundary values.
If :code:`data` compares less than :code:`min_value`, sets :code:`min_value`;
else if :code:`max_value` compares less than :code:`data`, sets :code:`max_value`;
otherwise remains unchanged :code:`data`.
Computes clamp:
.. code-block:: python
if data < min_value:
data=min_value
elif data > max_value:
data=max_value
:param data: Input tensor. One of: input node, array or scalar.
:param min_value: The lower bound of the <min_value;max_value> range. Scalar value.
:param max_value: The upper bound of the <min_value;max_value> range Scalar value.
:param name: Optional output node name.
:return: The new node performing a clamp operation on its input data element-wise.
"""
return Clamp(as_node(data), min_value, max_value)
# matmul ops # matmul ops
@nameable_op @nameable_op
def dot(left_node, right_node, reduction_axes_count=None, name=None): def dot(left_node, right_node, reduction_axes_count=None, name=None):
......
...@@ -17,14 +17,15 @@ ...@@ -17,14 +17,15 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/clamp.hpp"
#include "pyngraph/ops/elu.hpp" #include "pyngraph/ops/fused/clamp.hpp"
namespace py = pybind11; namespace py = pybind11;
void regclass_pyngraph_op_Elu(py::module m) void regclass_pyngraph_op_Clamp(py::module m)
{ {
py::class_<ngraph::op::Elu, std::shared_ptr<ngraph::op::Elu>, ngraph::op::Op> elu(m, "Elu"); py::class_<ngraph::op::Clamp, std::shared_ptr<ngraph::op::Clamp>, ngraph::op::Op>
elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu"; clamp(m, "Clamp");
elu.def(py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>()); clamp.doc() = "ngraph.impl.op.Clamp wraps ngraph::op::Clamp";
clamp.def(py::init<const std::shared_ptr<ngraph::Node>&, const double, const double>());
} }
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Clamp(py::module m);
...@@ -39,6 +39,7 @@ void regmodule_pyngraph_op(py::module m_op) ...@@ -39,6 +39,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Broadcast(m_op); regclass_pyngraph_op_Broadcast(m_op);
regclass_pyngraph_op_BroadcastDistributed(m_op); regclass_pyngraph_op_BroadcastDistributed(m_op);
regclass_pyngraph_op_Ceiling(m_op); regclass_pyngraph_op_Ceiling(m_op);
regclass_pyngraph_op_Clamp(m_op);
regclass_pyngraph_op_Concat(m_op); regclass_pyngraph_op_Concat(m_op);
regclass_pyngraph_op_Constant(m_op); regclass_pyngraph_op_Constant(m_op);
regclass_pyngraph_op_Convert(m_op); regclass_pyngraph_op_Convert(m_op);
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "pyngraph/ops/broadcast.hpp" #include "pyngraph/ops/broadcast.hpp"
#include "pyngraph/ops/broadcast_distributed.hpp" #include "pyngraph/ops/broadcast_distributed.hpp"
#include "pyngraph/ops/ceiling.hpp" #include "pyngraph/ops/ceiling.hpp"
#include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/concat.hpp" #include "pyngraph/ops/concat.hpp"
#include "pyngraph/ops/constant.hpp" #include "pyngraph/ops/constant.hpp"
#include "pyngraph/ops/convert.hpp" #include "pyngraph/ops/convert.hpp"
...@@ -39,7 +40,7 @@ ...@@ -39,7 +40,7 @@
#include "pyngraph/ops/cosh.hpp" #include "pyngraph/ops/cosh.hpp"
#include "pyngraph/ops/divide.hpp" #include "pyngraph/ops/divide.hpp"
#include "pyngraph/ops/dot.hpp" #include "pyngraph/ops/dot.hpp"
#include "pyngraph/ops/elu.hpp" #include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/equal.hpp" #include "pyngraph/ops/equal.hpp"
#include "pyngraph/ops/exp.hpp" #include "pyngraph/ops/exp.hpp"
#include "pyngraph/ops/floor.hpp" #include "pyngraph/ops/floor.hpp"
......
...@@ -170,6 +170,7 @@ sources = [ ...@@ -170,6 +170,7 @@ sources = [
'pyngraph/ops/avg_pool.cpp', 'pyngraph/ops/avg_pool.cpp',
'pyngraph/ops/broadcast.cpp', 'pyngraph/ops/broadcast.cpp',
'pyngraph/ops/broadcast_distributed.cpp', 'pyngraph/ops/broadcast_distributed.cpp',
'pyngraph/ops/fused/clamp.cpp',
'pyngraph/ops/concat.cpp', 'pyngraph/ops/concat.cpp',
'pyngraph/ops/constant.cpp', 'pyngraph/ops/constant.cpp',
'pyngraph/ops/convert.cpp', 'pyngraph/ops/convert.cpp',
...@@ -179,7 +180,7 @@ sources = [ ...@@ -179,7 +180,7 @@ sources = [
'pyngraph/ops/ceiling.cpp', 'pyngraph/ops/ceiling.cpp',
'pyngraph/ops/divide.cpp', 'pyngraph/ops/divide.cpp',
'pyngraph/ops/dot.cpp', 'pyngraph/ops/dot.cpp',
'pyngraph/ops/elu.cpp', 'pyngraph/ops/fused/elu.cpp',
'pyngraph/ops/equal.cpp', 'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp', 'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp', 'pyngraph/ops/floor.cpp',
......
# ****************************************************************************** # ******************************************************************************
# Copyright 2017-2019 Intel Corporation # Copyright 2018-2019 Intel Corporation
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -67,3 +67,36 @@ def test_elu_operator_with_scalar(): ...@@ -67,3 +67,36 @@ def test_elu_operator_with_scalar():
result = computation(data_value) result = computation(data_value)
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32) expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected) assert np.allclose(result, expected)
def test_clamp_operator():
runtime = get_runtime()
data_shape = [2, 2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
min_value = np.float32(3)
max_value = np.float32(12)
model = ng.clamp(parameter_data, min_value, max_value)
computation = runtime.computation(model, parameter_data)
data_value = np.array([[-5, 9], [45, 3]], dtype=np.float32)
result = computation(data_value)
expected = np.clip(data_value, min_value, max_value)
assert np.allclose(result, expected)
def test_clamp_operator_with_array():
runtime = get_runtime()
data_value = np.array([[-5, 9], [45, 3]], dtype=np.float32)
min_value = np.float32(3)
max_value = np.float32(12)
model = ng.clamp(data_value, min_value, max_value)
computation = runtime.computation(model)
result = computation()
expected = np.clip(data_value, min_value, max_value)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment