Commit 0c181e9d authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[ONNX] Add support for negative axes (#3643)

parent 1a5288ab
......@@ -33,12 +33,8 @@ namespace ngraph
{
NodeVector inputs{node.get_ng_inputs()};
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
size_t valid_axis =
common::convert_negative_axis(axis, inputs.at(0)->get_shape().size());
ASSERT_VALID_ARGUMENT(node, valid_axis >= 0)
<< "Incorrect value of axis attribute: " << axis;
common::validate_axis(node, axis, inputs.at(0)->get_shape().size());
return {std::make_shared<ngraph::op::Concat>(inputs, valid_axis)};
}
......
......@@ -19,7 +19,7 @@
#include "exceptions.hpp"
#include "flatten.hpp"
#include "ngraph/builder/reshape.hpp"
#include "utils/common.hpp"
namespace ngraph
{
namespace onnx_import
......@@ -33,11 +33,12 @@ namespace ngraph
NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
auto data_rank = data->get_shape().size();
// Accepted range is [-r, r] where r = rank(input).
auto valid_axis =
common::validate_axis(node, axis, data_rank, -data_rank, data_rank);
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
<< "provided 'axis' attribute is not valid.";
return {ngraph::builder::flatten(data, axis)};
return {ngraph::builder::flatten(data, valid_axis)};
}
} // namespace set_1
......
......@@ -19,6 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/gather.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -34,12 +35,9 @@ namespace ngraph
auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0);
if (axis < 0)
{
axis += data->get_shape().size();
}
auto valid_axis = common::validate_axis(node, axis, data->get_shape().size());
return {std::make_shared<ngraph::op::Gather>(data, indices, axis)};
return {std::make_shared<ngraph::op::Gather>(data, indices, valid_axis)};
}
} // namespace set_1
......
......@@ -35,12 +35,10 @@ namespace ngraph
const auto& input_shape = input->get_shape();
auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis < input_shape.size())
<< "The provided axis value " << axis
<< " does not match the input tensor dimensions";
auto valid_axis = common::validate_axis(node, axis, input_shape.size());
// reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor = ngraph::builder::flatten(input, axis);
const auto coerced_tensor = ngraph::builder::flatten(input, valid_axis);
const auto& coerced_shape = coerced_tensor->get_shape();
const std::shared_ptr<ngraph::Node> argmax_2d =
......
......@@ -23,6 +23,7 @@
#include "ngraph/builder/norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/divide.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -37,17 +38,14 @@ namespace ngraph
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
if (axis < 0)
{
axis += data->get_shape().size();
}
std::size_t valid_axis =
common::validate_axis(node, axis, data->get_shape().size());
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm
<< "Only normalization of 1st or 2nd order is supported.";
const AxisSet reduction_axes{static_cast<std::size_t>(axis)};
const AxisSet reduction_axes{valid_axis};
std::shared_ptr<ngraph::Node> norm = ngraph::builder::lp_norm(
data, reduction_axes, static_cast<std::size_t>(p_norm));
norm = std::make_shared<ngraph::op::Broadcast>(
......
......@@ -19,6 +19,7 @@
#include "mean_variance_normalization.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -47,9 +48,11 @@ namespace ngraph
NodeVector mean_variance_normalization(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<size_t>>("axes", {0, 2, 3});
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes", {0, 2, 3});
std::vector<std::size_t> valid_axes =
common::validate_axes(node, axes, data->get_shape().size());
return {std::make_shared<ngraph::op::MVN>(data, AxisSet(axes))};
return {std::make_shared<ngraph::op::MVN>(data, AxisSet(valid_axes))};
}
} // namespace set_9
......
......@@ -28,6 +28,7 @@
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "onehot.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -51,14 +52,13 @@ namespace ngraph
std::make_shared<ngraph::op::Slice>(values, Coordinate{1}, Coordinate{2});
auto axis = node.get_attribute_value<std::int64_t>("axis", -1);
if (axis < 0)
{
axis += indices_shape.size() + 1;
}
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= indices_shape.size()))
<< "invalid 'axis' attribute: "
<< node.get_attribute_value<std::int64_t>("axis", -1);
// Accepted range for axis is [-r-1, r] where r = rank(indices). Validate
// against rank+1.
std::size_t valid_axis = common::validate_axis(node,
axis,
indices_shape.size() + 1,
-indices_shape.size() - 1,
indices_shape.size());
auto constant_depth = std::dynamic_pointer_cast<ngraph::op::Constant>(depth);
......@@ -74,10 +74,11 @@ namespace ngraph
// axis = 1
// depth = 10
// output_shape = (2, 10, 2)
output_shape.insert(std::next(std::begin(output_shape), axis), depth_value);
output_shape.insert(std::next(std::begin(output_shape), valid_axis),
depth_value);
std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis),
std::make_shared<ngraph::op::OneHot>(indices, output_shape, valid_axis),
values->get_element_type());
auto broadcasted_values =
ngraph::op::numpy_style_broadcast({one_hot, on_value, off_value});
......
......@@ -21,6 +21,7 @@
#include "ngraph/op/convert.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -40,22 +41,26 @@ namespace ngraph
node.get_ng_inputs().at(1), element::i32);
const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1);
std::size_t valid_batch_axis =
common::validate_axis(node, batch_axis, data->get_shape().size());
const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0);
std::size_t valid_time_axis =
common::validate_axis(node, time_axis, data->get_shape().size());
NGRAPH_CHECK(batch_axis == 0 || batch_axis == 1,
NGRAPH_CHECK(valid_batch_axis == 0 || valid_batch_axis == 1,
"Allowed values of the 'batch_axis' attribute for ReverseSequence "
"operator are 0 and 1");
NGRAPH_CHECK(time_axis == 0 || time_axis == 1,
NGRAPH_CHECK(valid_time_axis == 0 || valid_time_axis == 1,
"Allowed values of the 'time_axis' attribute for ReverseSequence "
"operator are 0 and 1");
NGRAPH_CHECK(batch_axis != time_axis,
NGRAPH_CHECK(valid_batch_axis != valid_time_axis,
"'batch_axis' and 'time_axis' attributes of the ReverseSequence "
"operator can't point to the same dimension");
return {std::make_shared<ngraph::op::ReverseSequence>(
data, sequence_lengths_i32, batch_axis, time_axis)};
data, sequence_lengths_i32, valid_batch_axis, valid_time_axis)};
}
} // namespace set_1
......
......@@ -19,6 +19,7 @@
#include "exceptions.hpp"
#include "ngraph/op/softmax.hpp"
#include "softmax.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -35,22 +36,13 @@ namespace ngraph
auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1);
if (axis < 0)
{
axis = data_shape.size() + axis;
}
ASSERT_VALID_ARGUMENT(node, axis < data_shape.size())
<< "provided 'axis' value:" << axis
<< " is out of input tensor dimensions range.";
std::size_t valid_axis = common::validate_axis(node, axis, data_shape.size());
// create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
std::vector<size_t> axes(data_shape.size() - valid_axis);
std::iota(std::begin(axes), std::end(axes), valid_axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}
} // namespace set_1
} // namespace op
......
......@@ -19,6 +19,7 @@
#include "ngraph/op/fused/split.hpp"
#include "op/split.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -33,13 +34,15 @@ namespace ngraph
const auto input = node.get_ng_inputs().at(0);
const auto outputs_number = node.get_output_names().size();
const auto axis = node.get_attribute_value<int64_t>("axis", 0);
std::size_t valid_axis =
common::validate_axis(node, axis, input->get_shape().size());
try
{
const auto length_parts =
node.get_attribute_value<std::vector<std::size_t>>("split");
const auto fused_split =
std::make_shared<ngraph::op::Split>(input, axis, length_parts);
std::make_shared<ngraph::op::Split>(input, valid_axis, length_parts);
return fused_split->decompose_op();
}
......@@ -49,7 +52,7 @@ namespace ngraph
// the 'split' attribute - this means we should split the input tensor
// into same-length parts equal to the number of node outputs
const auto fused_split =
std::make_shared<ngraph::op::Split>(input, axis, outputs_number);
std::make_shared<ngraph::op::Split>(input, valid_axis, outputs_number);
return fused_split->decompose_op();
}
......
......@@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "squeeze.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -32,17 +33,12 @@ namespace ngraph
NodeVector squeeze(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
for (auto axis : axes)
{
ASSERT_VALID_ARGUMENT(node, axis >= 0)
<< "provided axes attribute is invalid. Only non-negative "
<< "integers are allowed, got " << axis << ".";
}
std::vector<std::int64_t> axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> valid_axes =
common::validate_axes(node, axes, data->get_shape().size());
auto axes_node = std::make_shared<ngraph::op::Constant>(
element::u64, Shape{axes.size()}, axes);
element::u64, Shape{valid_axes.size()}, valid_axes);
return {std::make_shared<ngraph::op::Squeeze>(data, axes_node)};
}
......
......@@ -23,6 +23,7 @@
#include "ngraph/op/topk.hpp"
#include "ngraph/type/element_type.hpp"
#include "topk.hpp"
#include "utils/common.hpp"
namespace ngraph
{
......@@ -35,21 +36,14 @@ namespace ngraph
NodeVector topk(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
std::int64_t k{node.get_attribute_value<std::int64_t>("k")};
auto num_dimensions = data->get_shape().size();
if (axis < 0)
{
axis += num_dimensions;
}
ASSERT_VALID_ARGUMENT(node, axis < num_dimensions)
<< "`axis` parameter is out of range: " << axis;
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
std::int64_t valid_axis = common::validate_axis(node, axis, num_dimensions);
std::shared_ptr<ngraph::Node> top_k =
std::make_shared<ngraph::op::TopK>(data, axis, element::i64, k);
std::make_shared<ngraph::op::TopK>(data, valid_axis, element::i64, k);
std::shared_ptr<ngraph::Node> indices =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 0);
......
......@@ -46,6 +46,53 @@ namespace ngraph
static_cast<onnx::TensorProto_DataType>(onnx_type)));
}
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank)
{
// Accepted range of value for axis is [-tensor_rank, tensor_rank-1].
return validate_axis(node, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
}
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
// Accepted range of value for axis is [axis_range_min, axis_range_max].
NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)),
node.get_description(),
"Parameter axis ",
axis,
" out of the tensor rank [-",
axis_range_min,
", ",
axis_range_max,
"].");
if (axis < 0)
{
axis = axis + tensor_rank;
}
return static_cast<size_t>(axis);
}
std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
std::vector<std::int64_t> axes,
std::int64_t tensor_rank)
{
std::vector<std::size_t> new_axes;
for (auto a : axes)
{
new_axes.push_back(validate_axis(node, a, tensor_rank));
}
return new_axes;
}
} // namespace common
} // namespace onnx_import
} // namespace ngraph
......@@ -25,6 +25,7 @@
#include <type_traits> // std::enable_if
#include <vector>
#include "core/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
......@@ -67,28 +68,52 @@ namespace ngraph
return range;
}
/// \brief Handle negative axis value.
/// \brief Handle out of range axis.
///
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_dim The corresponding tensor dimensionality.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \tparam T Provided axis value type.
/// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error.
/// If negative axis, it counts from the last to the first axis, by adding
/// tensor_rank to axis.
///
/// \return If negative axis, then return sum of tensor dimension and axis.
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank);
/// \brief Handle out of range axis.
///
template <typename T,
typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
std::int64_t convert_negative_axis(T axis, std::size_t tensor_dim)
{
if (axis >= 0)
{
return std::min(axis, static_cast<T>(tensor_dim));
}
else
{
return static_cast<std::int64_t>(tensor_dim) + axis;
}
}
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
/// \param[in] axis_range_min The min value of accepted range for axis.
/// \param[in] axis_range_max The max value of accepted range for axis.
///
/// \return Checking if axis is in range [axis_range_min, axis_range_max], otherwise
/// returns error.
//// If negative axis, it counts from the last to the first axis, by adding
/// tensor_rank to axis.
///
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max);
/// \brief Handle out of range axes in vector.
///
/// \param[in] node The node with requested axes.
/// \param[in] axes The requested vector of axes.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return If any negative axis in vector, it counts from the last to the first
/// axis, by adding tensor_rank to axis.
///
std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
std::vector<std::int64_t> axes,
std::int64_t tensor_rank);
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
......
......@@ -32,13 +32,17 @@ namespace ngraph
AxisSet get_reduction_axes(const Node& node)
{
auto reduction_axes =
node.get_attribute_value<std::vector<std::size_t>>("axes", {});
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> valid_reduction_axes = common::validate_axes(
node, reduction_axes, node.get_ng_inputs().at(0)->get_shape().size());
if (reduction_axes.empty())
{
reduction_axes = onnx_import::common::get_monotonic_range<std::size_t>(
valid_reduction_axes =
onnx_import::common::get_monotonic_range<std::size_t>(
node.get_ng_inputs().at(0)->get_shape().size());
}
return AxisSet{reduction_axes};
return AxisSet{valid_reduction_axes};
}
} // namespace detail
......
......@@ -26,6 +26,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
namespace ngraph
......@@ -64,8 +65,10 @@ namespace ngraph
auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
auto keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
auto input_node = node.get_ng_inputs().at(0);
auto valid_axis = common::validate_axis(node, axis, input_node->get_shape().size());
auto op_node = std::make_shared<IndexReduction>(input_node, axis, element::i64);
auto op_node =
std::make_shared<IndexReduction>(input_node, valid_axis, element::i64);
if (keepdims == 0)
{
......@@ -76,7 +79,7 @@ namespace ngraph
auto convert_node = std::make_shared<ngraph::op::Convert>(op_node, element::f32);
auto output_shape = input_node->get_shape();
output_shape.at(axis) = 1;
output_shape.at(valid_axis) = 1;
auto reshape_node = std::make_shared<ngraph::op::Reshape>(
convert_node,
ngraph::get_default_order(op_node->get_shape().size()),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment