Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
3f2cd153
Commit
3f2cd153
authored
6 years ago
by
Adam Rogowiec
Committed by
Robert Kimball
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Handle negative axis values in Concat op. (#2252)
parent
9940123b
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
6 deletions
+38
-6
onnx_import.cpp
python/pyngraph/onnx_import/onnx_import.cpp
+2
-3
concat.cpp
src/ngraph/frontend/onnx_import/op/concat.cpp
+12
-2
common.hpp
src/ngraph/frontend/onnx_import/utils/common.hpp
+24
-1
No files found.
python/pyngraph/onnx_import/onnx_import.cpp
View file @
3f2cd153
...
...
@@ -16,11 +16,10 @@
#include <istream>
#include <memory>
#include <string>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include <vector>
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/function.hpp"
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/frontend/onnx_import/op/concat.cpp
View file @
3f2cd153
...
...
@@ -14,8 +14,12 @@
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include "concat.hpp"
#include "exceptions.hpp"
#include "ngraph/op/concat.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -28,9 +32,15 @@ namespace ngraph
NodeVector
concat
(
const
Node
&
node
)
{
NodeVector
inputs
{
node
.
get_ng_inputs
()};
auto
axis
=
node
.
get_attribute_value
<
int64_t
>
(
"axis"
);
std
::
int64_t
axis
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
);
size_t
valid_axis
=
common
::
convert_negative_axis
(
axis
,
inputs
.
at
(
0
)
->
get_shape
().
size
());
ASSERT_VALID_ARGUMENT
(
node
,
valid_axis
>=
0
)
<<
"Incorrect value of axis attribute: "
<<
axis
;
return
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
inputs
,
axis
)};
return
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
inputs
,
valid_
axis
)};
}
}
// namespace set_1
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/frontend/onnx_import/utils/common.hpp
View file @
3f2cd153
...
...
@@ -16,7 +16,7 @@
#pragma once
#include <cmath> // std::floor
#include <cmath> // std::floor
, std::min
#include <cstddef> // std::size_t
#include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared
...
...
@@ -135,6 +135,29 @@ namespace ngraph
return
node
;
}
/// \brief Handle negative axis value.
///
/// \param[in] axis The requested axis value.
/// \param[in] tensor_dim The corresponding tensor dimensionality.
///
/// \tparam T Provided axis value type.
///
/// \return If negative axis, then return sum of tensor dimension and axis.
///
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
,
int
>::
type
=
0
>
std
::
int64_t
convert_negative_axis
(
T
axis
,
std
::
size_t
tensor_dim
)
{
if
(
axis
>=
0
)
{
return
std
::
min
(
axis
,
static_cast
<
T
>
(
tensor_dim
));
}
else
{
return
static_cast
<
std
::
int64_t
>
(
tensor_dim
)
+
axis
;
}
}
}
// namespace common
}
// namespace onnx_import
}
// namespace ngraph
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment