Unverified Commit 30603191 authored by Robert Kimball's avatar Robert Kimball Committed by GitHub

Merge pull request #3294 from NervanaSystems/etusien/clamp

[Py] Added clamp operator to Python API
parents 5b59c095 138b6260
......@@ -23,6 +23,7 @@ ngraph.ops
broadcast
broadcast_to
ceiling
clamp
concat
constant
convert
......
......@@ -36,6 +36,7 @@ from ngraph.ops import broadcast
from ngraph.ops import broadcast_to
from ngraph.ops import ceiling
from ngraph.ops import ceiling as ceil
from ngraph.ops import clamp
from ngraph.ops import concat
from ngraph.ops import constant
from ngraph.ops import convert
......
......@@ -61,6 +61,7 @@ from _pyngraph.op import Constant
"""
Constant.get_data = lambda self: np.array(self, copy=True)
from _pyngraph.op import Clamp
from _pyngraph.op import Convert
from _pyngraph.op import Convolution
from _pyngraph.op import ConvolutionBackpropData
......
......@@ -21,7 +21,7 @@ from ngraph.impl import AxisSet, AxisVector, Coordinate, CoordinateDiff, Functio
Shape, Strides
from ngraph.impl.op import Abs, Acos, Add, And, Asin, ArgMax, ArgMin, Atan, AvgPool, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Concat, Constant, Convert, \
BatchNormTraining, BatchNormInference, Broadcast, Ceiling, Clamp, Concat, Constant, Convert, \
Convolution, ConvolutionBackpropData, Cos, Cosh, Divide, Dot, Elu, Equal, Exp, Floor, \
GetOutputElement, Greater, GreaterEq, Less, LessEq, Log, LRN, Max, Maximum, MaxPool, \
Min, Minimum, Multiply, Negative, Not, NotEqual, OneHot, Or, Pad, Parameter, Product, \
......@@ -555,6 +555,36 @@ def tanh(node, name=None): # type: (Node, str) -> Node
return Tanh(node)
@nameable_op
def clamp(data, min_value, max_value, name=None):
# type: (NodeInput, ScalarData, ScalarData, str) -> Node
"""Perform clamp element-wise on data from input node.
Performs a clipping operation on an input value between a pair of boundary values.
For each element in :code:`data`, if the element's value is lower than :code:`min_value`,
it will be replaced with :code:`min_value`. If the value is higher than :code:`max_value`,
it will be replaced by :code:`max_value`.
Intermediate values of :code:`data` are returned without change.
Clamp uses the following logic:
.. code-block:: python
if data < min_value:
data=min_value
elif data > max_value:
data=max_value
:param data: Input tensor. One of: input node, array or scalar.
:param min_value: The lower bound of the <min_value;max_value> range. Scalar value.
:param max_value: The upper bound of the <min_value;max_value> range. Scalar value.
:param name: Optional output node name.
:return: The new node performing a clamp operation on its input data element-wise.
"""
return Clamp(as_node(data), min_value, max_value)
# matmul ops
@nameable_op
def dot(left_node, right_node, reduction_axes_count=None, name=None):
......
......@@ -17,14 +17,15 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/op/fused/elu.hpp"
#include "pyngraph/ops/elu.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "pyngraph/ops/fused/clamp.hpp"
namespace py = pybind11;
void regclass_pyngraph_op_Elu(py::module m)
void regclass_pyngraph_op_Clamp(py::module m)
{
py::class_<ngraph::op::Elu, std::shared_ptr<ngraph::op::Elu>, ngraph::op::Op> elu(m, "Elu");
elu.doc() = "ngraph.impl.op.Elu wraps ngraph::op::Elu";
elu.def(py::init<const std::shared_ptr<ngraph::Node>&, const std::shared_ptr<ngraph::Node>&>());
py::class_<ngraph::op::Clamp, std::shared_ptr<ngraph::op::Clamp>, ngraph::op::Op> clamp(
m, "Clamp");
clamp.doc() = "ngraph.impl.op.Clamp wraps ngraph::op::Clamp";
clamp.def(py::init<const std::shared_ptr<ngraph::Node>&, const double, const double>());
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_pyngraph_op_Clamp(py::module m);
......@@ -39,6 +39,7 @@ void regmodule_pyngraph_op(py::module m_op)
regclass_pyngraph_op_Broadcast(m_op);
regclass_pyngraph_op_BroadcastDistributed(m_op);
regclass_pyngraph_op_Ceiling(m_op);
regclass_pyngraph_op_Clamp(m_op);
regclass_pyngraph_op_Concat(m_op);
regclass_pyngraph_op_Constant(m_op);
regclass_pyngraph_op_Convert(m_op);
......
......@@ -39,10 +39,11 @@
#include "pyngraph/ops/cosh.hpp"
#include "pyngraph/ops/divide.hpp"
#include "pyngraph/ops/dot.hpp"
#include "pyngraph/ops/elu.hpp"
#include "pyngraph/ops/equal.hpp"
#include "pyngraph/ops/exp.hpp"
#include "pyngraph/ops/floor.hpp"
#include "pyngraph/ops/fused/clamp.hpp"
#include "pyngraph/ops/fused/elu.hpp"
#include "pyngraph/ops/get_output_element.hpp"
#include "pyngraph/ops/greater.hpp"
#include "pyngraph/ops/greater_eq.hpp"
......
......@@ -170,6 +170,7 @@ sources = [
'pyngraph/ops/avg_pool.cpp',
'pyngraph/ops/broadcast.cpp',
'pyngraph/ops/broadcast_distributed.cpp',
'pyngraph/ops/fused/clamp.cpp',
'pyngraph/ops/concat.cpp',
'pyngraph/ops/constant.cpp',
'pyngraph/ops/convert.cpp',
......@@ -179,7 +180,7 @@ sources = [
'pyngraph/ops/ceiling.cpp',
'pyngraph/ops/divide.cpp',
'pyngraph/ops/dot.cpp',
'pyngraph/ops/elu.cpp',
'pyngraph/ops/fused/elu.cpp',
'pyngraph/ops/equal.cpp',
'pyngraph/ops/exp.cpp',
'pyngraph/ops/floor.cpp',
......
......@@ -67,3 +67,36 @@ def test_elu_operator_with_scalar():
result = computation(data_value)
expected = np.array([[-2.9797862, 1.], [-2.5939941, 3.]], dtype=np.float32)
assert np.allclose(result, expected)
def test_clamp_operator():
runtime = get_runtime()
data_shape = [2, 2]
parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
min_value = np.float32(3)
max_value = np.float32(12)
model = ng.clamp(parameter_data, min_value, max_value)
computation = runtime.computation(model, parameter_data)
data_value = np.array([[-5, 9], [45, 3]], dtype=np.float32)
result = computation(data_value)
expected = np.clip(data_value, min_value, max_value)
assert np.allclose(result, expected)
def test_clamp_operator_with_array():
runtime = get_runtime()
data_value = np.array([[-5, 9], [45, 3]], dtype=np.float32)
min_value = np.float32(3)
max_value = np.float32(12)
model = ng.clamp(data_value, min_value, max_value)
computation = runtime.computation(model)
result = computation()
expected = np.clip(data_value, min_value, max_value)
assert np.allclose(result, expected)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment