Commit a006b5c4 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

[FusedOps] Clamp (#2886)

* Fused Clamp op implementation

* Basic clamp test with some edge cases

* Dump the expected and actual values for NgraphTestCase

* Validate the min and max params for Clamp

* Use clamp in clip

* Disable Clamp and its test on iGPU

* Use getters for Clamp's parameters

* Validate clamp's params in pre_validate_and_infer_types

* Unit tests for clamp op validation

* Revert "Dump the expected and actual values for NgraphTestCase"

This reverts commit 3a029d70e62339ee84aadf2bf16e418281b85ff7.

* Clamp op docs
parent 421311ed
......@@ -274,6 +274,8 @@ set (SRC
op/tanh.hpp
op/topk.cpp
op/topk.hpp
op/fused/clamp.cpp
op/fused/clamp.hpp
op/fused/conv_fused.cpp
op/fused/conv_fused.hpp
op/fused/hard_sigmoid.cpp
......
......@@ -18,12 +18,7 @@
#include <memory>
#include "clip.hpp"
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/fused/clamp.hpp"
namespace ngraph
{
......@@ -35,30 +30,15 @@ namespace ngraph
{
NodeVector clip(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
const auto data = node.get_ng_inputs().at(0);
double max_value =
const double max_value =
node.get_attribute_value<double>("max", std::numeric_limits<double>::max());
double min_value = node.get_attribute_value<double>(
"min", std::numeric_limits<double>::lowest());
std::shared_ptr<ngraph::Node> max_value_node =
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{},
std::vector<double>{max_value});
max_value_node =
ngraph::op::make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
ngraph::Shape{},
std::vector<double>{min_value});
min_value_node =
ngraph::op::make_broadcast_node(min_value_node, data->get_shape());
const double min_value = node.get_attribute_value<double>(
"min", std::numeric_limits<double>::lowest());
return {std::make_shared<ngraph::op::Minimum>(
max_value_node,
std::make_shared<ngraph::op::Maximum>(data, min_value_node))};
return {std::make_shared<ngraph::op::Clamp>(data, min_value, max_value)};
}
} // namespace set_1
......
......@@ -95,6 +95,7 @@
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
using namespace std;
using namespace ngraph;
op::Clamp::Clamp(const shared_ptr<Node>& data, const double min, const double max)
: FusedOp("Clamp", {data})
, m_min{min}
, m_max{max}
{
constructor_validate_and_infer_types();
}
void op::Clamp::pre_validate_and_infer_types()
{
NODE_VALIDATION_CHECK(
this, m_min < m_max, "The 'min' parameter needs to be less than 'max' for Clamp");
}
NodeVector op::Clamp::decompose_op() const
{
const auto data = get_argument(0);
const auto data_shape = data->get_shape();
const auto clamp_min = builder::make_constant(data->get_element_type(), data_shape, m_min);
const auto clamp_max = builder::make_constant(data->get_element_type(), data_shape, m_max);
return {std::make_shared<ngraph::op::Minimum>(
clamp_max, std::make_shared<ngraph::op::Maximum>(clamp_min, data))};
}
shared_ptr<Node> op::Clamp::copy_with_new_args(const NodeVector& new_args) const
{
NODE_VALIDATION_CHECK(this,
new_args.size() == 1,
"Expected 1 element in new_args for the Clamp op but got ",
new_args.size());
return make_shared<Clamp>(new_args.at(0), m_min, m_max);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
{
namespace op
{
/// \brief Performs a clipping operation on all elements of the input node
///
/// All input values that are outside of the <min;max> range are set to 'min' or 'max'
/// depending on which side of the <min;max> range they are. The values that fall into
/// this range remain unchanged.
class Clamp : public ngraph::op::util::FusedOp
{
public:
/// \brief Constructs a Clamp node.
///
/// \param data - Node producing the input tensor
/// \param min - the lower bound of the <min;max> range
/// \param max - the upper bound of the <min;max> range
Clamp(const std::shared_ptr<ngraph::Node>& data, const double min, const double max);
void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
double get_min() const { return m_min; }
double get_max() const { return m_max; }
private:
const double m_min;
const double m_max;
};
}
}
......@@ -20,6 +20,7 @@
NGRAPH_OP(Elu, ngraph::op)
NGRAPH_OP(Gemm, ngraph::op)
NGRAPH_OP(PRelu, ngraph::op)
NGRAPH_OP(Clamp, ngraph::op)
NGRAPH_OP(ConvolutionBias, ngraph::op)
NGRAPH_OP(ConvolutionBiasAdd, ngraph::op)
NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
......
......@@ -1977,6 +1977,7 @@ shared_ptr<runtime::Executable>
case OP_TYPEID::BatchMatMul:
case OP_TYPEID::BroadcastDistributed:
case OP_TYPEID::BroadcastLike:
case OP_TYPEID::Clamp:
case OP_TYPEID::DepthToSpace:
case OP_TYPEID::DynBroadcast:
case OP_TYPEID::DynPad:
......
......@@ -48,7 +48,7 @@ pad_reflect_2d_with_neg
# Not implemented
erf
zero_sized_erf
fused_clamp
gather_no_axis
gather
gather_nd_scalar_from_2d
......@@ -67,3 +67,4 @@ gather_nd_single_indices
gemm
gemm_broadcast_input_C
hardsigmoid
zero_sized_erf
......@@ -66,6 +66,7 @@
#include "ngraph/op/experimental/tile.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/fused/conv_fused.hpp"
#include "ngraph/op/fused/depth_to_space.hpp"
#include "ngraph/op/fused/elu.hpp"
......@@ -650,6 +651,13 @@ static shared_ptr<ngraph::Function>
node = make_shared<op::Ceiling>(args[0]);
break;
}
case OP_TYPEID::Clamp:
{
const auto clamp_min = node_js.at("min").get<float>();
const auto clamp_max = node_js.at("max").get<float>();
node = make_shared<op::Clamp>(args[0], clamp_min, clamp_max);
break;
}
case OP_TYPEID::Concat:
{
auto axis = node_js.at("axis").get<size_t>();
......@@ -1672,6 +1680,13 @@ static json write(const Node& n, bool binary_constant_data)
}
case OP_TYPEID::Ceiling: { break;
}
case OP_TYPEID::Clamp:
{
auto tmp = dynamic_cast<const op::Clamp*>(&n);
node["min"] = tmp->get_min();
node["max"] = tmp->get_max();
break;
}
case OP_TYPEID::Concat:
{
auto tmp = dynamic_cast<const op::Concat*>(&n);
......
......@@ -418,5 +418,49 @@ NGRAPH_TEST(${BACKEND_NAME}, gemm_broadcast_input_C)
test_case.add_input<double>(vector<double>{1});
//output
test_case.add_expected_output<double>(Shape{3, 4}, vector<double>(12, 7));
}
NGRAPH_TEST(${BACKEND_NAME}, fused_clamp)
{
auto data = make_shared<op::Parameter>(element::f64, Shape{4, 4});
auto tested_op = make_shared<op::Clamp>(data, 10.0, 20.0);
auto function = make_shared<Function>(tested_op, ParameterVector{data});
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<double>({std::numeric_limits<double>::min(),
std::numeric_limits<double>::max(),
-std::numeric_limits<double>::infinity(),
std::numeric_limits<double>::infinity(),
-1.0,
0.0,
1.0,
9.99999,
10.0,
10.0000001,
15.0,
19.9999999,
20.0,
20.0000001,
21.0,
100.0});
test_case.add_expected_output<double>(Shape{4, 4},
{10.0,
20.0,
10.0,
20.0,
10.0,
10.0,
10.0,
10.0,
10.0,
10.0000001,
15.0,
19.9999999,
20.0,
20.0,
20.0,
20.0});
test_case.run();
}
......@@ -13888,3 +13888,24 @@ TEST(type_prop, gemm_broadcast_input_C)
EXPECT_EQ(gemm_func->get_element_type(), element::f32);
EXPECT_EQ(gemm_func->get_shape(), (Shape{3, 4}));
}
TEST(type_prop, fused_clamp)
{
const auto data = make_shared<op::Parameter>(element::f64, Shape{2, 2});
try
{
const auto clamp = make_shared<op::Clamp>(data, 2.0, 1.0);
EXPECT_FALSE(clamp.get())
<< "Clamp validation did not work. Op node was created with incorrect params.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(), std::string("The 'min' parameter needs to be less than 'max' for Clamp"));
}
const auto clamp = make_shared<op::Clamp>(data, 1.0, 2.0);
EXPECT_EQ(clamp->get_element_type(), element::f64);
EXPECT_EQ(clamp->get_shape(), (Shape{2, 2}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment