Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
09242c31
Commit
09242c31
authored
Sep 05, 2018
by
Adam Rogowiec
Committed by
Michał Karzyński
Sep 05, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use get_default_axis_vector utility function for Reshape op. (#1558)
parent
42cc4b82
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
34 deletions
+21
-34
unsqueeze.cpp
src/ngraph/frontend/onnx_import/op/unsqueeze.cpp
+3
-5
broadcasting.cpp
src/ngraph/frontend/onnx_import/utils/broadcasting.cpp
+7
-12
reshape.cpp
src/ngraph/frontend/onnx_import/utils/reshape.cpp
+11
-17
No files found.
src/ngraph/frontend/onnx_import/op/unsqueeze.cpp
View file @
09242c31
...
@@ -16,11 +16,11 @@
...
@@ -16,11 +16,11 @@
#include <numeric>
#include <numeric>
#include "unsqueeze.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "utils/reshape.hpp"
#include "exceptions.hpp"
#include "exceptions.hpp"
#include "unsqueeze.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -41,9 +41,7 @@ namespace ngraph
...
@@ -41,9 +41,7 @@ namespace ngraph
}
}
std
::
sort
(
std
::
begin
(
axes
),
std
::
end
(
axes
),
std
::
greater
<
int64_t
>
());
std
::
sort
(
std
::
begin
(
axes
),
std
::
end
(
axes
),
std
::
greater
<
int64_t
>
());
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
AxisVector
input_order
{
reshape
::
get_default_axis_vector
(
data_shape
.
size
())};
AxisVector
input_order
(
data_shape
.
size
());
std
::
iota
(
std
::
begin
(
input_order
),
std
::
end
(
input_order
),
0
);
for
(
auto
axis
:
axes
)
for
(
auto
axis
:
axes
)
{
{
...
...
src/ngraph/frontend/onnx_import/utils/broadcasting.cpp
View file @
09242c31
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "broadcasting.hpp"
#include "broadcasting.hpp"
#include "reshape.hpp"
/// \brief Calculate output shape of numpy - style broadcast operation.
/// \brief Calculate output shape of numpy - style broadcast operation.
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
/// https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules
...
@@ -83,21 +84,15 @@ namespace ngraph
...
@@ -83,21 +84,15 @@ namespace ngraph
:
new_right_shape
.
push_back
(
right_full_shape
.
at
(
index
));
:
new_right_shape
.
push_back
(
right_full_shape
.
at
(
index
));
}
}
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std
::
vector
<
size_t
>
left_input_order
(
left
->
get_shape
().
size
());
std
::
iota
(
std
::
begin
(
left_input_order
),
std
::
end
(
left_input_order
),
0
);
// Remove dims which have length of 1 from source shape
// Remove dims which have length of 1 from source shape
std
::
shared_ptr
<
Node
>
broadcasted_left
=
std
::
shared_ptr
<
Node
>
broadcasted_left
=
std
::
make_shared
<
op
::
Reshape
>
(
std
::
make_shared
<
op
::
Reshape
>
(
left
,
left_input_order
,
new_left_shape
);
left
,
reshape
::
get_default_axis_vector
(
left
->
get_shape
().
size
()),
new_left_shape
);
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
std
::
vector
<
size_t
>
right_input_order
(
right
->
get_shape
().
size
());
std
::
iota
(
std
::
begin
(
right_input_order
),
std
::
end
(
right_input_order
),
0
);
// Remove dims which have length of 1 from source shape
// Remove dims which have length of 1 from source shape
std
::
shared_ptr
<
Node
>
broadcasted_right
=
std
::
shared_ptr
<
Node
>
broadcasted_right
=
std
::
make_shared
<
op
::
Reshape
>
(
std
::
make_shared
<
op
::
Reshape
>
(
right
,
right_input_order
,
new_right_shape
);
right
,
reshape
::
get_default_axis_vector
(
right
->
get_shape
().
size
()),
new_right_shape
);
broadcasted_left
=
std
::
make_shared
<
op
::
Broadcast
>
(
broadcasted_left
=
std
::
make_shared
<
op
::
Broadcast
>
(
broadcasted_left
,
output_shape
,
left_broadcast_axes
);
broadcasted_left
,
output_shape
,
left_broadcast_axes
);
...
...
src/ngraph/frontend/onnx_import/utils/reshape.cpp
View file @
09242c31
...
@@ -39,28 +39,22 @@ namespace ngraph
...
@@ -39,28 +39,22 @@ namespace ngraph
{
{
auto
data_shape
=
node
->
get_shape
();
auto
data_shape
=
node
->
get_shape
();
size_t
first_dim_size
=
1
;
size_t
last_dim_size
=
1
;
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// First dimension of output tensor is the product of [d_0, ... d_{axis-1}] dimensions of input tensor.
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
// The last dimension is the product of the rest of input tensor dimensions: [d_{axis}, ..., d_n]
for
(
auto
index
=
0
;
index
<
data_shape
.
size
();
++
index
)
size_t
first_dim_size
=
std
::
accumulate
(
std
::
begin
(
data_shape
),
{
std
::
next
(
std
::
begin
(
data_shape
),
axis
),
last_dim_size
*=
data_shape
.
at
(
index
);
1UL
,
if
(
index
<
axis
)
std
::
multiplies
<
std
::
size_t
>
());
{
first_dim_size
=
last_dim_size
;
}
}
last_dim_size
/=
first_dim_size
;
// Generate an increasing sequence (0,1,2,3..) as input_order for Reshape
size_t
last_dim_size
=
std
::
accumulate
(
std
::
next
(
std
::
begin
(
data_shape
),
axis
),
std
::
vector
<
size_t
>
input_order
(
data_shape
.
size
());
std
::
end
(
data_shape
),
std
::
iota
(
std
::
begin
(
input_order
),
std
::
end
(
input_order
),
0
);
1UL
,
std
::
multiplies
<
std
::
size_t
>
());
return
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
return
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
node
,
AxisVector
{
input_order
},
Shape
{
first_dim_size
,
last_dim_size
});
node
,
get_default_axis_vector
(
data_shape
.
size
()),
Shape
{
first_dim_size
,
last_dim_size
});
}
}
AxisVector
get_default_axis_vector
(
std
::
size_t
data_shape_size
,
std
::
size_t
start_value
)
AxisVector
get_default_axis_vector
(
std
::
size_t
data_shape_size
,
std
::
size_t
start_value
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment