Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
8eb63379
Commit
8eb63379
authored
Jul 26, 2019
by
Sang Ik Lee
Committed by
Scott Cyphers
Jul 26, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Reshape sinking: fix issue with handling rank changing reshape. (#3314)
parent
c04b5588
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
35 deletions
+70
-35
reshape_sinking.cpp
src/ngraph/pass/reshape_sinking.cpp
+70
-35
No files found.
src/ngraph/pass/reshape_sinking.cpp
View file @
8eb63379
...
...
@@ -56,20 +56,51 @@ static string describe_reshape(shared_ptr<Node> node)
return
ss
.
str
();
}
static
shared_ptr
<
op
::
Reshape
>
make_reshape
(
shared_ptr
<
Node
>
arg
,
const
AxisVector
&
input_order
,
const
Shape
&
output_shape
)
{
auto
reshape
=
make_shared
<
op
::
Reshape
>
(
arg
,
input_order
,
output_shape
);
NGRAPH_DEBUG
<<
"Make Reshape "
<<
describe_reshape
(
reshape
);
return
reshape
;
}
static
void
write_reshapemap
(
ReshapeMap
&
reorders
,
shared_ptr
<
Node
>
target
,
shared_ptr
<
op
::
Reshape
>
reshape
)
{
NGRAPH_DEBUG
<<
"Write ReshapeMap["
<<
target
->
get_name
()
<<
"] = "
<<
describe_reshape
(
reshape
);
reorders
[
target
]
=
reshape
;
}
static
shared_ptr
<
op
::
Reshape
>
read_reshapemap
(
ReshapeMap
&
reorders
,
shared_ptr
<
Node
>
target
)
{
auto
reorder
=
reorders
.
at
(
target
);
NGRAPH_DEBUG
<<
"Read ReshapeMap["
<<
target
->
get_name
()
<<
"] -> "
<<
describe_reshape
(
reorder
);
return
reorder
;
}
static
shared_ptr
<
op
::
Reshape
>
combine_reshapes
(
shared_ptr
<
op
::
Reshape
>
r1
,
shared_ptr
<
op
::
Reshape
>
r2
)
{
auto
default_order
=
ngraph
::
get_default_order
(
r1
->
get_shape
());
auto
perm_r1
=
apply_permutation
(
default_order
,
r1
->
get_input_order
());
auto
perm_r2
=
apply_permutation
(
perm_r1
,
r2
->
get_input_order
());
auto
rreshape
=
make_shared
<
op
::
Reshape
>
(
r2
->
get_argument
(
0
),
perm_r2
,
r2
->
get_shape
());
auto
rreshape
=
make_reshape
(
r2
->
get_argument
(
0
),
perm_r2
,
r2
->
get_shape
());
NGRAPH_DEBUG
<<
"Combining "
<<
describe_reshape
(
r1
)
<<
" and "
<<
describe_reshape
(
r2
)
<<
" into "
<<
describe_reshape
(
rreshape
);
return
rreshape
;
}
static
void
insert_reshape
(
shared_ptr
<
Node
>
target
,
shared_ptr
<
Node
>
reshape
,
size_t
input_index
)
{
NGRAPH_DEBUG
<<
"Inserting reshape at input "
<<
target
->
get_name
()
<<
" input index "
<<
input_index
;
auto
arg
=
target
->
input
(
input_index
).
get_source_output
();
NGRAPH_DEBUG
<<
"Arg shape: "
<<
arg
.
get_shape
();
auto
new_reshape
=
reshape
->
copy_with_new_inputs
({
arg
});
NGRAPH_DEBUG
<<
"Inserting reshape "
<<
describe_reshape
(
new_reshape
)
<<
" at input "
<<
target
->
get_name
()
<<
" input index "
<<
input_index
;
target
->
input
(
input_index
).
replace_source_output
(
new_reshape
->
output
(
0
));
}
...
...
@@ -92,7 +123,8 @@ static void mark_reshape_for_deletion(shared_ptr<Node> reshape,
static
shared_ptr
<
op
::
Reshape
>
create_default_reshape
(
shared_ptr
<
Node
>
n
)
{
auto
default_order
=
ngraph
::
get_default_order
(
n
->
get_shape
());
auto
default_reshape
=
make_shared
<
op
::
Reshape
>
(
n
,
default_order
,
n
->
get_shape
());
auto
default_reshape
=
make_reshape
(
n
,
default_order
,
n
->
get_shape
());
NGRAPH_DEBUG
<<
"Default reshape: "
<<
describe_reshape
(
default_reshape
);
return
default_reshape
;
}
...
...
@@ -187,7 +219,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
auto
new_arg_shape
=
ngraph
::
apply_permutation
(
broadcast_input
->
get_shape
(),
new_source_axis_order
);
broadcast_input
=
make_
shared
<
op
::
Reshape
>
(
broadcast_input
,
new_source_axis_order
,
new_arg_shape
);
make_
reshape
(
broadcast_input
,
new_source_axis_order
,
new_arg_shape
);
}
auto
new_broadcast
=
make_shared
<
op
::
Broadcast
>
(
...
...
@@ -209,12 +241,11 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
//of a binary op isn't in the default format (i.e. nhwc instead of nchw)
//We have to normalize this other argument to nchw by swimming nchw towards parameters
//as far as we can
static
void
convert_binary_to_default_order
(
shared_ptr
<
Node
>
binary
,
const
Input
<
Node
>&
input
,
shared_ptr
<
Node
>
right
,
unordered_map
<
shared_ptr
<
Node
>
,
shared_ptr
<
op
::
Reshape
>>&
reorders
,
set
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
static
void
convert_binary_to_default_order
(
shared_ptr
<
Node
>
binary
,
const
Input
<
Node
>&
input
,
shared_ptr
<
Node
>
right
,
ReshapeMap
&
reorders
,
set
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
left
=
input
.
get_source_output
().
get_node_shared_ptr
();
auto
perm_to_def
=
...
...
@@ -222,13 +253,13 @@ static void convert_binary_to_default_order(
auto
new_shape
=
apply_permutation
(
left
->
get_shape
(),
perm_to_def
);
NGRAPH_DEBUG
<<
"right = "
<<
ngraph
::
vector_to_string
(
right
->
get_shape
())
<<
", "
<<
right
->
get_name
();
auto
new_reshape
=
make_
shared
<
op
::
Reshape
>
(
left
,
perm_to_def
,
new_shape
);
auto
new_reshape
=
make_
reshape
(
left
,
perm_to_def
,
new_shape
);
NGRAPH_DEBUG
<<
"left : About to swim "
<<
describe_reshape
(
new_reshape
)
<<
" up to "
<<
left
->
get_name
();
//this should now insert and swim reshape on right
swim
(
input
,
new_reshape
);
mark_reshape_for_deletion
(
reorders
.
at
(
right
),
reshapes_to_delete
);
reorders
[
binary
]
=
reorders
.
at
(
right
);
write_reshapemap
(
reorders
,
binary
,
read_reshapemap
(
reorders
,
right
)
);
}
static
void
materialize_shapes
(
shared_ptr
<
Node
>
n
,
...
...
@@ -247,32 +278,37 @@ static void materialize_shapes(shared_ptr<Node> n,
auto
arg
=
n
->
get_argument
(
i
);
if
(
reorders
.
count
(
arg
)
!=
0
)
{
NGRAPH_DEBUG
<<
"Materializing "
<<
describe_reshape
(
reorders
.
at
(
arg
))
<<
" for "
auto
arg_reshape
=
reorders
.
at
(
arg
);
NGRAPH_DEBUG
<<
"Materializing "
<<
describe_reshape
(
arg_reshape
)
<<
" for "
<<
arg
->
get_name
();
mark_reshape_for_deletion
(
reorders
.
at
(
arg
),
reshapes_to_delete
);
if
(
reorders
.
at
(
arg
)
->
get_input_order
()
!=
get_default_order
(
arg
->
get_shape
()))
mark_reshape_for_deletion
(
arg_reshape
,
reshapes_to_delete
);
auto
arg_shape
=
arg
->
get_shape
();
if
(
arg_reshape
->
get_input_order
()
!=
get_default_order
(
arg
->
get_shape
()))
{
// Insert if arg needs to be transposed.
insert_reshape
(
n
,
reorders
.
at
(
arg
)
,
i
);
insert_reshape
(
n
,
arg_reshape
,
i
);
}
//no swimming up
}
}
reorders
[
n
]
=
create_default_reshape
(
n
);
write_reshapemap
(
reorders
,
n
,
create_default_reshape
(
n
)
);
}
static
void
sink_reshape
(
shared_ptr
<
op
::
Reshape
>
reshape
,
ReshapeMap
&
reorders
,
set
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
NGRAPH_DEBUG
<<
"Sinking Reshape :"
<<
describe_reshape
(
reshape
);
auto
orig_reshape
=
reorders
.
at
(
reshape
->
get_argument
(
0
));
if
(
!
reshape
->
get_is_transpose
())
// 1) Not a Transpose or 2) Rank changing operation.
if
((
reshape
->
get_output_shape
().
size
()
!=
reshape
->
get_input_order
().
size
())
||
(
!
reshape
->
get_is_transpose
()))
{
NGRAPH_DEBUG
<<
"Materializing "
<<
describe_reshape
(
orig_reshape
)
<<
" for reshape "
<<
reshape
->
get_name
(
);
<<
describe_reshape
(
reshape
);
insert_reshape
(
reshape
,
orig_reshape
,
0
);
mark_reshape_for_deletion
(
orig_reshape
,
reshapes_to_delete
);
reorders
[
reshape
]
=
create_default_reshape
(
reshape
);
write_reshapemap
(
reorders
,
reshape
,
create_default_reshape
(
reshape
)
);
}
else
{
...
...
@@ -284,9 +320,7 @@ static void sink_reshape(shared_ptr<op::Reshape> reshape,
//replace reshape with combined one
ngraph
::
replace_node
(
reshape
,
new_reshape
);
mark_reshape_for_deletion
(
new_reshape
,
reshapes_to_delete
);
reorders
[
new_reshape
]
=
new_reshape
;
NGRAPH_DEBUG
<<
"Combining "
<<
describe_reshape
(
orig_reshape
)
<<
" and"
<<
describe_reshape
(
reshape
)
<<
" into "
<<
describe_reshape
(
new_reshape
);
write_reshapemap
(
reorders
,
new_reshape
,
new_reshape
);
}
}
...
...
@@ -294,9 +328,9 @@ static void sink_unary(shared_ptr<op::util::UnaryElementwiseArithmetic> n,
ReshapeMap
&
reorders
,
set
<
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
arg_reshape
=
re
orders
.
at
(
n
->
get_argument
(
0
));
auto
arg_reshape
=
re
ad_reshapemap
(
reorders
,
n
->
get_argument
(
0
));
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
arg_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
n
]
=
reorders
[
n
->
get_argument
(
0
)]
;
write_reshapemap
(
reorders
,
n
,
arg_reshape
)
;
}
static
void
sink_binary
(
shared_ptr
<
op
::
util
::
BinaryElementwiseArithmetic
>
binary
,
...
...
@@ -310,7 +344,7 @@ static void sink_binary(shared_ptr<op::util::BinaryElementwiseArithmetic> binary
{
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
reorders
.
at
(
left
))
<<
" for "
<<
binary
->
get_name
();
reorders
[
binary
]
=
reorders
.
at
(
left
);
write_reshapemap
(
reorders
,
binary
,
read_reshapemap
(
reorders
,
left
)
);
//at this point, both reshapes will be eventually removed
mark_reshape_for_deletion
(
reorders
.
at
(
left
),
reshapes_to_delete
);
mark_reshape_for_deletion
(
reorders
.
at
(
right
),
reshapes_to_delete
);
...
...
@@ -360,9 +394,9 @@ static void sink_slice(shared_ptr<op::Slice> n,
NGRAPH_DEBUG
<<
"Replacing "
<<
n
->
get_name
()
<<
" with "
<<
new_slice
->
get_name
();
ngraph
::
replace_node
(
n
,
new_slice
);
auto
new_reshape
=
make_
shared
<
op
::
Reshape
>
(
new_slice
,
order
,
n
->
get_shape
());
auto
new_reshape
=
make_
reshape
(
new_slice
,
order
,
n
->
get_shape
());
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
new_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
new_slice
]
=
new_reshape
;
write_reshapemap
(
reorders
,
new_slice
,
new_reshape
)
;
}
static
void
...
...
@@ -385,9 +419,9 @@ static void
ngraph
::
replace_node
(
dummy_correct_shape
,
n
->
get_argument
(
0
));
NGRAPH_DEBUG
<<
"Replacing "
<<
n
->
get_name
()
<<
" with "
<<
new_pad
->
get_name
();
ngraph
::
replace_node
(
n
,
new_pad
);
auto
new_reshape
=
make_
shared
<
op
::
Reshape
>
(
new_pad
,
order
,
n
->
get_shape
());
auto
new_reshape
=
make_
reshape
(
new_pad
,
order
,
n
->
get_shape
());
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
new_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
new_pad
]
=
new_reshape
;
write_reshapemap
(
reorders
,
new_pad
,
new_reshape
)
;
}
static
void
sink_quantize
(
shared_ptr
<
op
::
Quantize
>
quantize
,
ReshapeMap
&
reorders
,
...
...
@@ -404,7 +438,7 @@ static void sink_quantize(shared_ptr<op::Quantize> quantize,
quantize
->
get_round_mode
());
ngraph
::
replace_node
(
quantize
,
new_quantize
);
reorders
[
new_quantize
]
=
arg_reshape
;
write_reshapemap
(
reorders
,
new_quantize
,
arg_reshape
)
;
}
static
void
sink_concat
(
shared_ptr
<
op
::
Concat
>
n
,
...
...
@@ -451,9 +485,9 @@ static void sink_concat(shared_ptr<op::Concat> n,
NGRAPH_DEBUG
<<
"Replacing "
<<
n
->
get_name
()
<<
" with "
<<
new_concat
->
get_name
();
ngraph
::
replace_node
(
n
,
new_concat
);
auto
new_reshape
=
make_
shared
<
op
::
Reshape
>
(
new_concat
,
order
,
n
->
get_shape
());
auto
new_reshape
=
make_
reshape
(
new_concat
,
order
,
n
->
get_shape
());
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
new_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
new_concat
]
=
new_reshape
;
write_reshapemap
(
reorders
,
new_concat
,
new_reshape
)
;
}
static
void
sink_dequantize
(
shared_ptr
<
op
::
Dequantize
>
dequantize
,
...
...
@@ -470,7 +504,7 @@ static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
axes_in_def_order
);
ngraph
::
replace_node
(
dequantize
,
new_dequantize
);
reorders
[
new_dequantize
]
=
arg_reshape
;
write_reshapemap
(
reorders
,
new_dequantize
,
arg_reshape
)
;
}
//The goal of ReshapeSinking is to remove
...
...
@@ -491,7 +525,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
//STEP 1 : Sink or Swim reshapes away for op clusters
for
(
auto
n
:
f
->
get_ordered_ops
())
{
NGRAPH_DEBUG
<<
"Processing node "
<<
n
->
get_name
();
NGRAPH_DEBUG
<<
"
Start:
Processing node "
<<
n
->
get_name
();
//collect all Result nodes for a sanity check
if
(
n
->
is_output
())
{
...
...
@@ -512,7 +546,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
}
else
if
(
auto
goe
=
dynamic_pointer_cast
<
op
::
GetOutputElement
>
(
n
))
{
reorders
[
goe
]
=
create_default_reshape
(
goe
);
write_reshapemap
(
reorders
,
goe
,
create_default_reshape
(
goe
)
);
}
else
if
(
auto
quantize
=
dynamic_pointer_cast
<
op
::
Quantize
>
(
n
))
{
...
...
@@ -555,6 +589,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
{
materialize_shapes
(
n
,
reorders
,
reshapes_to_delete
);
}
NGRAPH_DEBUG
<<
"End: Processing node "
<<
n
->
get_name
();
}
//STEP 2: purge all the reshapes we either sunk or swam.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment