Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
7277a9fd
Commit
7277a9fd
authored
6 years ago
by
Nick Korovaiko
Committed by
Robert Kimball
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make sure Slice is reshaped if needed (#1803)
parent
b339ea71
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
1 deletion
+15
-1
cpu_mat_fusion.cpp
src/ngraph/runtime/cpu/pass/cpu_mat_fusion.cpp
+15
-1
No files found.
src/ngraph/runtime/cpu/pass/cpu_mat_fusion.cpp
View file @
7277a9fd
...
@@ -147,6 +147,13 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
...
@@ -147,6 +147,13 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
auto
matched_weight
=
matcher_v2
->
get_pattern_map
()[
W
]
->
get_argument
(
0
);
auto
matched_weight
=
matcher_v2
->
get_pattern_map
()[
W
]
->
get_argument
(
0
);
auto
matched_data
=
matcher_v2
->
get_pattern_map
()[
input_data
];
auto
matched_data
=
matcher_v2
->
get_pattern_map
()[
input_data
];
auto
matched_bias
=
matcher_v2
->
get_pattern_map
()[
b
]
->
get_argument
(
0
);
auto
matched_bias
=
matcher_v2
->
get_pattern_map
()[
b
]
->
get_argument
(
0
);
if
(
matcher_v2
->
get_match_root
()
->
get_shape
().
size
()
!=
2
&&
matcher_v2
->
get_match_root
()
->
get_shape
().
size
()
!=
3
)
{
NGRAPH_DEBUG
<<
"mat fusion (v2) root "
<<
matcher_v2
->
get_match_root
()
->
get_name
()
<<
" isn't 2D or 3D"
;
continue
;
}
map_weights_to_pattern
[
matched_weight
].
push_back
(
matcher_v2
->
get_match_root
());
map_weights_to_pattern
[
matched_weight
].
push_back
(
matcher_v2
->
get_match_root
());
map_weights_bias_to_data
[
std
::
make_pair
(
matched_weight
,
matched_bias
)].
push_back
(
map_weights_bias_to_data
[
std
::
make_pair
(
matched_weight
,
matched_bias
)].
push_back
(
matched_data
);
matched_data
);
...
@@ -248,8 +255,15 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
...
@@ -248,8 +255,15 @@ bool runtime::cpu::pass::CPURnnMatFusion::run_on_function(std::shared_ptr<Functi
size_t
end_index
=
batch_size
;
size_t
end_index
=
batch_size
;
for
(
auto
&
matched_root_node
:
map_weights_to_pattern
[
weights
])
for
(
auto
&
matched_root_node
:
map_weights_to_pattern
[
weights
])
{
{
auto
slice_node
=
std
::
make_shared
<
op
::
Slice
>
(
std
::
shared_ptr
<
Node
>
slice_node
=
std
::
make_shared
<
op
::
Slice
>
(
new_add_bias
,
Coordinate
{
start_index
,
0
},
Coordinate
{
end_index
,
shape_axis_1
});
new_add_bias
,
Coordinate
{
start_index
,
0
},
Coordinate
{
end_index
,
shape_axis_1
});
if
(
matched_root_node
->
get_shape
().
size
()
!=
2
)
{
NGRAPH_ASSERT
(
matched_root_node
->
get_shape
().
size
()
==
3
);
slice_node
=
std
::
make_shared
<
op
::
Reshape
>
(
slice_node
,
AxisVector
{
0
,
1
},
matched_root_node
->
get_shape
());
}
start_index
+=
batch_size
;
start_index
+=
batch_size
;
end_index
+=
batch_size
;
end_index
+=
batch_size
;
NGRAPH_DEBUG
<<
"Replacing op "
<<
matched_root_node
->
get_name
()
<<
" with "
NGRAPH_DEBUG
<<
"Replacing op "
<<
matched_root_node
->
get_name
()
<<
" with "
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment