Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
0c181e9d
Commit
0c181e9d
authored
Sep 20, 2019
by
Ewa Tusień
Committed by
Scott Cyphers
Sep 20, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ONNX] Add support for negative axes (#3643)
parent
1a5288ab
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
163 additions
and
99 deletions
+163
-99
concat.cpp
src/ngraph/frontend/onnx_import/op/concat.cpp
+1
-5
flatten.cpp
src/ngraph/frontend/onnx_import/op/flatten.cpp
+6
-5
gather.hpp
src/ngraph/frontend/onnx_import/op/gather.hpp
+3
-5
hardmax.cpp
src/ngraph/frontend/onnx_import/op/hardmax.cpp
+2
-4
lp_norm.cpp
src/ngraph/frontend/onnx_import/op/lp_norm.cpp
+4
-6
mean_variance_normalization.cpp
...h/frontend/onnx_import/op/mean_variance_normalization.cpp
+5
-2
onehot.cpp
src/ngraph/frontend/onnx_import/op/onehot.cpp
+11
-10
reverse_sequence.cpp
src/ngraph/frontend/onnx_import/op/reverse_sequence.cpp
+9
-4
softmax.cpp
src/ngraph/frontend/onnx_import/op/softmax.cpp
+4
-12
split.cpp
src/ngraph/frontend/onnx_import/op/split.cpp
+5
-2
squeeze.cpp
src/ngraph/frontend/onnx_import/op/squeeze.cpp
+6
-10
topk.cpp
src/ngraph/frontend/onnx_import/op/topk.cpp
+4
-10
common.cpp
src/ngraph/frontend/onnx_import/utils/common.cpp
+47
-0
common.hpp
src/ngraph/frontend/onnx_import/utils/common.hpp
+43
-18
reduction.cpp
src/ngraph/frontend/onnx_import/utils/reduction.cpp
+8
-4
reduction.hpp
src/ngraph/frontend/onnx_import/utils/reduction.hpp
+5
-2
No files found.
src/ngraph/frontend/onnx_import/op/concat.cpp
View file @
0c181e9d
...
...
@@ -33,12 +33,8 @@ namespace ngraph
{
NodeVector
inputs
{
node
.
get_ng_inputs
()};
std
::
int64_t
axis
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
);
size_t
valid_axis
=
common
::
convert_negative_axis
(
axis
,
inputs
.
at
(
0
)
->
get_shape
().
size
());
ASSERT_VALID_ARGUMENT
(
node
,
valid_axis
>=
0
)
<<
"Incorrect value of axis attribute: "
<<
axis
;
common
::
validate_axis
(
node
,
axis
,
inputs
.
at
(
0
)
->
get_shape
().
size
());
return
{
std
::
make_shared
<
ngraph
::
op
::
Concat
>
(
inputs
,
valid_axis
)};
}
...
...
src/ngraph/frontend/onnx_import/op/flatten.cpp
View file @
0c181e9d
...
...
@@ -19,7 +19,7 @@
#include "exceptions.hpp"
#include "flatten.hpp"
#include "ngraph/builder/reshape.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
namespace
onnx_import
...
...
@@ -33,11 +33,12 @@ namespace ngraph
NodeVector
inputs
{
node
.
get_ng_inputs
()};
auto
data
=
inputs
.
at
(
0
);
auto
axis
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
1
);
auto
data_rank
=
data
->
get_shape
().
size
();
// Accepted range is [-r, r] where r = rank(input).
auto
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
data_rank
,
-
data_rank
,
data_rank
);
ASSERT_VALID_ARGUMENT
(
node
,
(
axis
>=
0
)
&&
(
axis
<=
data
->
get_shape
().
size
()))
<<
"provided 'axis' attribute is not valid."
;
return
{
ngraph
::
builder
::
flatten
(
data
,
axis
)};
return
{
ngraph
::
builder
::
flatten
(
data
,
valid_axis
)};
}
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/op/gather.hpp
View file @
0c181e9d
...
...
@@ -19,6 +19,7 @@
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/gather.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -34,12 +35,9 @@ namespace ngraph
auto
data
=
ng_inputs
.
at
(
0
);
auto
indices
=
ng_inputs
.
at
(
1
);
auto
axis
=
node
.
get_attribute_value
<
int64_t
>
(
"axis"
,
0
);
if
(
axis
<
0
)
{
axis
+=
data
->
get_shape
().
size
();
}
auto
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
data
->
get_shape
().
size
());
return
{
std
::
make_shared
<
ngraph
::
op
::
Gather
>
(
data
,
indices
,
axis
)};
return
{
std
::
make_shared
<
ngraph
::
op
::
Gather
>
(
data
,
indices
,
valid_
axis
)};
}
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/op/hardmax.cpp
View file @
0c181e9d
...
...
@@ -35,12 +35,10 @@ namespace ngraph
const
auto
&
input_shape
=
input
->
get_shape
();
auto
axis
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
1
);
ASSERT_VALID_ARGUMENT
(
node
,
axis
>=
0
&&
axis
<
input_shape
.
size
())
<<
"The provided axis value "
<<
axis
<<
" does not match the input tensor dimensions"
;
auto
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
input_shape
.
size
());
// reshape to 2D - "batch size" x "input feature dimensions" (NxD)
const
auto
coerced_tensor
=
ngraph
::
builder
::
flatten
(
input
,
axis
);
const
auto
coerced_tensor
=
ngraph
::
builder
::
flatten
(
input
,
valid_
axis
);
const
auto
&
coerced_shape
=
coerced_tensor
->
get_shape
();
const
std
::
shared_ptr
<
ngraph
::
Node
>
argmax_2d
=
...
...
src/ngraph/frontend/onnx_import/op/lp_norm.cpp
View file @
0c181e9d
...
...
@@ -23,6 +23,7 @@
#include "ngraph/builder/norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/divide.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -37,17 +38,14 @@ namespace ngraph
const
std
::
shared_ptr
<
ngraph
::
Node
>
data
{
node
.
get_ng_inputs
().
at
(
0
)};
std
::
int64_t
axis
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
)};
const
std
::
int64_t
p_norm
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"p"
,
2
)};
if
(
axis
<
0
)
{
axis
+=
data
->
get_shape
().
size
();
}
std
::
size_t
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
data
->
get_shape
().
size
());
ASSERT_VALID_ARGUMENT
(
node
,
p_norm
==
1
||
p_norm
==
2
)
<<
"Invalid `p` attribute value: "
<<
p_norm
<<
"Only normalization of 1st or 2nd order is supported."
;
const
AxisSet
reduction_axes
{
static_cast
<
std
::
size_t
>
(
axis
)
};
const
AxisSet
reduction_axes
{
valid_axis
};
std
::
shared_ptr
<
ngraph
::
Node
>
norm
=
ngraph
::
builder
::
lp_norm
(
data
,
reduction_axes
,
static_cast
<
std
::
size_t
>
(
p_norm
));
norm
=
std
::
make_shared
<
ngraph
::
op
::
Broadcast
>
(
...
...
src/ngraph/frontend/onnx_import/op/mean_variance_normalization.cpp
View file @
0c181e9d
...
...
@@ -19,6 +19,7 @@
#include "mean_variance_normalization.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/op/fused/mvn.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -47,9 +48,11 @@ namespace ngraph
NodeVector
mean_variance_normalization
(
const
Node
&
node
)
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
auto
axes
=
node
.
get_attribute_value
<
std
::
vector
<
size_t
>>
(
"axes"
,
{
0
,
2
,
3
});
auto
axes
=
node
.
get_attribute_value
<
std
::
vector
<
int64_t
>>
(
"axes"
,
{
0
,
2
,
3
});
std
::
vector
<
std
::
size_t
>
valid_axes
=
common
::
validate_axes
(
node
,
axes
,
data
->
get_shape
().
size
());
return
{
std
::
make_shared
<
ngraph
::
op
::
MVN
>
(
data
,
AxisSet
(
axes
))};
return
{
std
::
make_shared
<
ngraph
::
op
::
MVN
>
(
data
,
AxisSet
(
valid_
axes
))};
}
}
// namespace set_9
...
...
src/ngraph/frontend/onnx_import/op/onehot.cpp
View file @
0c181e9d
...
...
@@ -28,6 +28,7 @@
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "onehot.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -51,14 +52,13 @@ namespace ngraph
std
::
make_shared
<
ngraph
::
op
::
Slice
>
(
values
,
Coordinate
{
1
},
Coordinate
{
2
});
auto
axis
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
);
if
(
axis
<
0
)
{
axis
+=
indices_shape
.
size
()
+
1
;
}
ASSERT_VALID_ARGUMENT
(
node
,
(
axis
>=
0
)
&&
(
axis
<=
indices_shape
.
size
()))
<<
"invalid 'axis' attribute: "
<<
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
);
// Accepted range for axis is [-r-1, r] where r = rank(indices). Validate
// against rank+1.
std
::
size_t
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
indices_shape
.
size
()
+
1
,
-
indices_shape
.
size
()
-
1
,
indices_shape
.
size
());
auto
constant_depth
=
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
Constant
>
(
depth
);
...
...
@@ -74,10 +74,11 @@ namespace ngraph
// axis = 1
// depth = 10
// output_shape = (2, 10, 2)
output_shape
.
insert
(
std
::
next
(
std
::
begin
(
output_shape
),
axis
),
depth_value
);
output_shape
.
insert
(
std
::
next
(
std
::
begin
(
output_shape
),
valid_axis
),
depth_value
);
std
::
shared_ptr
<
ngraph
::
Node
>
one_hot
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
std
::
make_shared
<
ngraph
::
op
::
OneHot
>
(
indices
,
output_shape
,
axis
),
std
::
make_shared
<
ngraph
::
op
::
OneHot
>
(
indices
,
output_shape
,
valid_
axis
),
values
->
get_element_type
());
auto
broadcasted_values
=
ngraph
::
op
::
numpy_style_broadcast
({
one_hot
,
on_value
,
off_value
});
...
...
src/ngraph/frontend/onnx_import/op/reverse_sequence.cpp
View file @
0c181e9d
...
...
@@ -21,6 +21,7 @@
#include "ngraph/op/convert.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -40,22 +41,26 @@ namespace ngraph
node
.
get_ng_inputs
().
at
(
1
),
element
::
i32
);
const
auto
batch_axis
=
node
.
get_attribute_value
<
int64_t
>
(
"batch_axis"
,
1
);
std
::
size_t
valid_batch_axis
=
common
::
validate_axis
(
node
,
batch_axis
,
data
->
get_shape
().
size
());
const
auto
time_axis
=
node
.
get_attribute_value
<
int64_t
>
(
"time_axis"
,
0
);
std
::
size_t
valid_time_axis
=
common
::
validate_axis
(
node
,
time_axis
,
data
->
get_shape
().
size
());
NGRAPH_CHECK
(
batch_axis
==
0
||
batch_axis
==
1
,
NGRAPH_CHECK
(
valid_batch_axis
==
0
||
valid_
batch_axis
==
1
,
"Allowed values of the 'batch_axis' attribute for ReverseSequence "
"operator are 0 and 1"
);
NGRAPH_CHECK
(
time_axis
==
0
||
time_axis
==
1
,
NGRAPH_CHECK
(
valid_time_axis
==
0
||
valid_
time_axis
==
1
,
"Allowed values of the 'time_axis' attribute for ReverseSequence "
"operator are 0 and 1"
);
NGRAPH_CHECK
(
batch_axis
!=
time_axis
,
NGRAPH_CHECK
(
valid_batch_axis
!=
valid_
time_axis
,
"'batch_axis' and 'time_axis' attributes of the ReverseSequence "
"operator can't point to the same dimension"
);
return
{
std
::
make_shared
<
ngraph
::
op
::
ReverseSequence
>
(
data
,
sequence_lengths_i32
,
batch_axis
,
time_axis
)};
data
,
sequence_lengths_i32
,
valid_batch_axis
,
valid_
time_axis
)};
}
}
// namespace set_1
...
...
src/ngraph/frontend/onnx_import/op/softmax.cpp
View file @
0c181e9d
...
...
@@ -19,6 +19,7 @@
#include "exceptions.hpp"
#include "ngraph/op/softmax.hpp"
#include "softmax.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -35,22 +36,13 @@ namespace ngraph
auto
data_shape
=
data
->
get_shape
();
int
axis
=
node
.
get_attribute_value
<
int64_t
>
(
"axis"
,
1
);
if
(
axis
<
0
)
{
axis
=
data_shape
.
size
()
+
axis
;
}
ASSERT_VALID_ARGUMENT
(
node
,
axis
<
data_shape
.
size
())
<<
"provided 'axis' value:"
<<
axis
<<
" is out of input tensor dimensions range."
;
std
::
size_t
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
data_shape
.
size
());
// create vector of capacity data_dimensions - axis_divider position
std
::
vector
<
size_t
>
axes
(
data_shape
.
size
()
-
axis
);
std
::
iota
(
std
::
begin
(
axes
),
std
::
end
(
axes
),
axis
);
std
::
vector
<
size_t
>
axes
(
data_shape
.
size
()
-
valid_
axis
);
std
::
iota
(
std
::
begin
(
axes
),
std
::
end
(
axes
),
valid_
axis
);
return
{
std
::
make_shared
<
ngraph
::
op
::
Softmax
>
(
data
,
axes
)};
}
}
// namespace set_1
}
// namespace op
...
...
src/ngraph/frontend/onnx_import/op/split.cpp
View file @
0c181e9d
...
...
@@ -19,6 +19,7 @@
#include "ngraph/op/fused/split.hpp"
#include "op/split.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -33,13 +34,15 @@ namespace ngraph
const
auto
input
=
node
.
get_ng_inputs
().
at
(
0
);
const
auto
outputs_number
=
node
.
get_output_names
().
size
();
const
auto
axis
=
node
.
get_attribute_value
<
int64_t
>
(
"axis"
,
0
);
std
::
size_t
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
input
->
get_shape
().
size
());
try
{
const
auto
length_parts
=
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
"split"
);
const
auto
fused_split
=
std
::
make_shared
<
ngraph
::
op
::
Split
>
(
input
,
axis
,
length_parts
);
std
::
make_shared
<
ngraph
::
op
::
Split
>
(
input
,
valid_
axis
,
length_parts
);
return
fused_split
->
decompose_op
();
}
...
...
@@ -49,7 +52,7 @@ namespace ngraph
// the 'split' attribute - this means we should split the input tensor
// into same-length parts equal to the number of node outputs
const
auto
fused_split
=
std
::
make_shared
<
ngraph
::
op
::
Split
>
(
input
,
axis
,
outputs_number
);
std
::
make_shared
<
ngraph
::
op
::
Split
>
(
input
,
valid_
axis
,
outputs_number
);
return
fused_split
->
decompose_op
();
}
...
...
src/ngraph/frontend/onnx_import/op/squeeze.cpp
View file @
0c181e9d
...
...
@@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/squeeze.hpp"
#include "squeeze.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -32,17 +33,12 @@ namespace ngraph
NodeVector
squeeze
(
const
Node
&
node
)
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
auto
axes
=
node
.
get_attribute_value
<
std
::
vector
<
std
::
int64_t
>>
(
"axes"
,
{});
for
(
auto
axis
:
axes
)
{
ASSERT_VALID_ARGUMENT
(
node
,
axis
>=
0
)
<<
"provided axes attribute is invalid. Only non-negative "
<<
"integers are allowed, got "
<<
axis
<<
"."
;
}
std
::
vector
<
std
::
int64_t
>
axes
=
node
.
get_attribute_value
<
std
::
vector
<
std
::
int64_t
>>
(
"axes"
,
{});
std
::
vector
<
std
::
size_t
>
valid_axes
=
common
::
validate_axes
(
node
,
axes
,
data
->
get_shape
().
size
());
auto
axes_node
=
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
element
::
u64
,
Shape
{
axes
.
size
()},
axes
);
element
::
u64
,
Shape
{
valid_axes
.
size
()},
valid_
axes
);
return
{
std
::
make_shared
<
ngraph
::
op
::
Squeeze
>
(
data
,
axes_node
)};
}
...
...
src/ngraph/frontend/onnx_import/op/topk.cpp
View file @
0c181e9d
...
...
@@ -23,6 +23,7 @@
#include "ngraph/op/topk.hpp"
#include "ngraph/type/element_type.hpp"
#include "topk.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -35,21 +36,14 @@ namespace ngraph
NodeVector
topk
(
const
Node
&
node
)
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
std
::
int64_t
axis
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
)};
std
::
int64_t
k
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"k"
)};
auto
num_dimensions
=
data
->
get_shape
().
size
();
if
(
axis
<
0
)
{
axis
+=
num_dimensions
;
}
ASSERT_VALID_ARGUMENT
(
node
,
axis
<
num_dimensions
)
<<
"`axis` parameter is out of range: "
<<
axis
;
std
::
int64_t
axis
{
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
-
1
)};
std
::
int64_t
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
num_dimensions
);
std
::
shared_ptr
<
ngraph
::
Node
>
top_k
=
std
::
make_shared
<
ngraph
::
op
::
TopK
>
(
data
,
axis
,
element
::
i64
,
k
);
std
::
make_shared
<
ngraph
::
op
::
TopK
>
(
data
,
valid_
axis
,
element
::
i64
,
k
);
std
::
shared_ptr
<
ngraph
::
Node
>
indices
=
std
::
make_shared
<
ngraph
::
op
::
GetOutputElement
>
(
top_k
,
0
);
...
...
src/ngraph/frontend/onnx_import/utils/common.cpp
View file @
0c181e9d
...
...
@@ -46,6 +46,53 @@ namespace ngraph
static_cast
<
onnx
::
TensorProto_DataType
>
(
onnx_type
)));
}
std
::
size_t
validate_axis
(
const
ngraph
::
onnx_import
::
Node
&
node
,
std
::
int64_t
axis
,
std
::
int64_t
tensor_rank
)
{
// Accepted range of value for axis is [-tensor_rank, tensor_rank-1].
return
validate_axis
(
node
,
axis
,
tensor_rank
,
-
tensor_rank
,
tensor_rank
-
1
);
}
std
::
size_t
validate_axis
(
const
ngraph
::
onnx_import
::
Node
&
node
,
std
::
int64_t
axis
,
std
::
int64_t
tensor_rank
,
std
::
int64_t
axis_range_min
,
std
::
int64_t
axis_range_max
)
{
// Accepted range of value for axis is [axis_range_min, axis_range_max].
NGRAPH_CHECK
(((
axis
>=
axis_range_min
)
&&
(
axis
<=
axis_range_max
)),
node
.
get_description
(),
"Parameter axis "
,
axis
,
" out of the tensor rank [-"
,
axis_range_min
,
", "
,
axis_range_max
,
"]."
);
if
(
axis
<
0
)
{
axis
=
axis
+
tensor_rank
;
}
return
static_cast
<
size_t
>
(
axis
);
}
std
::
vector
<
std
::
size_t
>
validate_axes
(
const
ngraph
::
onnx_import
::
Node
&
node
,
std
::
vector
<
std
::
int64_t
>
axes
,
std
::
int64_t
tensor_rank
)
{
std
::
vector
<
std
::
size_t
>
new_axes
;
for
(
auto
a
:
axes
)
{
new_axes
.
push_back
(
validate_axis
(
node
,
a
,
tensor_rank
));
}
return
new_axes
;
}
}
// namespace common
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/utils/common.hpp
View file @
0c181e9d
...
...
@@ -25,6 +25,7 @@
#include <type_traits> // std::enable_if
#include <vector>
#include "core/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
...
...
@@ -67,28 +68,52 @@ namespace ngraph
return
range
;
}
/// \brief Handle
negative axis value
.
/// \brief Handle
out of range axis
.
///
/// \param[in] axis The requested axis value.
/// \param[in] tensor_dim The corresponding tensor dimensionality.
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \tparam T Provided axis value type.
/// \return Checking if axis is in range [-tensor_rank, tensor_rank-1], otherwise
/// returns error.
/// If negative axis, it counts from the last to the first axis, by adding
/// tensor_rank to axis.
///
/// \return If negative axis, then return sum of tensor dimension and axis.
std
::
size_t
validate_axis
(
const
ngraph
::
onnx_import
::
Node
&
node
,
std
::
int64_t
axis
,
std
::
int64_t
tensor_rank
);
/// \brief Handle out of range axis.
///
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
,
int
>::
type
=
0
>
std
::
int64_t
convert_negative_axis
(
T
axis
,
std
::
size_t
tensor_dim
)
{
if
(
axis
>=
0
)
{
return
std
::
min
(
axis
,
static_cast
<
T
>
(
tensor_dim
));
}
else
{
return
static_cast
<
std
::
int64_t
>
(
tensor_dim
)
+
axis
;
}
}
/// \param[in] node The node with requested axis.
/// \param[in] axis The requested axis value.
/// \param[in] tensor_rank The corresponding tensor rank.
/// \param[in] axis_range_min The min value of accepted range for axis.
/// \param[in] axis_range_max The max value of accepted range for axis.
///
/// \return Checking if axis is in range [axis_range_min, axis_range_max], otherwise
/// returns error.
//// If negative axis, it counts from the last to the first axis, by adding
/// tensor_rank to axis.
///
std
::
size_t
validate_axis
(
const
ngraph
::
onnx_import
::
Node
&
node
,
std
::
int64_t
axis
,
std
::
int64_t
tensor_rank
,
std
::
int64_t
axis_range_min
,
std
::
int64_t
axis_range_max
);
/// \brief Handle out of range axes in vector.
///
/// \param[in] node The node with requested axes.
/// \param[in] axes The requested vector of axes.
/// \param[in] tensor_rank The corresponding tensor rank.
///
/// \return If any negative axis in vector, it counts from the last to the first
/// axis, by adding tensor_rank to axis.
///
std
::
vector
<
std
::
size_t
>
validate_axes
(
const
ngraph
::
onnx_import
::
Node
&
node
,
std
::
vector
<
std
::
int64_t
>
axes
,
std
::
int64_t
tensor_rank
);
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
...
...
src/ngraph/frontend/onnx_import/utils/reduction.cpp
View file @
0c181e9d
...
...
@@ -32,13 +32,17 @@ namespace ngraph
AxisSet
get_reduction_axes
(
const
Node
&
node
)
{
auto
reduction_axes
=
node
.
get_attribute_value
<
std
::
vector
<
std
::
size_t
>>
(
"axes"
,
{});
node
.
get_attribute_value
<
std
::
vector
<
std
::
int64_t
>>
(
"axes"
,
{});
std
::
vector
<
std
::
size_t
>
valid_reduction_axes
=
common
::
validate_axes
(
node
,
reduction_axes
,
node
.
get_ng_inputs
().
at
(
0
)
->
get_shape
().
size
());
if
(
reduction_axes
.
empty
())
{
reduction_axes
=
onnx_import
::
common
::
get_monotonic_range
<
std
::
size_t
>
(
node
.
get_ng_inputs
().
at
(
0
)
->
get_shape
().
size
());
valid_reduction_axes
=
onnx_import
::
common
::
get_monotonic_range
<
std
::
size_t
>
(
node
.
get_ng_inputs
().
at
(
0
)
->
get_shape
().
size
());
}
return
AxisSet
{
reduction_axes
};
return
AxisSet
{
valid_
reduction_axes
};
}
}
// namespace detail
...
...
src/ngraph/frontend/onnx_import/utils/reduction.hpp
View file @
0c181e9d
...
...
@@ -26,6 +26,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/util.hpp"
#include "utils/common.hpp"
#include "utils/reshape.hpp"
namespace
ngraph
...
...
@@ -64,8 +65,10 @@ namespace ngraph
auto
axis
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"axis"
,
0
);
auto
keepdims
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"keepdims"
,
1
);
auto
input_node
=
node
.
get_ng_inputs
().
at
(
0
);
auto
valid_axis
=
common
::
validate_axis
(
node
,
axis
,
input_node
->
get_shape
().
size
());
auto
op_node
=
std
::
make_shared
<
IndexReduction
>
(
input_node
,
axis
,
element
::
i64
);
auto
op_node
=
std
::
make_shared
<
IndexReduction
>
(
input_node
,
valid_axis
,
element
::
i64
);
if
(
keepdims
==
0
)
{
...
...
@@ -76,7 +79,7 @@ namespace ngraph
auto
convert_node
=
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
op_node
,
element
::
f32
);
auto
output_shape
=
input_node
->
get_shape
();
output_shape
.
at
(
axis
)
=
1
;
output_shape
.
at
(
valid_
axis
)
=
1
;
auto
reshape_node
=
std
::
make_shared
<
ngraph
::
op
::
Reshape
>
(
convert_node
,
ngraph
::
get_default_order
(
op_node
->
get_shape
().
size
()),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment