Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
02a6b07c
Commit
02a6b07c
authored
6 years ago
by
nikolay.korovaiko
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Sink Concat
parent
800d5f14
No related merge requests found
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
153 additions
and
26 deletions
+153
-26
reshape_sinking.cpp
src/ngraph/pass/reshape_sinking.cpp
+80
-26
reshape_sinking.cpp
test/reshape_sinking.cpp
+73
-0
No files found.
src/ngraph/pass/reshape_sinking.cpp
View file @
02a6b07c
...
...
@@ -26,6 +26,7 @@
#include "ngraph/log.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/get_output_element.hpp"
...
...
@@ -230,6 +231,32 @@ static void convert_binary_to_default_order(
reorders
[
binary
]
=
reorders
.
at
(
right
);
}
static
void
materialize_shapes
(
std
::
shared_ptr
<
Node
>
n
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
//skip multiple output nodes and deal with GOEs exclusively
if
(
n
->
get_outputs
().
size
()
>
1
)
{
return
;
}
for
(
size_t
i
=
0
;
i
<
n
->
get_arguments
().
size
();
i
++
)
{
//materialize all pending reshapes, flush pending reshapes
auto
arg
=
n
->
get_argument
(
i
);
if
(
reorders
.
count
(
arg
)
!=
0
)
{
NGRAPH_DEBUG
<<
"Materializing "
<<
describe_reshape
(
reorders
.
at
(
arg
))
<<
" for "
<<
arg
->
get_name
();
mark_reshape_for_deletion
(
reorders
.
at
(
arg
),
reshapes_to_delete
);
insert_reshape
(
n
,
reorders
.
at
(
arg
),
i
);
//no swimming up
}
}
reorders
[
n
]
=
create_default_reshape
(
n
);
}
static
void
sink_reshape
(
std
::
shared_ptr
<
op
::
Reshape
>
reshape
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
...
...
@@ -379,6 +406,55 @@ static void sink_quantize(std::shared_ptr<op::Quantize> quantize,
reorders
[
new_quantize
]
=
arg_reshape
;
}
static
void
sink_concat
(
std
::
shared_ptr
<
op
::
Concat
>
n
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
arg_reshape
=
reorders
.
at
(
n
->
get_argument
(
0
));
auto
order
=
arg_reshape
->
get_input_order
();
// we need the correct input shape to produce the right output shape
// we are going to create a label of the right input shape,
// so a new slice will have the right shape
auto
def_order
=
ngraph
::
get_permutation_to_default_order
(
order
);
auto
input_shape
=
ngraph
::
apply_permutation
(
arg_reshape
->
get_shape
(),
def_order
);
auto
dummy_correct_shape
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
arg_reshape
->
get_element_type
(),
input_shape
);
NodeVector
new_args
;
new_args
.
push_back
(
dummy_correct_shape
);
for
(
size_t
i
=
1
;
i
<
n
->
get_input_size
();
i
++
)
{
auto
iarg_reshape
=
reorders
.
at
(
n
->
get_argument
(
i
));
auto
iorder
=
iarg_reshape
->
get_input_order
();
if
(
iorder
!=
order
)
{
NGRAPH_DEBUG
<<
" input order at "
<<
i
<<
"-th arg is different from first arg"
;
materialize_shapes
(
n
,
reorders
,
reshapes_to_delete
);
return
;
}
auto
iinput_shape
=
ngraph
::
apply_permutation
(
iarg_reshape
->
get_shape
(),
def_order
);
auto
idummy_correct_shape
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
iarg_reshape
->
get_element_type
(),
input_shape
);
new_args
.
push_back
(
idummy_correct_shape
);
}
auto
new_axis
=
order
.
at
(
n
->
get_concatenation_axis
());
auto
new_concat
=
std
::
make_shared
<
op
::
Concat
>
(
new_args
,
new_axis
);
//put back the original arguments
for
(
size_t
i
=
0
;
i
<
new_concat
->
get_input_size
();
i
++
)
{
ngraph
::
replace_node
(
new_args
.
at
(
i
),
n
->
get_argument
(
i
));
}
NGRAPH_DEBUG
<<
"Replacing "
<<
n
->
get_name
()
<<
" with "
<<
new_concat
->
get_name
();
ngraph
::
replace_node
(
n
,
new_concat
);
auto
new_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
new_concat
,
order
,
n
->
get_shape
());
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
new_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
new_concat
]
=
new_reshape
;
}
static
void
sink_dequantize
(
std
::
shared_ptr
<
op
::
Dequantize
>
dequantize
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
...
...
@@ -396,32 +472,6 @@ static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
reorders
[
new_dequantize
]
=
arg_reshape
;
}
static
void
materialize_shapes
(
std
::
shared_ptr
<
Node
>
n
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
//skip multiple output nodes and deal with GOEs exclusively
if
(
n
->
get_outputs
().
size
()
>
1
)
{
return
;
}
for
(
size_t
i
=
0
;
i
<
n
->
get_arguments
().
size
();
i
++
)
{
//materialize all pending reshapes, flush pending reshapes
auto
arg
=
n
->
get_argument
(
i
);
if
(
reorders
.
count
(
arg
)
!=
0
)
{
NGRAPH_DEBUG
<<
"Materializing "
<<
describe_reshape
(
reorders
.
at
(
arg
))
<<
" for "
<<
arg
->
get_name
();
mark_reshape_for_deletion
(
reorders
.
at
(
arg
),
reshapes_to_delete
);
insert_reshape
(
n
,
reorders
.
at
(
arg
),
i
);
//no swimming up
}
}
reorders
[
n
]
=
create_default_reshape
(
n
);
}
//The goal of ReshapeSinking is to remove
//round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
//around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
...
...
@@ -479,6 +529,10 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
{
sink_pad
(
pad
,
reorders
,
reshapes_to_delete
);
}
else
if
(
auto
concat
=
std
::
dynamic_pointer_cast
<
op
::
Concat
>
(
n
))
{
sink_concat
(
concat
,
reorders
,
reshapes_to_delete
);
}
else
{
materialize_shapes
(
n
,
reorders
,
reshapes_to_delete
);
...
...
This diff is collapsed.
Click to expand it.
test/reshape_sinking.cpp
View file @
02a6b07c
...
...
@@ -203,3 +203,76 @@ TEST(reshape_sinking, slice_pad)
size_t
before_after
=
count_ops_of_type
<
op
::
Reshape
>
(
f
);
ASSERT_LE
(
before_after
,
before_count
);
}
TEST
(
reshape_sinking
,
concat
)
{
Shape
shape
{};
Shape
shape_w
{
1
,
1
,
1
,
1
};
Shape
shape_x
{
1
,
3
,
3
,
1
};
Shape
shape_b
{
1
,
3
,
3
,
1
};
Shape
r_shape
{
1
,
3
,
3
,
2
};
auto
B_
=
op
::
Constant
::
create
(
element
::
f32
,
shape_w
,
{
3
});
auto
B
=
make_shared
<
op
::
Reshape
>
(
B_
,
AxisVector
{
3
,
2
,
0
,
1
},
Shape
{
1
,
1
,
1
,
1
});
/* nchw */
auto
A_
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_x
);
auto
A
=
make_shared
<
op
::
Reshape
>
(
A_
,
AxisVector
{
0
,
3
,
1
,
2
},
Shape
{
1
,
1
,
3
,
3
});
/* nchw */
auto
C
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{
1
},
{
2
});
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
r_shape
);
auto
conv
=
make_shared
<
op
::
Convolution
>
(
A
,
B
,
Strides
{
1
,
1
},
Strides
{
1
,
1
},
CoordinateDiff
{
0
,
0
},
CoordinateDiff
{
0
,
0
},
Strides
{
1
,
1
});
auto
reshape_conv
=
make_shared
<
op
::
Reshape
>
(
conv
,
AxisVector
{
0
,
2
,
3
,
1
},
Shape
{
1
,
3
,
3
,
1
});
/* nhwc */
auto
broadcast
=
make_shared
<
op
::
Broadcast
>
(
C
,
reshape_conv
->
get_shape
(),
AxisSet
{
0
,
1
,
2
});
auto
add
=
broadcast
+
reshape_conv
;
auto
B1_
=
op
::
Constant
::
create
(
element
::
f32
,
shape_w
,
{
3
});
auto
B1
=
make_shared
<
op
::
Reshape
>
(
B1_
,
AxisVector
{
3
,
2
,
0
,
1
},
Shape
{
1
,
1
,
1
,
1
});
auto
A1_
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_x
);
auto
A1
=
make_shared
<
op
::
Reshape
>
(
A1_
,
AxisVector
{
0
,
3
,
1
,
2
},
Shape
{
1
,
1
,
3
,
3
});
auto
C1
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{
1
},
{
2
});
auto
R1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
r_shape
);
auto
conv1
=
make_shared
<
op
::
Convolution
>
(
A1
,
B1
,
Strides
{
1
,
1
},
Strides
{
1
,
1
},
CoordinateDiff
{
0
,
0
},
CoordinateDiff
{
0
,
0
},
Strides
{
1
,
1
});
auto
reshape_conv1
=
make_shared
<
op
::
Reshape
>
(
conv1
,
AxisVector
{
0
,
2
,
3
,
1
},
Shape
{
1
,
3
,
3
,
1
});
auto
broadcast1
=
make_shared
<
op
::
Broadcast
>
(
C1
,
reshape_conv
->
get_shape
(),
AxisSet
{
0
,
1
,
2
});
auto
add1
=
broadcast1
+
reshape_conv1
;
auto
concat
=
make_shared
<
op
::
Concat
>
(
NodeVector
{
add
,
add1
},
3
);
auto
relu
=
make_shared
<
op
::
Relu
>
(
concat
);
auto
reshape_relu
=
make_shared
<
op
::
Reshape
>
(
relu
,
AxisVector
{
0
,
3
,
1
,
2
},
Shape
{
1
,
2
,
3
,
3
});
/* nchw */
auto
B2_
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{
1
,
1
,
2
,
1
},
{
2
});
auto
B2
=
make_shared
<
op
::
Reshape
>
(
B2_
,
AxisVector
{
3
,
2
,
0
,
1
},
Shape
{
1
,
2
,
1
,
1
});
auto
conv2
=
make_shared
<
op
::
Convolution
>
(
reshape_relu
,
B2
,
Strides
{
1
,
1
},
Strides
{
1
,
1
},
CoordinateDiff
{
0
,
0
},
CoordinateDiff
{
0
,
0
},
Strides
{
1
,
1
});
auto
reshape_conv2
=
make_shared
<
op
::
Reshape
>
(
conv2
,
AxisVector
{
0
,
2
,
3
,
1
},
Shape
{
1
,
3
,
3
,
1
});
/* nhwc */
auto
f
=
make_shared
<
Function
>
(
reshape_conv2
,
ParameterVector
{
A_
,
A1_
});
pass
::
Manager
pass_manager
;
size_t
before_count
=
count_ops_of_type
<
op
::
Reshape
>
(
f
);
pass_manager
.
register_pass
<
pass
::
VisualizeTree
>
(
"before.pdf"
);
pass_manager
.
register_pass
<
pass
::
ReshapeSinking
>
();
pass_manager
.
register_pass
<
pass
::
ReshapeElimination
>
();
pass_manager
.
register_pass
<
pass
::
CommonSubexpressionElimination
>
();
pass_manager
.
register_pass
<
pass
::
VisualizeTree
>
(
"after.pdf"
);
pass_manager
.
run_passes
(
f
);
size_t
before_after
=
count_ops_of_type
<
op
::
Reshape
>
(
f
);
ASSERT_LE
(
before_after
,
before_count
);
}
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment