Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
bb665f19
Unverified
Commit
bb665f19
authored
Feb 13, 2020
by
Scott Cyphers
Committed by
GitHub
Feb 13, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Avoid size_t issue in Gather (#4329)
parent
a05b4823
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
27 deletions
+46
-27
dimension.cpp
src/ngraph/dimension.cpp
+26
-0
dimension.hpp
src/ngraph/dimension.hpp
+7
-12
gather.cpp
src/ngraph/op/gather.cpp
+12
-13
gather.hpp
src/ngraph/op/gather.hpp
+1
-2
No files found.
src/ngraph/dimension.cpp
View file @
bb665f19
...
@@ -133,3 +133,29 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
...
@@ -133,3 +133,29 @@ bool Dimension::broadcast_merge(Dimension& dst, const Dimension d1, const Dimens
}
}
}
}
}
}
uint64_t
Dimension
::
get_length
()
const
{
if
(
is_dynamic
())
{
throw
std
::
invalid_argument
(
"Cannot get length of dynamic dimension"
);
}
if
(
m_dimension
<
0
)
{
throw
std
::
invalid_argument
(
"Cannot get_length of negative dimension"
);
}
return
m_dimension
;
}
Dimension
::
operator
size_t
()
const
{
if
(
is_dynamic
())
{
throw
std
::
invalid_argument
(
"Cannot convert dynamic dimension to size_t"
);
}
if
(
m_dimension
<
0
)
{
throw
std
::
invalid_argument
(
"Cannot convert negative dimension to size_t"
);
}
return
m_dimension
;
}
src/ngraph/dimension.hpp
View file @
bb665f19
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include <stddef.h>
#include <stddef.h>
#include <stdexcept>
#include <stdexcept>
#include "ngraph/deprecated.hpp"
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/ngraph_visibility.hpp"
namespace
ngraph
namespace
ngraph
...
@@ -61,18 +62,12 @@ namespace ngraph
...
@@ -61,18 +62,12 @@ namespace ngraph
/// \brief Convert this dimension to `size_t`. This dimension must be static and
/// \brief Convert this dimension to `size_t`. This dimension must be static and
/// non-negative.
/// non-negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative.
/// \throws std::invalid_argument If this dimension is dynamic or negative.
explicit
operator
size_t
()
const
explicit
operator
size_t
()
const
NGRAPH_DEPRECATED
(
"use get_length() instead"
);
{
if
(
is_dynamic
())
/// \brief Convert this dimension to `uint64_t`. This dimension must be static and
{
/// non-negative.
throw
std
::
invalid_argument
(
"Cannot convert dynamic dimension to size_t"
);
/// \throws std::invalid_argument If this dimension is dynamic or negative.
}
uint64_t
get_length
()
const
;
if
(
m_dimension
<
0
)
{
throw
std
::
invalid_argument
(
"Cannot convert negative dimension to size_t"
);
}
return
m_dimension
;
}
/// \brief Check whether this dimension represents the same scheme as the argument (both
/// \brief Check whether this dimension represents the same scheme as the argument (both
/// dynamic, or equal).
/// dynamic, or equal).
...
...
src/ngraph/op/gather.cpp
View file @
bb665f19
...
@@ -120,9 +120,9 @@ void op::v1::Gather::validate_and_infer_types()
...
@@ -120,9 +120,9 @@ void op::v1::Gather::validate_and_infer_types()
if
(
axis_rank
.
is_static
()
&&
axis_shape
.
is_static
())
if
(
axis_rank
.
is_static
()
&&
axis_shape
.
is_static
())
{
{
const
auto
axis_is_scalar
=
static_cast
<
size_t
>
(
axis_rank
)
==
0
;
const
auto
axis_is_scalar
=
axis_rank
.
get_length
(
)
==
0
;
const
auto
axis_has_one_elem
=
const
auto
axis_has_one_elem
=
static_cast
<
size_t
>
(
axis_rank
)
==
1
&&
static_cast
<
size_t
>
(
axis_shape
[
0
]
)
==
1
;
axis_rank
.
get_length
()
==
1
&&
axis_shape
[
0
].
get_length
(
)
==
1
;
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
axis_is_scalar
||
axis_has_one_elem
,
axis_is_scalar
||
axis_has_one_elem
,
"Axes input must be scalar or have 1 element (shape: "
,
"Axes input must be scalar or have 1 element (shape: "
,
...
@@ -134,7 +134,7 @@ void op::v1::Gather::validate_and_infer_types()
...
@@ -134,7 +134,7 @@ void op::v1::Gather::validate_and_infer_types()
if
(
input_rank
.
is_static
()
&&
axis
!=
AXIS_NOT_SET_VALUE
)
if
(
input_rank
.
is_static
()
&&
axis
!=
AXIS_NOT_SET_VALUE
)
{
{
NODE_VALIDATION_CHECK
(
this
,
NODE_VALIDATION_CHECK
(
this
,
axis
<
static_cast
<
size_t
>
(
input_rank
),
axis
<
input_rank
.
get_length
(
),
"The axis must => 0 and <= input_rank (axis: "
,
"The axis must => 0 and <= input_rank (axis: "
,
axis
,
axis
,
")."
);
")."
);
...
@@ -150,19 +150,18 @@ void op::v1::Gather::validate_and_infer_types()
...
@@ -150,19 +150,18 @@ void op::v1::Gather::validate_and_infer_types()
if
(
params_shape
.
rank
().
is_static
()
&&
indices_shape
.
rank
().
is_static
()
&&
if
(
params_shape
.
rank
().
is_static
()
&&
indices_shape
.
rank
().
is_static
()
&&
axis
!=
AXIS_NOT_SET_VALUE
)
axis
!=
AXIS_NOT_SET_VALUE
)
{
{
std
::
vector
<
Dimension
>
result_dims
(
static_cast
<
size_t
>
(
params_shape
.
rank
()
)
+
std
::
vector
<
Dimension
>
result_dims
(
params_shape
.
rank
().
get_length
(
)
+
static_cast
<
size_t
>
(
indices_shape
.
rank
()
)
-
1
);
indices_shape
.
rank
().
get_length
(
)
-
1
);
size
_t
i
=
0
;
uint64
_t
i
=
0
;
for
(;
i
<
static_cast
<
size_t
>
(
axis
)
;
i
++
)
for
(;
i
<
axis
;
i
++
)
{
{
result_dims
[
i
]
=
params_shape
[
i
];
result_dims
[
i
]
=
params_shape
[
i
];
}
}
for
(
size_t
j
=
0
;
j
<
static_cast
<
size_t
>
(
indices_shape
.
rank
()
);
i
++
,
j
++
)
for
(
uint64_t
j
=
0
;
j
<
indices_shape
.
rank
().
get_length
(
);
i
++
,
j
++
)
{
{
result_dims
[
i
]
=
indices_shape
[
j
];
result_dims
[
i
]
=
indices_shape
[
j
];
}
}
for
(
size_t
j
=
static_cast
<
size_t
>
(
axis
)
+
1
;
j
<
static_cast
<
size_t
>
(
params_shape
.
rank
());
for
(
uint64_t
j
=
axis
+
1
;
j
<
params_shape
.
rank
().
get_length
();
i
++
,
j
++
)
i
++
,
j
++
)
{
{
result_dims
[
i
]
=
params_shape
[
j
];
result_dims
[
i
]
=
params_shape
[
j
];
}
}
...
@@ -177,7 +176,7 @@ void op::v1::Gather::validate_and_infer_types()
...
@@ -177,7 +176,7 @@ void op::v1::Gather::validate_and_infer_types()
set_output_type
(
0
,
result_et
,
result_shape
);
set_output_type
(
0
,
result_et
,
result_shape
);
}
}
size
_t
op
::
v1
::
Gather
::
get_axis
()
const
int64
_t
op
::
v1
::
Gather
::
get_axis
()
const
{
{
int64_t
axis
=
AXIS_NOT_SET_VALUE
;
int64_t
axis
=
AXIS_NOT_SET_VALUE
;
auto
axes_input_node
=
input_value
(
AXIS
).
get_node_shared_ptr
();
auto
axes_input_node
=
input_value
(
AXIS
).
get_node_shared_ptr
();
...
@@ -190,10 +189,10 @@ size_t op::v1::Gather::get_axis() const
...
@@ -190,10 +189,10 @@ size_t op::v1::Gather::get_axis() const
const
auto
&
input_rank
=
get_input_partial_shape
(
PARAMS
).
rank
();
const
auto
&
input_rank
=
get_input_partial_shape
(
PARAMS
).
rank
();
if
(
input_rank
.
is_static
())
if
(
input_rank
.
is_static
())
{
{
axis
+=
static_cast
<
size_t
>
(
input_rank
);
axis
+=
input_rank
.
get_length
(
);
}
}
}
}
return
static_cast
<
size_t
>
(
axis
)
;
return
axis
;
}
}
void
op
::
v1
::
Gather
::
generate_adjoints
(
autodiff
::
Adjoints
&
/* adjoints */
,
void
op
::
v1
::
Gather
::
generate_adjoints
(
autodiff
::
Adjoints
&
/* adjoints */
,
...
...
src/ngraph/op/gather.hpp
View file @
bb665f19
...
@@ -67,8 +67,7 @@ namespace ngraph
...
@@ -67,8 +67,7 @@ namespace ngraph
const
Output
<
Node
>&
indices
,
const
Output
<
Node
>&
indices
,
const
Output
<
Node
>&
axis
);
const
Output
<
Node
>&
axis
);
size_t
get_version
()
const
override
{
return
1
;
}
int64_t
get_axis
()
const
;
size_t
get_axis
()
const
;
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment