Commit 09242c31 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

Use get_default_axis_vector utility function for Reshape op. (#1558)

parent 42cc4b82
...@@ -16,11 +16,11 @@ ...@@ -16,11 +16,11 @@
#include <numeric> #include <numeric>
#include "unsqueeze.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "utils/reshape.hpp"
#include "exceptions.hpp" #include "exceptions.hpp"
#include "unsqueeze.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -41,9 +41,7 @@ namespace ngraph ...@@ -41,9 +41,7 @@ namespace ngraph
} }
std::sort(std::begin(axes), std::end(axes), std::greater<int64_t>()); std::sort(std::begin(axes), std::end(axes), std::greater<int64_t>());
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
AxisVector input_order(data_shape.size());
std::iota(std::begin(input_order), std::end(input_order), 0);
for (auto axis : axes) for (auto axis : axes)
{ {
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "broadcasting.hpp" #include "broadcasting.hpp"
#include "reshape.hpp"
/// \brief Calculate output shape of numpy - style broadcast operation. /// \brief Calculate output shape of numpy - style broadcast operation.
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules /// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
...@@ -83,21 +84,15 @@ namespace ngraph ...@@ -83,21 +84,15 @@ namespace ngraph
: new_right_shape.push_back(right_full_shape.at(index)); : new_right_shape.push_back(right_full_shape.at(index));
} }
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> left_input_order(left->get_shape().size());
std::iota(std::begin(left_input_order), std::end(left_input_order), 0);
// Remove dims which have length of 1 from source shape // Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_left = std::shared_ptr<Node> broadcasted_left = std::make_shared<op::Reshape>(
std::make_shared<op::Reshape>(left, left_input_order, new_left_shape); left, reshape::get_default_axis_vector(left->get_shape().size()), new_left_shape);
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> right_input_order(right->get_shape().size());
std::iota(std::begin(right_input_order), std::end(right_input_order), 0);
// Remove dims which have length of 1 from source shape // Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_right = std::shared_ptr<Node> broadcasted_right = std::make_shared<op::Reshape>(
std::make_shared<op::Reshape>(right, right_input_order, new_right_shape); right,
reshape::get_default_axis_vector(right->get_shape().size()),
new_right_shape);
broadcasted_left = std::make_shared<op::Broadcast>( broadcasted_left = std::make_shared<op::Broadcast>(
broadcasted_left, output_shape, left_broadcast_axes); broadcasted_left, output_shape, left_broadcast_axes);
......
...@@ -39,28 +39,22 @@ namespace ngraph ...@@ -39,28 +39,22 @@ namespace ngraph
{ {
auto data_shape = node->get_shape(); auto data_shape = node->get_shape();
size_t first_dim_size = 1;
size_t last_dim_size = 1;
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor. // First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n] // The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
for (auto index = 0; index < data_shape.size(); ++index) size_t first_dim_size = std::accumulate(std::begin(data_shape),
{ std::next(std::begin(data_shape), axis),
last_dim_size *= data_shape.at(index); 1UL,
if (index < axis) std::multiplies<std::size_t>());
{
first_dim_size = last_dim_size;
}
}
last_dim_size /= first_dim_size;
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape size_t last_dim_size = std::accumulate(std::next(std::begin(data_shape), axis),
std::vector<size_t> input_order(data_shape.size()); std::end(data_shape),
std::iota(std::begin(input_order), std::end(input_order), 0); 1UL,
std::multiplies<std::size_t>());
return std::make_shared<ngraph::op::Reshape>( return std::make_shared<ngraph::op::Reshape>(
node, AxisVector{input_order}, Shape{first_dim_size, last_dim_size}); node,
get_default_axis_vector(data_shape.size()),
Shape{first_dim_size, last_dim_size});
} }
AxisVector get_default_axis_vector(std::size_t data_shape_size, std::size_t start_value) AxisVector get_default_axis_vector(std::size_t data_shape_size, std::size_t start_value)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment