Commit 0c181e9d authored by Ewa Tusień's avatar Ewa Tusień Committed by Scott Cyphers

[ONNX] Add support for negative axes (#3643)

parent 1a5288ab
...@@ -33,12 +33,8 @@ namespace ngraph ...@@ -33,12 +33,8 @@ namespace ngraph
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
std::int64_t axis = node.get_attribute_value<std::int64_t>("axis"); std::int64_t axis = node.get_attribute_value<std::int64_t>("axis");
size_t valid_axis = size_t valid_axis =
common::convert_negative_axis(axis, inputs.at(0)->get_shape().size()); common::validate_axis(node, axis, inputs.at(0)->get_shape().size());
ASSERT_VALID_ARGUMENT(node, valid_axis >= 0)
<< "Incorrect value of axis attribute: " << axis;
return {std::make_shared<ngraph::op::Concat>(inputs, valid_axis)}; return {std::make_shared<ngraph::op::Concat>(inputs, valid_axis)};
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "flatten.hpp" #include "flatten.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
namespace onnx_import namespace onnx_import
...@@ -33,11 +33,12 @@ namespace ngraph ...@@ -33,11 +33,12 @@ namespace ngraph
NodeVector inputs{node.get_ng_inputs()}; NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0); auto data = inputs.at(0);
auto axis = node.get_attribute_value<std::int64_t>("axis", 1); auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
auto data_rank = data->get_shape().size();
// Accepted range is [-r, r] where r = rank(input).
auto valid_axis =
common::validate_axis(node, axis, data_rank, -data_rank, data_rank);
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size())) return {ngraph::builder::flatten(data, valid_axis)};
<< "provided 'axis' attribute is not valid.";
return {ngraph::builder::flatten(data, axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "core/node.hpp" #include "core/node.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -34,12 +35,9 @@ namespace ngraph ...@@ -34,12 +35,9 @@ namespace ngraph
auto data = ng_inputs.at(0); auto data = ng_inputs.at(0);
auto indices = ng_inputs.at(1); auto indices = ng_inputs.at(1);
auto axis = node.get_attribute_value<int64_t>("axis", 0); auto axis = node.get_attribute_value<int64_t>("axis", 0);
if (axis < 0) auto valid_axis = common::validate_axis(node, axis, data->get_shape().size());
{
axis += data->get_shape().size();
}
return {std::make_shared<ngraph::op::Gather>(data, indices, axis)}; return {std::make_shared<ngraph::op::Gather>(data, indices, valid_axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -35,12 +35,10 @@ namespace ngraph ...@@ -35,12 +35,10 @@ namespace ngraph
const auto& input_shape = input->get_shape(); const auto& input_shape = input->get_shape();
auto axis = node.get_attribute_value<std::int64_t>("axis", 1); auto axis = node.get_attribute_value<std::int64_t>("axis", 1);
ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis < input_shape.size()) auto valid_axis = common::validate_axis(node, axis, input_shape.size());
<< "The provided axis value " << axis
<< " does not match the input tensor dimensions";
// reshape to 2D - "batch size" x "input feature dimensions" (NxD) // reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const auto coerced_tensor = ngraph::builder::flatten(input, axis); const auto coerced_tensor = ngraph::builder::flatten(input, valid_axis);
const auto& coerced_shape = coerced_tensor->get_shape(); const auto& coerced_shape = coerced_tensor->get_shape();
const std::shared_ptr<ngraph::Node> argmax_2d = const std::shared_ptr<ngraph::Node> argmax_2d =
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -37,17 +38,14 @@ namespace ngraph ...@@ -37,17 +38,14 @@ namespace ngraph
const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)}; const std::shared_ptr<ngraph::Node> data{node.get_ng_inputs().at(0)};
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)}; std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)}; const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
std::size_t valid_axis =
if (axis < 0) common::validate_axis(node, axis, data->get_shape().size());
{
axis += data->get_shape().size();
}
ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2) ASSERT_VALID_ARGUMENT(node, p_norm == 1 || p_norm == 2)
<< "Invalid `p` attribute value: " << p_norm << "Invalid `p` attribute value: " << p_norm
<< "Only normalization of 1st or 2nd order is supported."; << "Only normalization of 1st or 2nd order is supported.";
const AxisSet reduction_axes{static_cast<std::size_t>(axis)}; const AxisSet reduction_axes{valid_axis};
std::shared_ptr<ngraph::Node> norm = ngraph::builder::lp_norm( std::shared_ptr<ngraph::Node> norm = ngraph::builder::lp_norm(
data, reduction_axes, static_cast<std::size_t>(p_norm)); data, reduction_axes, static_cast<std::size_t>(p_norm));
norm = std::make_shared<ngraph::op::Broadcast>( norm = std::make_shared<ngraph::op::Broadcast>(
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "mean_variance_normalization.hpp" #include "mean_variance_normalization.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/op/fused/mvn.hpp" #include "ngraph/op/fused/mvn.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -47,9 +48,11 @@ namespace ngraph ...@@ -47,9 +48,11 @@ namespace ngraph
NodeVector mean_variance_normalization(const Node& node) NodeVector mean_variance_normalization(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<size_t>>("axes", {0, 2, 3}); auto axes = node.get_attribute_value<std::vector<int64_t>>("axes", {0, 2, 3});
std::vector<std::size_t> valid_axes =
common::validate_axes(node, axes, data->get_shape().size());
return {std::make_shared<ngraph::op::MVN>(data, AxisSet(axes))}; return {std::make_shared<ngraph::op::MVN>(data, AxisSet(valid_axes))};
} }
} // namespace set_9 } // namespace set_9
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "onehot.hpp" #include "onehot.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -51,14 +52,13 @@ namespace ngraph ...@@ -51,14 +52,13 @@ namespace ngraph
std::make_shared<ngraph::op::Slice>(values, Coordinate{1}, Coordinate{2}); std::make_shared<ngraph::op::Slice>(values, Coordinate{1}, Coordinate{2});
auto axis = node.get_attribute_value<std::int64_t>("axis", -1); auto axis = node.get_attribute_value<std::int64_t>("axis", -1);
if (axis < 0) // Accepted range for axis is [-r-1, r] where r = rank(indices). Validate
{ // against rank+1.
axis += indices_shape.size() + 1; std::size_t valid_axis = common::validate_axis(node,
} axis,
indices_shape.size() + 1,
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= indices_shape.size())) -indices_shape.size() - 1,
<< "invalid 'axis' attribute: " indices_shape.size());
<< node.get_attribute_value<std::int64_t>("axis", -1);
auto constant_depth = std::dynamic_pointer_cast<ngraph::op::Constant>(depth); auto constant_depth = std::dynamic_pointer_cast<ngraph::op::Constant>(depth);
...@@ -74,10 +74,11 @@ namespace ngraph ...@@ -74,10 +74,11 @@ namespace ngraph
// axis = 1 // axis = 1
// depth = 10 // depth = 10
// output_shape = (2, 10, 2) // output_shape = (2, 10, 2)
output_shape.insert(std::next(std::begin(output_shape), axis), depth_value); output_shape.insert(std::next(std::begin(output_shape), valid_axis),
depth_value);
std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>( std::shared_ptr<ngraph::Node> one_hot = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::OneHot>(indices, output_shape, axis), std::make_shared<ngraph::op::OneHot>(indices, output_shape, valid_axis),
values->get_element_type()); values->get_element_type());
auto broadcasted_values = auto broadcasted_values =
ngraph::op::numpy_style_broadcast({one_hot, on_value, off_value}); ngraph::op::numpy_style_broadcast({one_hot, on_value, off_value});
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -40,22 +41,26 @@ namespace ngraph ...@@ -40,22 +41,26 @@ namespace ngraph
node.get_ng_inputs().at(1), element::i32); node.get_ng_inputs().at(1), element::i32);
const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1); const auto batch_axis = node.get_attribute_value<int64_t>("batch_axis", 1);
std::size_t valid_batch_axis =
common::validate_axis(node, batch_axis, data->get_shape().size());
const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0); const auto time_axis = node.get_attribute_value<int64_t>("time_axis", 0);
std::size_t valid_time_axis =
common::validate_axis(node, time_axis, data->get_shape().size());
NGRAPH_CHECK(batch_axis == 0 || batch_axis == 1, NGRAPH_CHECK(valid_batch_axis == 0 || valid_batch_axis == 1,
"Allowed values of the 'batch_axis' attribute for ReverseSequence " "Allowed values of the 'batch_axis' attribute for ReverseSequence "
"operator are 0 and 1"); "operator are 0 and 1");
NGRAPH_CHECK(time_axis == 0 || time_axis == 1, NGRAPH_CHECK(valid_time_axis == 0 || valid_time_axis == 1,
"Allowed values of the 'time_axis' attribute for ReverseSequence " "Allowed values of the 'time_axis' attribute for ReverseSequence "
"operator are 0 and 1"); "operator are 0 and 1");
NGRAPH_CHECK(batch_axis != time_axis, NGRAPH_CHECK(valid_batch_axis != valid_time_axis,
"'batch_axis' and 'time_axis' attributes of the ReverseSequence " "'batch_axis' and 'time_axis' attributes of the ReverseSequence "
"operator can't point to the same dimension"); "operator can't point to the same dimension");
return {std::make_shared<ngraph::op::ReverseSequence>( return {std::make_shared<ngraph::op::ReverseSequence>(
data, sequence_lengths_i32, batch_axis, time_axis)}; data, sequence_lengths_i32, valid_batch_axis, valid_time_axis)};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "exceptions.hpp" #include "exceptions.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "softmax.hpp" #include "softmax.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,22 +36,13 @@ namespace ngraph ...@@ -35,22 +36,13 @@ namespace ngraph
auto data_shape = data->get_shape(); auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1); int axis = node.get_attribute_value<int64_t>("axis", 1);
std::size_t valid_axis = common::validate_axis(node, axis, data_shape.size());
if (axis < 0)
{
axis = data_shape.size() + axis;
}
ASSERT_VALID_ARGUMENT(node, axis < data_shape.size())
<< "provided 'axis' value:" << axis
<< " is out of input tensor dimensions range.";
// create vector of capacity data_dimensions - axis_divider position // create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis); std::vector<size_t> axes(data_shape.size() - valid_axis);
std::iota(std::begin(axes), std::end(axes), axis); std::iota(std::begin(axes), std::end(axes), valid_axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)}; return {std::make_shared<ngraph::op::Softmax>(data, axes)};
} }
} // namespace set_1 } // namespace set_1
} // namespace op } // namespace op
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "ngraph/op/fused/split.hpp" #include "ngraph/op/fused/split.hpp"
#include "op/split.hpp" #include "op/split.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -33,13 +34,15 @@ namespace ngraph ...@@ -33,13 +34,15 @@ namespace ngraph
const auto input = node.get_ng_inputs().at(0); const auto input = node.get_ng_inputs().at(0);
const auto outputs_number = node.get_output_names().size(); const auto outputs_number = node.get_output_names().size();
const auto axis = node.get_attribute_value<int64_t>("axis", 0); const auto axis = node.get_attribute_value<int64_t>("axis", 0);
std::size_t valid_axis =
common::validate_axis(node, axis, input->get_shape().size());
try try
{ {
const auto length_parts = const auto length_parts =
node.get_attribute_value<std::vector<std::size_t>>("split"); node.get_attribute_value<std::vector<std::size_t>>("split");
const auto fused_split = const auto fused_split =
std::make_shared<ngraph::op::Split>(input, axis, length_parts); std::make_shared<ngraph::op::Split>(input, valid_axis, length_parts);
return fused_split->decompose_op(); return fused_split->decompose_op();
} }
...@@ -49,7 +52,7 @@ namespace ngraph ...@@ -49,7 +52,7 @@ namespace ngraph
// the 'split' attribute - this means we should split the input tensor // the 'split' attribute - this means we should split the input tensor
// into same-length parts equal to the number of node outputs // into same-length parts equal to the number of node outputs
const auto fused_split = const auto fused_split =
std::make_shared<ngraph::op::Split>(input, axis, outputs_number); std::make_shared<ngraph::op::Split>(input, valid_axis, outputs_number);
return fused_split->decompose_op(); return fused_split->decompose_op();
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/squeeze.hpp" #include "ngraph/op/fused/squeeze.hpp"
#include "squeeze.hpp" #include "squeeze.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -32,17 +33,12 @@ namespace ngraph ...@@ -32,17 +33,12 @@ namespace ngraph
NodeVector squeeze(const Node& node) NodeVector squeeze(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
auto axes = node.get_attribute_value<std::vector<std::int64_t>>("axes", {}); std::vector<std::int64_t> axes =
node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
for (auto axis : axes) std::vector<std::size_t> valid_axes =
{ common::validate_axes(node, axes, data->get_shape().size());
ASSERT_VALID_ARGUMENT(node, axis >= 0)
<< "provided axes attribute is invalid. Only non-negative "
<< "integers are allowed, got " << axis << ".";
}
auto axes_node = std::make_shared<ngraph::op::Constant>( auto axes_node = std::make_shared<ngraph::op::Constant>(
element::u64, Shape{axes.size()}, axes); element::u64, Shape{valid_axes.size()}, valid_axes);
return {std::make_shared<ngraph::op::Squeeze>(data, axes_node)}; return {std::make_shared<ngraph::op::Squeeze>(data, axes_node)};
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
#include "topk.hpp" #include "topk.hpp"
#include "utils/common.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,21 +36,14 @@ namespace ngraph ...@@ -35,21 +36,14 @@ namespace ngraph
NodeVector topk(const Node& node) NodeVector topk(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
std::int64_t k{node.get_attribute_value<std::int64_t>("k")}; std::int64_t k{node.get_attribute_value<std::int64_t>("k")};
auto num_dimensions = data->get_shape().size(); auto num_dimensions = data->get_shape().size();
if (axis < 0) std::int64_t axis{node.get_attribute_value<std::int64_t>("axis", -1)};
{ std::int64_t valid_axis = common::validate_axis(node, axis, num_dimensions);
axis += num_dimensions;
}
ASSERT_VALID_ARGUMENT(node, axis < num_dimensions)
<< "`axis` parameter is out of range: " << axis;
std::shared_ptr<ngraph::Node> top_k = std::shared_ptr<ngraph::Node> top_k =
std::make_shared<ngraph::op::TopK>(data, axis, element::i64, k); std::make_shared<ngraph::op::TopK>(data, valid_axis, element::i64, k);
std::shared_ptr<ngraph::Node> indices = std::shared_ptr<ngraph::Node> indices =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 0); std::make_shared<ngraph::op::GetOutputElement>(top_k, 0);
......
...@@ -46,6 +46,53 @@ namespace ngraph ...@@ -46,6 +46,53 @@ namespace ngraph
static_cast<onnx::TensorProto_DataType>(onnx_type))); static_cast<onnx::TensorProto_DataType>(onnx_type)));
} }
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank)
{
// Accepted range of value for axis is [-tensor_rank, tensor_rank-1].
return validate_axis(node, axis, tensor_rank, -tensor_rank, tensor_rank - 1);
}
std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max)
{
// Accepted range of value for axis is [axis_range_min, axis_range_max].
NGRAPH_CHECK(((axis >= axis_range_min) && (axis <= axis_range_max)),
node.get_description(),
"Parameter axis ",
axis,
" out of the tensor rank [-",
axis_range_min,
", ",
axis_range_max,
"].");
if (axis < 0)
{
axis = axis + tensor_rank;
}
return static_cast<size_t>(axis);
}
std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
std::vector<std::int64_t> axes,
std::int64_t tensor_rank)
{
std::vector<std::size_t> new_axes;
for (auto a : axes)
{
new_axes.push_back(validate_axis(node, a, tensor_rank));
}
return new_axes;
}
} // namespace common } // namespace common
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <type_traits> // std::enable_if #include <type_traits> // std::enable_if
#include <vector> #include <vector>
#include "core/node.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp" #include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
...@@ -67,28 +68,52 @@ namespace ngraph ...@@ -67,28 +68,52 @@ namespace ngraph
return range; return range;
} }
/// \brief Handle negative axis value. /// \brief Handle out of range axis.
/// ///
/// \param[in] axis The requested axis value. /// \param[in] node The node with requested axis.
/// \param[in] tensor_dim The corresponding tensor dimensionality. /// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
/// ///
/// \tparam T Provided axis value type. /// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error.
/// If negative axis, it counts from the last to the first axis, by adding
/// tensor_rank to axis.
/// ///
/// \return If negative axis, then return sum of tensor dimension and axis. std::size_t validate_axis(const ngraph::onnx_import::Node& node,
std::int64_t axis,
std::int64_t tensor_rank);
/// \brief Handle out of range axis.
/// ///
template <typename T, /// \param[in] node The node with requested axis.
typename std::enable_if<std::is_integral<T>::value, int>::type = 0> /// \param[in] axis The requested axis value.
std::int64_t convert_negative_axis(T axis, std::size_t tensor_dim) /// \param[in] tensor_rank The corresponding tensor rank.
{ /// \param[in] axis_range_min The min value of accepted range for axis.
if (axis >= 0) /// \param[in] axis_range_max The max value of accepted range for axis.
{ ///
return std::min(axis, static_cast<T>(tensor_dim)); /// \return Checking if axis is in range [axis_range_min, axis_range_max], otherwise
} /// returns error.
else //// If negative axis, it counts from the last to the first axis, by adding
{ /// tensor_rank to axis.
return static_cast<std::int64_t>(tensor_dim) + axis; ///
} std::size_t validate_axis(const ngraph::onnx_import::Node& node,
} std::int64_t axis,
std::int64_t tensor_rank,
std::int64_t axis_range_min,
std::int64_t axis_range_max);
/// \brief Handle out of range axes in vector.
///
/// \param[in] node The node with requested axes.
/// \param[in] axes The requested vector of axes.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return If any negative axis in vector, it counts from the last to the first
/// axis, by adding tensor_rank to axis.
///
std::vector<std::size_t> validate_axes(const ngraph::onnx_import::Node& node,
std::vector<std::int64_t> axes,
std::int64_t tensor_rank);
/// \brief Creates a shifted square identity matrix. /// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that /// \note Shifting in the context of this operator means that
......
...@@ -32,13 +32,17 @@ namespace ngraph ...@@ -32,13 +32,17 @@ namespace ngraph
AxisSet get_reduction_axes(const Node& node) AxisSet get_reduction_axes(const Node& node)
{ {
auto reduction_axes = auto reduction_axes =
node.get_attribute_value<std::vector<std::size_t>>("axes", {}); node.get_attribute_value<std::vector<std::int64_t>>("axes", {});
std::vector<std::size_t> valid_reduction_axes = common::validate_axes(
node, reduction_axes, node.get_ng_inputs().at(0)->get_shape().size());
if (reduction_axes.empty()) if (reduction_axes.empty())
{ {
reduction_axes = onnx_import::common::get_monotonic_range<std::size_t>( valid_reduction_axes =
node.get_ng_inputs().at(0)->get_shape().size()); onnx_import::common::get_monotonic_range<std::size_t>(
node.get_ng_inputs().at(0)->get_shape().size());
} }
return AxisSet{reduction_axes}; return AxisSet{valid_reduction_axes};
} }
} // namespace detail } // namespace detail
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp" #include "utils/reshape.hpp"
namespace ngraph namespace ngraph
...@@ -64,8 +65,10 @@ namespace ngraph ...@@ -64,8 +65,10 @@ namespace ngraph
auto axis = node.get_attribute_value<std::int64_t>("axis", 0); auto axis = node.get_attribute_value<std::int64_t>("axis", 0);
auto keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1); auto keepdims = node.get_attribute_value<std::int64_t>("keepdims", 1);
auto input_node = node.get_ng_inputs().at(0); auto input_node = node.get_ng_inputs().at(0);
auto valid_axis = common::validate_axis(node, axis, input_node->get_shape().size());
auto op_node = std::make_shared<IndexReduction>(input_node, axis, element::i64); auto op_node =
std::make_shared<IndexReduction>(input_node, valid_axis, element::i64);
if (keepdims == 0) if (keepdims == 0)
{ {
...@@ -76,7 +79,7 @@ namespace ngraph ...@@ -76,7 +79,7 @@ namespace ngraph
auto convert_node = std::make_shared<ngraph::op::Convert>(op_node, element::f32); auto convert_node = std::make_shared<ngraph::op::Convert>(op_node, element::f32);
auto output_shape = input_node->get_shape(); auto output_shape = input_node->get_shape();
output_shape.at(axis) = 1; output_shape.at(valid_axis) = 1;
auto reshape_node = std::make_shared<ngraph::op::Reshape>( auto reshape_node = std::make_shared<ngraph::op::Reshape>(
convert_node, convert_node,
ngraph::get_default_order(op_node->get_shape().size()), ngraph::get_default_order(op_node->get_shape().size()),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment