Commit 34326357 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Fix shrink operator for uint* types (#3188)

* Fix shrink operator for uint* types

* Add a comment for negative_lambd=0 for uint*
parent 599f0f21
......@@ -43,14 +43,26 @@ namespace ngraph
ASSERT_VALID_ARGUMENT(node, !(lambd < 0.0f))
<< " The provided 'lambd' value:" << lambd << " must not be negative.";
const auto negative_lambd = ngraph::op::Constant::create(
input->get_element_type(), input->get_shape(), {-lambd});
std::shared_ptr<ngraph::op::Constant> negative_lambd;
const auto input_element_type = input->get_element_type();
if (input_element_type.is_signed())
{
negative_lambd = ngraph::op::Constant::create(
input_element_type, input->get_shape(), {-lambd});
}
else
{
// Passing -lambd to unsigned type constant will cause an overflow.
// For unsigned types the lowest possible value is 0.
negative_lambd = ngraph::op::Constant::create(
input_element_type, input->get_shape(), {0});
}
const auto positive_lambd = ngraph::op::Constant::create(
input->get_element_type(), input->get_shape(), {lambd});
input_element_type, input->get_shape(), {lambd});
const auto bias_tensor = ngraph::op::Constant::create(
input->get_element_type(), input->get_shape(), {bias});
input_element_type, input->get_shape(), {bias});
// Create a mask indicating locations of values that need to be adjusted
// by adding and subtracting bias
......@@ -63,9 +75,9 @@ namespace ngraph
// Convert from bool to the input type to be able to multiply adjusted inputs
// by the created masks
values_below_neg_lambd = std::make_shared<ngraph::op::Convert>(
values_below_neg_lambd, input->get_element_type());
values_below_neg_lambd, input_element_type);
values_above_pos_lambd = std::make_shared<ngraph::op::Convert>(
values_above_pos_lambd, input->get_element_type());
values_above_pos_lambd, input_element_type);
std::shared_ptr<ngraph::Node> input_minus_bias = input - bias_tensor;
std::shared_ptr<ngraph::Node> input_plus_bias = input + bias_tensor;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment