Commit 09242c31 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Michał Karzyński

Use get_default_axis_vector utility function for Reshape op. (#1558)

parent 42cc4b82
......@@ -16,11 +16,11 @@
#include <numeric>
#include "unsqueeze.hpp"
#include "ngraph/op/reshape.hpp"
#include "utils/reshape.hpp"
#include "exceptions.hpp"
#include "unsqueeze.hpp"
namespace ngraph
{
......@@ -41,9 +41,7 @@ namespace ngraph
}
std::sort(std::begin(axes), std::end(axes), std::greater<int64_t>());
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
AxisVector input_order(data_shape.size());
std::iota(std::begin(input_order), std::end(input_order), 0);
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes)
{
......
......@@ -22,6 +22,7 @@
#include "ngraph/op/reshape.hpp"
#include "broadcasting.hpp"
#include "reshape.hpp"
/// \brief Calculate output shape of numpy - style broadcast operation.
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
......@@ -83,21 +84,15 @@ namespace ngraph
: new_right_shape.push_back(right_full_shape.at(index));
}
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> left_input_order(left->get_shape().size());
std::iota(std::begin(left_input_order), std::end(left_input_order), 0);
// Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_left =
std::make_shared<op::Reshape>(left, left_input_order, new_left_shape);
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> right_input_order(right->get_shape().size());
std::iota(std::begin(right_input_order), std::end(right_input_order), 0);
std::shared_ptr<Node> broadcasted_left = std::make_shared<op::Reshape>(
left, reshape::get_default_axis_vector(left->get_shape().size()), new_left_shape);
// Remove dims which have length of 1 from source shape
std::shared_ptr<Node> broadcasted_right =
std::make_shared<op::Reshape>(right, right_input_order, new_right_shape);
std::shared_ptr<Node> broadcasted_right = std::make_shared<op::Reshape>(
right,
reshape::get_default_axis_vector(right->get_shape().size()),
new_right_shape);
broadcasted_left = std::make_shared<op::Broadcast>(
broadcasted_left, output_shape, left_broadcast_axes);
......
......@@ -39,28 +39,22 @@ namespace ngraph
{
auto data_shape = node->get_shape();
size_t first_dim_size = 1;
size_t last_dim_size = 1;
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
for (auto index = 0; index < data_shape.size(); ++index)
{
last_dim_size *= data_shape.at(index);
if (index < axis)
{
first_dim_size = last_dim_size;
}
}
last_dim_size /= first_dim_size;
size_t first_dim_size = std::accumulate(std::begin(data_shape),
std::next(std::begin(data_shape), axis),
1UL,
std::multiplies<std::size_t>());
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std::vector<size_t> input_order(data_shape.size());
std::iota(std::begin(input_order), std::end(input_order), 0);
size_t last_dim_size = std::accumulate(std::next(std::begin(data_shape), axis),
std::end(data_shape),
1UL,
std::multiplies<std::size_t>());
return std::make_shared<ngraph::op::Reshape>(
node, AxisVector{input_order}, Shape{first_dim_size, last_dim_size});
node,
get_default_axis_vector(data_shape.size()),
Shape{first_dim_size, last_dim_size});
}
AxisVector get_default_axis_vector(std::size_t data_shape_size, std::size_t start_value)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment