Commit 34326357 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[ONNX] Fix shrink operator for uint* types (#3188)

* Fix shrink operator for uint* types

* Add a comment for negative_lambd=0 for uint*
parent 599f0f21
...@@ -43,14 +43,26 @@ namespace ngraph ...@@ -43,14 +43,26 @@ namespace ngraph
ASSERT_VALID_ARGUMENT(node, !(lambd < 0.0f)) ASSERT_VALID_ARGUMENT(node, !(lambd < 0.0f))
<< " The provided 'lambd' value:" << lambd << " must not be negative."; << " The provided 'lambd' value:" << lambd << " must not be negative.";
const auto negative_lambd = ngraph::op::Constant::create( std::shared_ptr<ngraph::op::Constant> negative_lambd;
input->get_element_type(), input->get_shape(), {-lambd}); const auto input_element_type = input->get_element_type();
if (input_element_type.is_signed())
{
negative_lambd = ngraph::op::Constant::create(
input_element_type, input->get_shape(), {-lambd});
}
else
{
// Passing -lambd to unsigned type constant will cause an overflow.
// For unsigned types the lowest possible value is 0.
negative_lambd = ngraph::op::Constant::create(
input_element_type, input->get_shape(), {0});
}
const auto positive_lambd = ngraph::op::Constant::create( const auto positive_lambd = ngraph::op::Constant::create(
input->get_element_type(), input->get_shape(), {lambd}); input_element_type, input->get_shape(), {lambd});
const auto bias_tensor = ngraph::op::Constant::create( const auto bias_tensor = ngraph::op::Constant::create(
input->get_element_type(), input->get_shape(), {bias}); input_element_type, input->get_shape(), {bias});
// Create a mask indicating locations of values that need to be adjusted // Create a mask indicating locations of values that need to be adjusted
// by adding and subtracting bias // by adding and subtracting bias
...@@ -63,9 +75,9 @@ namespace ngraph ...@@ -63,9 +75,9 @@ namespace ngraph
// Convert from bool to the input type to be able to multiply adjusted inputs // Convert from bool to the input type to be able to multiply adjusted inputs
// by the created masks // by the created masks
values_below_neg_lambd = std::make_shared<ngraph::op::Convert>( values_below_neg_lambd = std::make_shared<ngraph::op::Convert>(
values_below_neg_lambd, input->get_element_type()); values_below_neg_lambd, input_element_type);
values_above_pos_lambd = std::make_shared<ngraph::op::Convert>( values_above_pos_lambd = std::make_shared<ngraph::op::Convert>(
values_above_pos_lambd, input->get_element_type()); values_above_pos_lambd, input_element_type);
std::shared_ptr<ngraph::Node> input_minus_bias = input - bias_tensor; std::shared_ptr<ngraph::Node> input_minus_bias = input - bias_tensor;
std::shared_ptr<ngraph::Node> input_plus_bias = input + bias_tensor; std::shared_ptr<ngraph::Node> input_plus_bias = input + bias_tensor;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment