Commit 250dddbc authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

[ONNX] Shrink op support (#3024)

* Initial implementation of the Shrink op

* Multiply the values by the correct masks

* Basic test case for Shrink with floats

* Shrink test on integers

* Code formatting

* Shrink documentation and typo fix

* Rephrase the Shrink docs

* Out of <memory> ;)
parent f21db619
...@@ -150,6 +150,8 @@ add_library(onnx_import STATIC ...@@ -150,6 +150,8 @@ add_library(onnx_import STATIC
op/selu.hpp op/selu.hpp
op/shape.hpp op/shape.hpp
op/shape.cpp op/shape.cpp
op/shrink.hpp
op/shrink.cpp
op/sigmoid.hpp op/sigmoid.hpp
op/sign.hpp op/sign.hpp
op/sin.hpp op/sin.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "exceptions.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "shrink.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector shrink(const Node& node)
{
const auto input = node.get_ng_inputs().at(0);
const float bias = node.get_attribute_value<float>("bias", 0.0f);
const float lambd = node.get_attribute_value<float>("lambd", 0.5f);
ASSERT_VALID_ARGUMENT(node, !(lambd < 0.0f))
<< " The provided 'lambd' value:" << lambd << " must not be negative.";
const auto negative_lambd = ngraph::op::Constant::create(
input->get_element_type(), input->get_shape(), {-lambd});
const auto positive_lambd = ngraph::op::Constant::create(
input->get_element_type(), input->get_shape(), {lambd});
const auto bias_tensor = ngraph::op::Constant::create(
input->get_element_type(), input->get_shape(), {bias});
// Create a mask indicating locations of values that need to be adjusted
// by adding and subtracting bias
// All other values indicated by 'false' in the masks need to be zeroed out
std::shared_ptr<ngraph::Node> values_below_neg_lambd =
std::make_shared<ngraph::op::Less>(input, negative_lambd);
std::shared_ptr<ngraph::Node> values_above_pos_lambd =
std::make_shared<ngraph::op::Greater>(input, positive_lambd);
// Convert from bool to the input type to be able to multiply adjusted inputs
// by the created masks
values_below_neg_lambd = std::make_shared<ngraph::op::Convert>(
values_below_neg_lambd, input->get_element_type());
values_above_pos_lambd = std::make_shared<ngraph::op::Convert>(
values_above_pos_lambd, input->get_element_type());
std::shared_ptr<ngraph::Node> input_minus_bias = input - bias_tensor;
std::shared_ptr<ngraph::Node> input_plus_bias = input + bias_tensor;
// multiply by the corresponding mask to zero-out the values within
// the <-lambd;lambd> range and keep the bias-adjusted values from outside of it
input_minus_bias = values_above_pos_lambd * input_minus_bias;
input_plus_bias = values_below_neg_lambd * input_plus_bias;
return {input_plus_bias + input_minus_bias};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
/// @brief ONNX Shrink operator
///
/// @note It operates on a single input tensor and two attributes: lambd and bias.
/// Input values greater or equal to '-lambd' and less or equal to 'lambd' are zeroed-out.
/// 'Bias' is added to the values that are less than '-lambd'
/// and subtracted from values greater than 'lambd'.
NodeVector shrink(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -84,6 +84,7 @@ opset versions starting from `1` to `6` and to the latest opset version. ...@@ -84,6 +84,7 @@ opset versions starting from `1` to `6` and to the latest opset version.
| Relu | 1-6- | | Relu | 1-6- |
| Selu | 1-6- | | Selu | 1-6- |
| Shape | 1- | | Shape | 1- |
| Shrink | 1- |
| Sigmoid | 1-6- | | Sigmoid | 1-6- |
| Sign | 9- | | Sign | 9- |
| Sin | 7- | | Sin | 7- |
......
...@@ -94,6 +94,7 @@ ...@@ -94,6 +94,7 @@
#include "op/reshape.hpp" #include "op/reshape.hpp"
#include "op/selu.hpp" #include "op/selu.hpp"
#include "op/shape.hpp" #include "op/shape.hpp"
#include "op/shrink.hpp"
#include "op/sigmoid.hpp" #include "op/sigmoid.hpp"
#include "op/sign.hpp" #include "op/sign.hpp"
#include "op/sin.hpp" #include "op/sin.hpp"
...@@ -311,6 +312,7 @@ namespace ngraph ...@@ -311,6 +312,7 @@ namespace ngraph
REGISTER_OPERATOR("Reshape", 1, reshape); REGISTER_OPERATOR("Reshape", 1, reshape);
REGISTER_OPERATOR("Selu", 1, selu); REGISTER_OPERATOR("Selu", 1, selu);
REGISTER_OPERATOR("Shape", 1, shape); REGISTER_OPERATOR("Shape", 1, shape);
REGISTER_OPERATOR("Shrink", 1, shrink);
REGISTER_OPERATOR("Sigmoid", 1, sigmoid); REGISTER_OPERATOR("Sigmoid", 1, sigmoid);
REGISTER_OPERATOR("Sign", 1, sign); REGISTER_OPERATOR("Sign", 1, sign);
REGISTER_OPERATOR("Sin", 1, sin); REGISTER_OPERATOR("Sin", 1, sin);
......
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "Shrink"
attribute {
name: "lambd"
f: 1.5
type: FLOAT
}
attribute {
name: "bias"
f: 0.5
type: FLOAT
}
}
name: "shrink_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 11
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 11
}
}
}
}
}
}
opset_import {
version: 9
}
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "Shrink"
attribute {
name: "lambd"
f: 1.4
type: FLOAT
}
attribute {
name: "bias"
f: 1.5
type: FLOAT
}
}
name: "shrink_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 11
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 11
}
}
}
}
}
}
opset_import {
version: 9
}
...@@ -1456,3 +1456,29 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_hardmax) ...@@ -1456,3 +1456,29 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_hardmax)
test_case.run(); test_case.run();
} }
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_float)
{
const auto shrink_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/shrink_float.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(shrink_fn, "${BACKEND_NAME}");
test_case.add_input<float>(
{-2.0f, -1.6f, -1.5f, -1.4f, -1.0f, 0.0f, 1.0f, 1.4f, 1.5f, 1.6f, 2.0f});
test_case.add_expected_output<float>(
Shape{11}, {-1.5f, -1.1f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.1f, 1.5f});
test_case.run();
}
NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int)
{
const auto shrink_fn = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/shrink_int.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(shrink_fn, "${BACKEND_NAME}");
test_case.add_input<int>({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5});
test_case.add_expected_output<int>(Shape{11}, {-4, -3, -2, -1, 0, 0, 0, 1, 2, 3, 4});
test_case.run();
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment