Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
ea40bc41
Unverified
Commit
ea40bc41
authored
6 years ago
by
Adam Procter
Committed by
GitHub
6 years ago
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #1708 from NervanaSystems/aprocter/cherry-pick-1663
Cherry-pick "zero dim elem fix (#1663)"
parents
f3c88459
427bcc1f
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
3 deletions
+18
-3
zero_dim_tensor_elimination.cpp
src/ngraph/pass/zero_dim_tensor_elimination.cpp
+18
-3
No files found.
src/ngraph/pass/zero_dim_tensor_elimination.cpp
View file @
ea40bc41
...
@@ -37,7 +37,9 @@ static bool has_zero_dim(std::shared_ptr<Node> node)
...
@@ -37,7 +37,9 @@ static bool has_zero_dim(std::shared_ptr<Node> node)
{
{
throw
ngraph_error
(
"has_zero_dim is called on multi-output op"
);
throw
ngraph_error
(
"has_zero_dim is called on multi-output op"
);
}
}
return
shape_size
(
node
->
get_shape
())
==
0
;
const
auto
&
shape
=
node
->
get_shape
();
return
std
::
find
(
shape
.
begin
(),
shape
.
end
(),
0
)
!=
shape
.
end
();
}
}
static
bool
verify_no_internal_zero_length_ops
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
static
bool
verify_no_internal_zero_length_ops
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
...
@@ -75,6 +77,7 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
...
@@ -75,6 +77,7 @@ static bool verify_no_internal_zero_length_ops(std::shared_ptr<ngraph::Function>
bool
ngraph
::
pass
::
ZeroDimTensorElimination
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
bool
ngraph
::
pass
::
ZeroDimTensorElimination
::
run_on_function
(
std
::
shared_ptr
<
ngraph
::
Function
>
f
)
{
{
bool
replaced
=
false
;
bool
replaced
=
false
;
auto
cvals
=
std
::
vector
<
std
::
string
>
(
0
);
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar op
// as an internal node (i.e. a node that isn't an argument to `op::Result`)
// as an internal node (i.e. a node that isn't an argument to `op::Result`)
for
(
auto
n
:
f
->
get_ordered_ops
())
for
(
auto
n
:
f
->
get_ordered_ops
())
...
@@ -93,7 +96,6 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
...
@@ -93,7 +96,6 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
{
{
// we don't have to create constants every time but this is the easiest
// we don't have to create constants every time but this is the easiest
// and it's CSE's job to eliminate the same ones
// and it's CSE's job to eliminate the same ones
auto
cvals
=
std
::
vector
<
std
::
string
>
(
0
);
auto
constant
=
auto
constant
=
std
::
make_shared
<
op
::
Constant
>
(
n
->
get_element_type
(),
n
->
get_shape
(),
cvals
);
std
::
make_shared
<
op
::
Constant
>
(
n
->
get_element_type
(),
n
->
get_shape
(),
cvals
);
replace_node
(
n
,
constant
);
replace_node
(
n
,
constant
);
...
@@ -102,8 +104,21 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
...
@@ -102,8 +104,21 @@ bool ngraph::pass::ZeroDimTensorElimination::run_on_function(std::shared_ptr<ngr
continue
;
continue
;
}
}
if
(
n
->
get_inputs
().
size
()
==
0
)
{
continue
;
}
auto
arg
=
n
->
get_inputs
().
at
(
0
).
get_output
().
get_node
();
if
(
arg
->
get_outputs
().
size
()
!=
1
||
!
has_zero_dim
(
arg
))
{
continue
;
}
auto
new_node
=
n
->
get_default_value
();
auto
new_node
=
n
->
get_default_value
();
if
(
!
new_node
||
!
has_zero_dim
(
n
->
get_argument
(
0
)))
if
(
!
new_node
)
{
{
continue
;
continue
;
}
}
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment