Commit aa9a2b91 authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[ONNX] Add support for changed clip op in set 11. (#3829)

* Added support for ONNX's clip op in set_11.

* Registered clip in set_11.

* Code formatting.

* Added reshape.

* Changed to auto_braodcast.

* Code refactoring.

* Removed unnecessary import.
parent a5908869
......@@ -18,7 +18,12 @@
#include <memory>
#include "clip.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/reshape.hpp"
namespace ngraph
{
......@@ -43,6 +48,54 @@ namespace ngraph
} // namespace set_1
namespace set_11
{
NodeVector clip(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> data = inputs.at(0);
const element::Type data_type = data->get_element_type();
const Shape data_shape = data->get_shape();
std::shared_ptr<ngraph::Node> min;
std::shared_ptr<ngraph::Node> max;
// If second input is provided, assign to min input, otherwise set lowest
// numeric limit of double as min input.
if (inputs.size() > 1 && !inputs.at(1)->is_null())
{
min = inputs.at(1);
}
else
{
min = builder::make_constant(
data_type, data_shape, std::numeric_limits<double>::lowest());
}
// If third input is provided, assign to max input, otherwise set maximum
// numeric limit of double as max input.
if (inputs.size() == 3 && !inputs.at(2)->is_null())
{
max = inputs.at(2);
}
else
{
max = builder::make_constant(
data_type, data_shape, std::numeric_limits<double>::max());
}
auto max_of_min_and_data = std::make_shared<ngraph::op::Maximum>(
min,
data,
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY));
return {std::make_shared<ngraph::op::Minimum>(
max,
max_of_min_and_data,
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
}
} // namespace set_11
} // namespace op
} // namespace onnx_import
......
......@@ -31,6 +31,12 @@ namespace ngraph
} // namespace set_1
namespace set_11
{
NodeVector clip(const Node& node);
} // namespace set_11
} // namespace op
} // namespace onnx_import
......
......@@ -249,6 +249,7 @@ namespace ngraph
REGISTER_OPERATOR("Cast", 1, cast);
REGISTER_OPERATOR("Ceil", 1, ceil);
REGISTER_OPERATOR("Clip", 1, clip);
REGISTER_OPERATOR("Clip", 11, clip);
REGISTER_OPERATOR("Concat", 1, concat);
REGISTER_OPERATOR("Constant", 1, constant);
REGISTER_OPERATOR("Conv", 1, conv);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment