Commit 3f2cd153 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Robert Kimball

Handle negative axis values in Concat op. (#2252)

parent 9940123b
......@@ -16,11 +16,10 @@
#include <istream>
#include <memory>
#include <string>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include <vector>
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/function.hpp"
......
......@@ -14,8 +14,12 @@
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include "concat.hpp"
#include "exceptions.hpp"
#include "ngraph/op/concat.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -28,9 +32,15 @@ namespace ngraph
NodeVector concat(const Node& node)
{
NodeVector inputs{node.get_ng_inputs()};
auto axis = node.get_attribute_value<int64_t>("axis");
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
size_t valid_axis =
common::convert_negative_axis(axis, inputs.at(0)->get_shape().size());
ASSERT_VALID_ARGUMENT(node, valid_axis >= 0)
<< "Incorrect value of axis attribute: " << axis;
return {std::make_shared<ngraph::op::Concat>(inputs, axis)};
return {std::make_shared<ngraph::op::Concat>(inputs, valid_axis)};
}
} // namespace set_1
......
......@@ -16,7 +16,7 @@
#pragma once
#include <cmath> // std::floor
#include <cmath> // std::floor, std::min
#include <cstddef> // std::size_t
#include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared
......@@ -135,6 +135,29 @@ namespace ngraph
return node;
}
/// \brief Handle negative axis value.
///
/// \param[in] axis The requested axis value.
/// \param[in] tensor_dim The corresponding tensor dimensionality.
///
/// \tparam T Provided axis value type.
///
/// \return If negative axis, then return sum of tensor dimension and axis.
///
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
std::int64_t convert_negative_axis(T axis, std::size_t tensor_dim)
{
if (axis >= 0)
{
return std::min(axis, static_cast<T>(tensor_dim));
}
else
{
return static_cast<std::int64_t>(tensor_dim) + axis;
}
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment