Commit f07b95a2 authored by Michał Karzyński's avatar Michał Karzyński Committed by Robert Kimball

[ONNX] Add interpret_as_scalar helper (#2825)

* Add interpret_as_scalar helper

* Add interpret_as_scalar to OneHot

* clang-format

* Remove interpret_as_scalar from OneHot

We currently only support OneHot with constant depth

* Review comments
parent fb0ae59c
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "onehot.hpp" #include "onehot.hpp"
#include "utils/reshape.hpp"
namespace ngraph namespace ngraph
{ {
......
...@@ -201,6 +201,24 @@ namespace ngraph ...@@ -201,6 +201,24 @@ namespace ngraph
return split(node, length_parts, axis_to_split); return split(node, length_parts, axis_to_split);
} }
std::shared_ptr<ngraph::Node>
interpret_as_scalar(const std::shared_ptr<ngraph::Node>& node)
{
Shape node_shape = node->get_shape();
// If node is already a scalar, return original
if (node_shape.empty())
{
return node;
}
NGRAPH_CHECK((shape_size(node_shape) == 1),
"Scalar value can't be derived from a node with ",
node_shape);
return ngraph::op::util::reshape(node, Shape{});
}
} // namespace reshape } // namespace reshape
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -115,6 +115,19 @@ namespace ngraph ...@@ -115,6 +115,19 @@ namespace ngraph
std::size_t split_parts, std::size_t split_parts,
int axis = 0); int axis = 0);
/// \brief Handle a node which represents a scalar value.
///
/// \note Some ONNX nodes, which should provide scalar values are given as
/// tensors of shape {1}. This function will provide a reshape of
/// such a node with Shape{1} into a scalar with Shape{}.
///
/// \param[in] node Node to reshape.
///
/// \return Original node or a node representing a reshape of the original.
///
std::shared_ptr<ngraph::Node>
interpret_as_scalar(const std::shared_ptr<ngraph::Node>& node);
} // namespace reshape } // namespace reshape
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment