Commit 09242c31 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

Use get_default_axis_vector utility function for Reshape op. (#1558)

parent 42cc4b82
master v0.29.0-rc.0 v0.28.0-rc.1 v0.28.0-rc.0 v0.27.1-rc.3 v0.27.1-rc.2 v0.27.1-rc.1 v0.27.1-rc.0 v0.27.0-rc.1 v0.27.0-rc.0 v0.26.1-rc.0 v0.26.0 v0.26.0-rc.8 v0.26.0-rc.7 v0.26.0-rc.6 v0.26.0-rc.5 v0.26.0-rc.4 v0.26.0-rc.3 v0.26.0-rc.2 v0.26.0-rc.0 v0.25.1-rc.11 v0.25.1-rc.10 v0.25.1-rc.9 v0.25.1-rc.8 v0.25.1-rc.7 v0.25.1-rc.6 v0.25.1-rc.5 v0.25.1-rc.4 v0.25.1-rc.3 v0.25.1-rc.2 v0.25.1-rc.1 v0.25.1-rc.0 v0.25.0 v0.25.0-rc.3 v0.25.0-rc.2 v0.25.0-rc.1 v0.25.0-rc.0 v0.25.0-dev.0 v0.24.0 v0.24.0-rc.3 v0.24.0-rc.2 v0.24.0-rc.1 v0.24.0-rc.0 v0.23.0-rc.7 v0.23.0-rc.6 v0.23.0-rc.5 v0.23.0-rc.4 v0.23.0-rc.3 v0.23.0-rc.2 v0.23.0-rc.1 v0.23.0-rc.0 v0.22.2-rc.0 v0.22.1 v0.22.1-rc.0 v0.22.0 v0.22.0-rc.2 v0.22.0-rc.0 v0.21.0 v0.21.0-rc.1 v0.21.0-rc.0 v0.20.1-rc.4 v0.20.1-rc.3 v0.20.1-rc.2 v0.20.1-rc.1 v0.20.1-rc.0 v0.20.0-rc.2 v0.20.0-rc.1 v0.20.0-rc.0 v0.20.0-dev.0 v0.19.1 v0.19.1-rc.0 v0.19.0 v0.19.0-rc.5 v0.19.0-rc.4 v0.19.0-rc.3 v0.19.0-rc.2 v0.19.0-rc.1 v0.19.0-rc.0 v0.18.1 v0.18.1-rc.1 v0.18.1-rc.0 v0.18.0 v0.18.0-rc.2 v0.18.0-rc.1 v0.18.0-rc.0 v0.17.0-rc.1 v0.17.0-rc.0 v0.16.0-rc.3 v0.16.0-rc.2 v0.16.0-rc.1 v0.16.0-rc.0 v0.15.1-rc.2 v0.15.1-rc.1 v0.15.0 v0.15.0-rc.2 v0.15.0-rc.1 v0.15.0-rc.0 v0.14.0 v0.14.0-rc.1 v0.14.0-rc.0 v0.13.0 v0.12.0 v0.12.0-rc.2 v0.12.0-rc.1 v0.12.0-rc.0 v0.11.1 v0.11.0 v0.11.0-rc.1 v0.11.0-rc.0 v0.10.1 v0.10.0 v0.10.0-rc.6 v0.10.0-rc.5 v0.10.0-rc.4 v0.10.0-rc.3 v0.10.0-rc.2 v0.10.0-rc.1 v0.10.0-rc.0 v0.9.1 v0.9.1-rc.0 v0.9.0 v0.9.0-rc.5 v0.9.0-rc.4 v0.9.0-rc.3 v0.9.0-rc.2 v0.9.0-rc.1 v0.9.0-rc.0 v0.8.2-rc.0 v0.8.1 v0.8.1-rc.0 v0.8.0 v0.8.0-rc.2 v0.8.0-rc.1 v0.8.0-rc.0
No related merge requests found
......@@ -16,11 +16,11 @@
#include <numeric>
#include "unsqueeze.hpp"
#include "ngraph/op/reshape.hpp"
#include "utils/reshape.hpp"
#include "exceptions.hpp"
#include "unsqueeze.hpp"
namespace ngraph
{
......@@ -41,9 +41,7 @@ namespace ngraph
}
std::sort(std::begin(axes), std::end(axes), std::greater<int64_t>());
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
AxisVector input_order(data_shape.size());
std::iota(std::begin(input_order), std::end(input_order), 0);
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes)
{
......
......@@ -22,6 +22,7 @@
#include "ngraph/op/reshape.hpp"
#include "broadcasting.hpp"
#include "reshape.hpp"
/// \brief Calculate output shape of numpy - style broadcast operation.
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
......@@ -83,21 +84,15 @@ namespace ngraph
: new_right_shape.push_back(right_full_shape.at(index));
}
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> left_input_order(left->get_shape().size());
std::iota(std::begin(left_input_order), std::end(left_input_order), 0);
// Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_left =
std::make_shared<op::Reshape>(left, left_input_order, new_left_shape);
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> right_input_order(right->get_shape().size());
std::iota(std::begin(right_input_order), std::end(right_input_order), 0);
std::shared_ptr<Node> broadcasted_left = std::make_shared<op::Reshape>(
left, reshape::get_default_axis_vector(left->get_shape().size()), new_left_shape);
// Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_right =
std::make_shared<op::Reshape>(right, right_input_order, new_right_shape);
std::shared_ptr<Node> broadcasted_right = std::make_shared<op::Reshape>(
right,
reshape::get_default_axis_vector(right->get_shape().size()),
new_right_shape);
broadcasted_left = std::make_shared<op::Broadcast>(
broadcasted_left, output_shape, left_broadcast_axes);
......
......@@ -39,28 +39,22 @@ namespace ngraph
{
auto data_shape = node->get_shape();
size_t first_dim_size = 1;
size_t last_dim_size = 1;
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
for (auto index = 0; index < data_shape.size(); ++index)
{
last_dim_size *= data_shape.at(index);
if (index < axis)
{
first_dim_size = last_dim_size;
}
}
last_dim_size /= first_dim_size;
size_t first_dim_size = std::accumulate(std::begin(data_shape),
std::next(std::begin(data_shape), axis),
1UL,
std::multiplies<std::size_t>());
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> input_order(data_shape.size());
std::iota(std::begin(input_order), std::end(input_order), 0);
size_t last_dim_size = std::accumulate(std::next(std::begin(data_shape), axis),
std::end(data_shape),
1UL,
std::multiplies<std::size_t>());
return std::make_shared<ngraph::op::Reshape>(
node, AxisVector{input_order}, Shape{first_dim_size, last_dim_size});
node,
get_default_axis_vector(data_shape.size()),
Shape{first_dim_size, last_dim_size});
}
AxisVector get_default_axis_vector(std::size_t data_shape_size, std::size_t start_value)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment