Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
bf365b12
Unverified
Commit
bf365b12
authored
Jan 17, 2019
by
Adam Procter
Committed by
GitHub
Jan 17, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #2313 from NervanaSystems/krovatkin/rs_concat
Sink Concat
parents
ad30e973
80923525
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
153 additions
and
26 deletions
+153
-26
reshape_sinking.cpp
src/ngraph/pass/reshape_sinking.cpp
+80
-26
reshape_sinking.cpp
test/reshape_sinking.cpp
+73
-0
No files found.
src/ngraph/pass/reshape_sinking.cpp
View file @
bf365b12
...
...
@@ -26,6 +26,7 @@
#include "ngraph/log.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/get_output_element.hpp"
...
...
@@ -230,6 +231,32 @@ static void convert_binary_to_default_order(
reorders
[
binary
]
=
reorders
.
at
(
right
);
}
static
void
materialize_shapes
(
std
::
shared_ptr
<
Node
>
n
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
//skip multiple output nodes and deal with GOEs exclusively
if
(
n
->
get_outputs
().
size
()
>
1
)
{
return
;
}
for
(
size_t
i
=
0
;
i
<
n
->
get_arguments
().
size
();
i
++
)
{
//materialize all pending reshapes, flush pending reshapes
auto
arg
=
n
->
get_argument
(
i
);
if
(
reorders
.
count
(
arg
)
!=
0
)
{
NGRAPH_DEBUG
<<
"Materializing "
<<
describe_reshape
(
reorders
.
at
(
arg
))
<<
" for "
<<
arg
->
get_name
();
mark_reshape_for_deletion
(
reorders
.
at
(
arg
),
reshapes_to_delete
);
insert_reshape
(
n
,
reorders
.
at
(
arg
),
i
);
//no swimming up
}
}
reorders
[
n
]
=
create_default_reshape
(
n
);
}
static
void
sink_reshape
(
std
::
shared_ptr
<
op
::
Reshape
>
reshape
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
...
...
@@ -379,6 +406,55 @@ static void sink_quantize(std::shared_ptr<op::Quantize> quantize,
reorders
[
new_quantize
]
=
arg_reshape
;
}
static
void
sink_concat
(
std
::
shared_ptr
<
op
::
Concat
>
n
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
auto
arg_reshape
=
reorders
.
at
(
n
->
get_argument
(
0
));
auto
order
=
arg_reshape
->
get_input_order
();
// we need the correct input shape to produce the right output shape
// we are going to create a label of the right input shape,
// so a new slice will have the right shape
auto
def_order
=
ngraph
::
get_permutation_to_default_order
(
order
);
auto
input_shape
=
ngraph
::
apply_permutation
(
arg_reshape
->
get_shape
(),
def_order
);
auto
dummy_correct_shape
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
arg_reshape
->
get_element_type
(),
input_shape
);
NodeVector
new_args
;
new_args
.
push_back
(
dummy_correct_shape
);
for
(
size_t
i
=
1
;
i
<
n
->
get_input_size
();
i
++
)
{
auto
iarg_reshape
=
reorders
.
at
(
n
->
get_argument
(
i
));
auto
iorder
=
iarg_reshape
->
get_input_order
();
if
(
iorder
!=
order
)
{
NGRAPH_DEBUG
<<
" input order at "
<<
i
<<
"-th arg is different from first arg"
;
materialize_shapes
(
n
,
reorders
,
reshapes_to_delete
);
return
;
}
auto
iinput_shape
=
ngraph
::
apply_permutation
(
iarg_reshape
->
get_shape
(),
def_order
);
auto
idummy_correct_shape
=
std
::
make_shared
<
pattern
::
op
::
Label
>
(
iarg_reshape
->
get_element_type
(),
iinput_shape
);
new_args
.
push_back
(
idummy_correct_shape
);
}
auto
new_axis
=
order
.
at
(
n
->
get_concatenation_axis
());
auto
new_concat
=
std
::
make_shared
<
op
::
Concat
>
(
new_args
,
new_axis
);
//put back the original arguments
for
(
size_t
i
=
0
;
i
<
new_concat
->
get_input_size
();
i
++
)
{
ngraph
::
replace_node
(
new_args
.
at
(
i
),
n
->
get_argument
(
i
));
}
NGRAPH_DEBUG
<<
"Replacing "
<<
n
->
get_name
()
<<
" with "
<<
new_concat
->
get_name
();
ngraph
::
replace_node
(
n
,
new_concat
);
auto
new_reshape
=
std
::
make_shared
<
op
::
Reshape
>
(
new_concat
,
order
,
n
->
get_shape
());
NGRAPH_DEBUG
<<
"Propagating "
<<
describe_reshape
(
new_reshape
)
<<
" for "
<<
n
->
get_name
();
reorders
[
new_concat
]
=
new_reshape
;
}
static
void
sink_dequantize
(
std
::
shared_ptr
<
op
::
Dequantize
>
dequantize
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
...
...
@@ -396,32 +472,6 @@ static void sink_dequantize(std::shared_ptr<op::Dequantize> dequantize,
reorders
[
new_dequantize
]
=
arg_reshape
;
}
static
void
materialize_shapes
(
std
::
shared_ptr
<
Node
>
n
,
ReshapeMap
&
reorders
,
std
::
set
<
std
::
shared_ptr
<
Node
>>&
reshapes_to_delete
)
{
//skip multiple output nodes and deal with GOEs exclusively
if
(
n
->
get_outputs
().
size
()
>
1
)
{
return
;
}
for
(
size_t
i
=
0
;
i
<
n
->
get_arguments
().
size
();
i
++
)
{
//materialize all pending reshapes, flush pending reshapes
auto
arg
=
n
->
get_argument
(
i
);
if
(
reorders
.
count
(
arg
)
!=
0
)
{
NGRAPH_DEBUG
<<
"Materializing "
<<
describe_reshape
(
reorders
.
at
(
arg
))
<<
" for "
<<
arg
->
get_name
();
mark_reshape_for_deletion
(
reorders
.
at
(
arg
),
reshapes_to_delete
);
insert_reshape
(
n
,
reorders
.
at
(
arg
),
i
);
//no swimming up
}
}
reorders
[
n
]
=
create_default_reshape
(
n
);
}
//The goal of ReshapeSinking is to remove
//round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
//around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
...
...
@@ -493,6 +543,10 @@ bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Funct
{
sink_pad
(
pad
,
reorders
,
reshapes_to_delete
);
}
else
if
(
auto
concat
=
std
::
dynamic_pointer_cast
<
op
::
Concat
>
(
n
))
{
sink_concat
(
concat
,
reorders
,
reshapes_to_delete
);
}
else
{
materialize_shapes
(
n
,
reorders
,
reshapes_to_delete
);
...
...
test/reshape_sinking.cpp
View file @
bf365b12
...
...
@@ -203,3 +203,76 @@ TEST(reshape_sinking, slice_pad)
size_t
before_after
=
count_ops_of_type
<
op
::
Reshape
>
(
f
);
ASSERT_LE
(
before_after
,
before_count
);
}
TEST
(
reshape_sinking
,
concat
)
{
Shape
shape
{};
Shape
shape_w
{
1
,
1
,
1
,
1
};
Shape
shape_x
{
1
,
3
,
3
,
1
};
Shape
shape_b
{
1
,
3
,
3
,
1
};
Shape
r_shape
{
1
,
3
,
3
,
2
};
auto
B_
=
op
::
Constant
::
create
(
element
::
f32
,
shape_w
,
{
3
});
auto
B
=
make_shared
<
op
::
Reshape
>
(
B_
,
AxisVector
{
3
,
2
,
0
,
1
},
Shape
{
1
,
1
,
1
,
1
});
/* nchw */
auto
A_
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_x
);
auto
A
=
make_shared
<
op
::
Reshape
>
(
A_
,
AxisVector
{
0
,
3
,
1
,
2
},
Shape
{
1
,
1
,
3
,
3
});
/* nchw */
auto
C
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{
1
},
{
2
});
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
r_shape
);
auto
conv
=
make_shared
<
op
::
Convolution
>
(
A
,
B
,
Strides
{
1
,
1
},
Strides
{
1
,
1
},
CoordinateDiff
{
0
,
0
},
CoordinateDiff
{
0
,
0
},
Strides
{
1
,
1
});
auto
reshape_conv
=
make_shared
<
op
::
Reshape
>
(
conv
,
AxisVector
{
0
,
2
,
3
,
1
},
Shape
{
1
,
3
,
3
,
1
});
/* nhwc */
auto
broadcast
=
make_shared
<
op
::
Broadcast
>
(
C
,
reshape_conv
->
get_shape
(),
AxisSet
{
0
,
1
,
2
});
auto
add
=
broadcast
+
reshape_conv
;
auto
B1_
=
op
::
Constant
::
create
(
element
::
f32
,
shape_w
,
{
3
});
auto
B1
=
make_shared
<
op
::
Reshape
>
(
B1_
,
AxisVector
{
3
,
2
,
0
,
1
},
Shape
{
1
,
1
,
1
,
1
});
auto
A1_
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_x
);
auto
A1
=
make_shared
<
op
::
Reshape
>
(
A1_
,
AxisVector
{
0
,
3
,
1
,
2
},
Shape
{
1
,
1
,
3
,
3
});
auto
C1
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{
1
},
{
2
});
auto
R1
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
r_shape
);
auto
conv1
=
make_shared
<
op
::
Convolution
>
(
A1
,
B1
,
Strides
{
1
,
1
},
Strides
{
1
,
1
},
CoordinateDiff
{
0
,
0
},
CoordinateDiff
{
0
,
0
},
Strides
{
1
,
1
});
auto
reshape_conv1
=
make_shared
<
op
::
Reshape
>
(
conv1
,
AxisVector
{
0
,
2
,
3
,
1
},
Shape
{
1
,
3
,
3
,
1
});
auto
broadcast1
=
make_shared
<
op
::
Broadcast
>
(
C1
,
reshape_conv
->
get_shape
(),
AxisSet
{
0
,
1
,
2
});
auto
add1
=
broadcast1
+
reshape_conv1
;
auto
concat
=
make_shared
<
op
::
Concat
>
(
NodeVector
{
add
,
add1
},
3
);
auto
relu
=
make_shared
<
op
::
Relu
>
(
concat
);
auto
reshape_relu
=
make_shared
<
op
::
Reshape
>
(
relu
,
AxisVector
{
0
,
3
,
1
,
2
},
Shape
{
1
,
2
,
3
,
3
});
/* nchw */
auto
B2_
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{
1
,
1
,
2
,
1
},
{
2
});
auto
B2
=
make_shared
<
op
::
Reshape
>
(
B2_
,
AxisVector
{
3
,
2
,
0
,
1
},
Shape
{
1
,
2
,
1
,
1
});
auto
conv2
=
make_shared
<
op
::
Convolution
>
(
reshape_relu
,
B2
,
Strides
{
1
,
1
},
Strides
{
1
,
1
},
CoordinateDiff
{
0
,
0
},
CoordinateDiff
{
0
,
0
},
Strides
{
1
,
1
});
auto
reshape_conv2
=
make_shared
<
op
::
Reshape
>
(
conv2
,
AxisVector
{
0
,
2
,
3
,
1
},
Shape
{
1
,
3
,
3
,
1
});
/* nhwc */
auto
f
=
make_shared
<
Function
>
(
reshape_conv2
,
ParameterVector
{
A_
,
A1_
});
pass
::
Manager
pass_manager
;
size_t
before_count
=
count_ops_of_type
<
op
::
Reshape
>
(
f
);
pass_manager
.
register_pass
<
pass
::
VisualizeTree
>
(
"before.pdf"
);
pass_manager
.
register_pass
<
pass
::
ReshapeSinking
>
();
pass_manager
.
register_pass
<
pass
::
ReshapeElimination
>
();
pass_manager
.
register_pass
<
pass
::
CommonSubexpressionElimination
>
();
pass_manager
.
register_pass
<
pass
::
VisualizeTree
>
(
"after.pdf"
);
pass_manager
.
run_passes
(
f
);
size_t
before_after
=
count_ops_of_type
<
op
::
Reshape
>
(
f
);
ASSERT_LE
(
before_after
,
before_count
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment