Commit 3f2cd153 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Robert Kimball

Handle negative axis values in Concat op. (#2252)

parent 9940123b
...@@ -16,11 +16,10 @@ ...@@ -16,11 +16,10 @@
#include <istream> #include <istream>
#include <memory> #include <memory>
#include <string>
#include <vector>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <string>
#include <vector>
#include "ngraph/frontend/onnx_import/onnx.hpp" #include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <cstdint>
#include "concat.hpp" #include "concat.hpp"
#include "exceptions.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -28,9 +32,15 @@ namespace ngraph ...@@ -28,9 +32,15 @@ namespace ngraph
NodeVector concat(const Node& node) NodeVector concat(const Node& node)
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
auto axis = node.get_attribute_value<int64_t>("axis"); std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
size_t valid_axis =
common::convert_negative_axis(axis, inputs.at(0)->get_shape().size());
ASSERT_VALID_ARGUMENT(node, valid_axis >= 0)
<< "Incorrect value of axis attribute: " << axis;
return {std::make_shared<ngraph::op::Concat>(inputs, axis)}; return {std::make_shared<ngraph::op::Concat>(inputs, valid_axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include <cmath> // std::floor #include <cmath> // std::floor, std::min
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include <iterator> // std::begin, std::end #include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared #include <memory> // std::shared_ptr, std::make_shared
...@@ -135,6 +135,29 @@ namespace ngraph ...@@ -135,6 +135,29 @@ namespace ngraph
return node; return node;
} }
/// \brief Handle negative axis value.
///
/// \param[in] axis The requested axis value.
/// \param[in] tensor_dim The corresponding tensor dimensionality.
///
/// \tparam T Provided axis value type.
///
/// \return If negative axis, then return sum of tensor dimension and axis.
///
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
std::int64_t convert_negative_axis(T axis, std::size_t tensor_dim)
{
if (axis >= 0)
{
return std::min(axis, static_cast<T>(tensor_dim));
}
else
{
return static_cast<std::int64_t>(tensor_dim) + axis;
}
}
} // namespace common } // namespace common
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment