Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
97a44e27
Commit
97a44e27
authored
5 years ago
by
Adam Procter
Committed by
Scott Cyphers
5 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add support for ArgMin and ArgMax to ZeroDimTensorElimination (#3022)
parent
c555b36a
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
51 additions
and
0 deletions
+51
-0
argmax.cpp
src/ngraph/op/argmax.cpp
+7
-0
argmax.hpp
src/ngraph/op/argmax.hpp
+2
-0
argmin.cpp
src/ngraph/op/argmin.cpp
+7
-0
argmin.hpp
src/ngraph/op/argmin.hpp
+2
-0
index_reduction.cpp
src/ngraph/op/util/index_reduction.cpp
+1
-0
zero_dim_tensor_elimination.cpp
test/zero_dim_tensor_elimination.cpp
+32
-0
No files found.
src/ngraph/op/argmax.cpp
View file @
97a44e27
...
...
@@ -37,3 +37,10 @@ shared_ptr<Node> op::ArgMax::copy_with_new_args(const NodeVector& new_args) cons
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
ArgMax
>
(
new_args
.
at
(
0
),
m_axis
,
this
->
get_element_type
());
}
std
::
shared_ptr
<
Node
>
op
::
ArgMax
::
get_default_value
()
const
{
// Choice of value here is arbitrary, because validation should be rejecting cases where the
// axis of reduction has size zero.
return
ngraph
::
make_constant_from_string
(
"0"
,
get_element_type
(),
get_shape
());
}
This diff is collapsed.
Click to expand it.
src/ngraph/op/argmax.hpp
View file @
97a44e27
...
...
@@ -42,6 +42,8 @@ namespace ngraph
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
std
::
shared_ptr
<
Node
>
get_default_value
()
const
override
;
};
}
}
This diff is collapsed.
Click to expand it.
src/ngraph/op/argmin.cpp
View file @
97a44e27
...
...
@@ -36,3 +36,10 @@ shared_ptr<Node> op::ArgMin::copy_with_new_args(const NodeVector& new_args) cons
check_new_args_count
(
this
,
new_args
);
return
make_shared
<
ArgMin
>
(
new_args
.
at
(
0
),
m_axis
,
this
->
get_element_type
());
}
std
::
shared_ptr
<
Node
>
op
::
ArgMin
::
get_default_value
()
const
{
// Choice of value here is arbitrary, because validation should be rejecting cases where the
// axis of reduction has size zero.
return
ngraph
::
make_constant_from_string
(
"0"
,
get_element_type
(),
get_shape
());
}
This diff is collapsed.
Click to expand it.
src/ngraph/op/argmin.hpp
View file @
97a44e27
...
...
@@ -43,6 +43,8 @@ namespace ngraph
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
virtual
std
::
shared_ptr
<
Node
>
get_default_value
()
const
override
;
};
}
}
This diff is collapsed.
Click to expand it.
src/ngraph/op/util/index_reduction.cpp
View file @
97a44e27
...
...
@@ -72,6 +72,7 @@ void op::util::IndexReduction::set_index_element_type(const element::Type& index
void
op
::
util
::
IndexReduction
::
validate_and_infer_types
()
{
// TODO(amprocte): Should reject if size of reduction axis is zero.
const
PartialShape
&
arg_shape
=
get_input_partial_shape
(
0
);
Rank
rank
=
arg_shape
.
rank
();
...
...
This diff is collapsed.
Click to expand it.
test/zero_dim_tensor_elimination.cpp
View file @
97a44e27
...
...
@@ -192,6 +192,38 @@ TEST(zero_dim_tensor_elimination, zero_const_slice)
EXPECT_EQ
(
count_ops_of_type
<
op
::
Slice
>
(
f
),
0
);
}
TEST
(
zero_dim_tensor_elimination
,
zero_argmax
)
{
auto
A
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
0
,
2
,
3
});
auto
argmax
=
make_shared
<
op
::
ArgMax
>
(
A
,
1
,
element
::
i32
);
auto
f
=
std
::
make_shared
<
Function
>
(
NodeVector
{
argmax
},
ParameterVector
{
A
});
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
VisualizeTree
>
(
"zero_argmax_before.png"
);
pass_manager
.
register_pass
<
ngraph
::
pass
::
ZeroDimTensorElimination
>
();
pass_manager
.
register_pass
<
pass
::
VisualizeTree
>
(
"zero_argmax_after.png"
);
EXPECT_EQ
(
count_ops_of_type
<
op
::
ArgMax
>
(
f
),
1
);
pass_manager
.
run_passes
(
f
);
EXPECT_EQ
(
count_ops_of_type
<
op
::
ArgMax
>
(
f
),
0
);
EXPECT_EQ
(
f
->
get_results
().
at
(
0
)
->
get_shape
(),
(
Shape
{
0
,
3
}));
}
TEST
(
zero_dim_tensor_elimination
,
zero_argmin
)
{
auto
A
=
std
::
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
0
,
2
,
3
});
auto
argmin
=
make_shared
<
op
::
ArgMin
>
(
A
,
1
,
element
::
i32
);
auto
f
=
std
::
make_shared
<
Function
>
(
NodeVector
{
argmin
},
ParameterVector
{
A
});
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
VisualizeTree
>
(
"zero_argmin_before.png"
);
pass_manager
.
register_pass
<
ngraph
::
pass
::
ZeroDimTensorElimination
>
();
pass_manager
.
register_pass
<
pass
::
VisualizeTree
>
(
"zero_argmin_after.png"
);
EXPECT_EQ
(
count_ops_of_type
<
op
::
ArgMin
>
(
f
),
1
);
pass_manager
.
run_passes
(
f
);
EXPECT_EQ
(
count_ops_of_type
<
op
::
ArgMin
>
(
f
),
0
);
EXPECT_EQ
(
f
->
get_results
().
at
(
0
)
->
get_shape
(),
(
Shape
{
0
,
3
}));
}
TEST
(
zero_dim_tensor_elimination
,
pass_property
)
{
auto
pass
=
std
::
make_shared
<
ngraph
::
pass
::
ZeroDimTensorElimination
>
();
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment