Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
01698d7a
Unverified
Commit
01698d7a
authored
5 years ago
by
Tomasz Socha
Committed by
GitHub
5 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add attribute visitor for ops M-P (#4344)
parent
5d8c39e9
Show whitespace changes
Inline
Side-by-side
Showing
21 changed files
with
295 additions
and
0 deletions
+295
-0
psroi_pooling.cpp
src/ngraph/op/experimental/layers/psroi_pooling.cpp
+12
-0
psroi_pooling.hpp
src/ngraph/op/experimental/layers/psroi_pooling.hpp
+1
-0
matmul.cpp
src/ngraph/op/fused/matmul.cpp
+8
-0
matmul.hpp
src/ngraph/op/fused/matmul.hpp
+1
-0
mod.cpp
src/ngraph/op/fused/mod.cpp
+7
-0
mod.hpp
src/ngraph/op/fused/mod.hpp
+1
-0
normalize_l2.cpp
src/ngraph/op/fused/normalize_l2.cpp
+8
-0
normalize_l2.hpp
src/ngraph/op/fused/normalize_l2.hpp
+1
-0
prelu.cpp
src/ngraph/op/fused/prelu.cpp
+5
-0
prelu.hpp
src/ngraph/op/fused/prelu.hpp
+1
-0
max_pool.cpp
src/ngraph/op/max_pool.cpp
+12
-0
max_pool.hpp
src/ngraph/op/max_pool.hpp
+1
-0
negative.cpp
src/ngraph/op/negative.cpp
+5
-0
negative.hpp
src/ngraph/op/negative.hpp
+1
-0
non_max_suppression.cpp
src/ngraph/op/non_max_suppression.cpp
+31
-0
non_max_suppression.hpp
src/ngraph/op/non_max_suppression.hpp
+19
-0
one_hot.cpp
src/ngraph/op/one_hot.cpp
+7
-0
one_hot.hpp
src/ngraph/op/one_hot.hpp
+1
-0
pad.cpp
src/ngraph/op/pad.cpp
+7
-0
pad.hpp
src/ngraph/op/pad.hpp
+1
-0
attributes.cpp
test/attributes.cpp
+165
-0
No files found.
src/ngraph/op/experimental/layers/psroi_pooling.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "psroi_pooling.hpp"
#include "ngraph/attribute_visitor.hpp"
using
namespace
std
;
using
namespace
ngraph
;
...
...
@@ -40,6 +41,17 @@ op::PSROIPooling::PSROIPooling(const Output<Node>& input,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
PSROIPooling
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"output_dim"
,
m_output_dim
);
visitor
.
on_attribute
(
"group_size"
,
m_group_size
);
visitor
.
on_attribute
(
"spatial_scale"
,
m_spatial_scale
);
visitor
.
on_attribute
(
"mode"
,
m_mode
);
visitor
.
on_attribute
(
"spatial_bins_x"
,
m_spatial_bins_x
);
visitor
.
on_attribute
(
"spatial_bins_y"
,
m_spatial_bins_y
);
return
true
;
}
void
op
::
PSROIPooling
::
validate_and_infer_types
()
{
auto
input_et
=
get_input_element_type
(
0
);
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/experimental/layers/psroi_pooling.hpp
View file @
01698d7a
...
...
@@ -51,6 +51,7 @@ namespace ngraph
int
spatial_bins_y
,
const
std
::
string
&
mode
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/matmul.cpp
View file @
01698d7a
...
...
@@ -17,6 +17,7 @@
#include <numeric>
#include "matmul.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp"
...
...
@@ -37,6 +38,13 @@ op::MatMul::MatMul(const Output<Node>& A,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
MatMul
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"transpose_a"
,
m_transpose_a
);
visitor
.
on_attribute
(
"transpose_b"
,
m_transpose_b
);
return
true
;
}
void
op
::
MatMul
::
pre_validate_and_infer_types
()
{
element
::
Type
result_et
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/matmul.hpp
View file @
01698d7a
...
...
@@ -44,6 +44,7 @@ namespace ngraph
const
bool
&
transpose_a
=
0
,
const
bool
&
transpose_b
=
0
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/mod.cpp
View file @
01698d7a
...
...
@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/mod.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/convert.hpp"
...
...
@@ -35,6 +36,12 @@ op::v1::Mod::Mod(const Output<Node>& A,
{
}
bool
ngraph
::
op
::
v1
::
Mod
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"auto_broadcast"
,
m_auto_broadcast
);
return
true
;
}
NodeVector
op
::
v1
::
Mod
::
decompose_op
()
const
{
const
auto
dividend
=
make_shared
<
op
::
Abs
>
(
input_value
(
0
));
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/mod.hpp
View file @
01698d7a
...
...
@@ -43,6 +43,7 @@ namespace ngraph
const
Output
<
Node
>&
B
,
const
AutoBroadcastSpec
&
auto_broadcast
=
AutoBroadcastType
::
NUMPY
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/normalize_l2.cpp
View file @
01698d7a
...
...
@@ -16,6 +16,7 @@
#include <algorithm>
#include <iterator>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/constant.hpp"
...
...
@@ -39,6 +40,13 @@ op::NormalizeL2::NormalizeL2(const Output<Node>& data,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
NormalizeL2
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"eps"
,
m_eps
);
visitor
.
on_attribute
(
"eps_mode"
,
m_eps_mode
);
return
true
;
}
void
op
::
NormalizeL2
::
pre_validate_and_infer_types
()
{
auto
axes_node
=
input_value
(
1
).
get_node_shared_ptr
();
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/normalize_l2.hpp
View file @
01698d7a
...
...
@@ -52,6 +52,7 @@ namespace ngraph
float
eps
,
EpsMode
eps_mode
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
float
get_eps
()
const
{
return
m_eps
;
}
EpsMode
get_eps_mode
()
const
{
return
m_eps_mode
;
}
virtual
NodeVector
decompose_op
()
const
override
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/prelu.cpp
View file @
01698d7a
...
...
@@ -35,6 +35,11 @@ op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
PRelu
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
NodeVector
op
::
PRelu
::
decompose_op
()
const
{
auto
data
=
input_value
(
0
);
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/fused/prelu.hpp
View file @
01698d7a
...
...
@@ -42,6 +42,7 @@ namespace ngraph
/// \param slope Multipliers for negative values
PRelu
(
const
Output
<
Node
>&
data
,
const
Output
<
Node
>&
slope
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/max_pool.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/max_pool.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
...
...
@@ -303,6 +304,17 @@ op::v1::MaxPool::MaxPool(const Output<Node>& arg,
{
}
bool
ngraph
::
op
::
v1
::
MaxPool
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"strides"
,
m_strides
);
visitor
.
on_attribute
(
"pads_begin"
,
m_pads_begin
);
visitor
.
on_attribute
(
"pads_end"
,
m_pads_end
);
visitor
.
on_attribute
(
"kernel"
,
m_kernel
);
visitor
.
on_attribute
(
"rounding_type"
,
m_rounding_type
);
visitor
.
on_attribute
(
"auto_pad"
,
m_auto_pad
);
return
true
;
}
void
op
::
v1
::
MaxPool
::
validate_and_infer_types
()
{
if
(
0
==
m_strides
.
size
())
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/max_pool.hpp
View file @
01698d7a
...
...
@@ -247,6 +247,7 @@ namespace ngraph
const
Shape
&
kernel
,
op
::
RoundingType
rounding_mode
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
size_t
get_version
()
const
override
{
return
1
;
}
void
validate_and_infer_types
()
override
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/negative.cpp
View file @
01698d7a
...
...
@@ -27,6 +27,11 @@ op::Negative::Negative(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Negative
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Negative
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/negative.hpp
View file @
01698d7a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Negative
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/non_max_suppression.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/non_max_suppression.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
using
namespace
std
;
...
...
@@ -65,6 +66,13 @@ shared_ptr<Node> op::v1::NonMaxSuppression::copy_with_new_args(const NodeVector&
m_sort_result_descending
);
}
bool
ngraph
::
op
::
v1
::
NonMaxSuppression
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"box_encoding"
,
m_box_encoding
);
visitor
.
on_attribute
(
"sort_result_descending"
,
m_sort_result_descending
);
return
true
;
}
void
op
::
v1
::
NonMaxSuppression
::
validate_and_infer_types
()
{
const
auto
boxes_ps
=
get_input_partial_shape
(
0
);
...
...
@@ -157,3 +165,26 @@ int64_t op::v1::NonMaxSuppression::max_boxes_output_from_input() const
return
max_output_boxes
;
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>&
EnumNames
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>
(
"op::v1::NonMaxSuppression::BoxEncodingType"
,
{{
"corner"
,
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
::
CORNER
},
{
"center"
,
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
::
CENTER
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
This diff is collapsed.
Click to expand it.
src/ngraph/op/non_max_suppression.hpp
View file @
01698d7a
...
...
@@ -68,6 +68,7 @@ namespace ngraph
const
BoxEncodingType
box_encoding
=
BoxEncodingType
::
CORNER
,
const
bool
sort_result_descending
=
true
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
@@ -93,4 +94,22 @@ namespace ngraph
};
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>
:
public
EnumAttributeAdapterBase
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>
{
public
:
AttributeAdapter
(
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
&
value
)
:
EnumAttributeAdapterBase
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::v1::NonMaxSuppression::BoxEncodingType>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
This diff is collapsed.
Click to expand it.
src/ngraph/op/one_hot.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/one_hot.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/validation_util.hpp"
using
namespace
std
;
...
...
@@ -194,6 +195,12 @@ void op::v1::OneHot::validate_and_infer_types()
set_output_type
(
0
,
on_value_et
,
result_shape
);
}
bool
ngraph
::
op
::
v1
::
OneHot
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"axis"
,
m_axis
);
return
true
;
}
shared_ptr
<
Node
>
op
::
v1
::
OneHot
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/one_hot.hpp
View file @
01698d7a
...
...
@@ -98,6 +98,7 @@ namespace ngraph
const
Output
<
Node
>&
off_value
,
int64_t
axis
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
void
validate_and_infer_types
()
override
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/pad.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/pad.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
...
...
@@ -222,6 +223,12 @@ CoordinateDiff op::v1::Pad::get_pads_end() const
return
pads_end_coord
;
}
bool
ngraph
::
op
::
v1
::
Pad
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"pad_mode"
,
m_pad_mode
);
return
true
;
}
void
op
::
v1
::
Pad
::
validate_and_infer_types
()
{
element
::
Type
result_et
;
...
...
This diff is collapsed.
Click to expand it.
src/ngraph/op/pad.hpp
View file @
01698d7a
...
...
@@ -135,6 +135,7 @@ namespace ngraph
/// \brief Constructs a generic padding operation.
Pad
()
=
default
;
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
size_t
get_version
()
const
override
{
return
1
;
}
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
This diff is collapsed.
Click to expand it.
test/attributes.cpp
View file @
01698d7a
...
...
@@ -285,6 +285,171 @@ TEST(attributes, user_op)
EXPECT_EQ
(
g_oracle
->
get_ultra_parameters
(),
oracle
->
get_ultra_parameters
());
}
TEST
(
attributes
,
matmul_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
MatMul
>
();
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
0
,
2
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
0
});
bool
transpose_a
=
true
;
bool
transpose_b
=
true
;
auto
matmul
=
make_shared
<
opset1
::
MatMul
>
(
A
,
B
,
transpose_a
,
transpose_b
);
NodeBuilder
builder
(
matmul
);
auto
g_matmul
=
as_type_ptr
<
opset1
::
MatMul
>
(
builder
.
create
());
EXPECT_EQ
(
g_matmul
->
get_transpose_a
(),
matmul
->
get_transpose_a
());
EXPECT_EQ
(
g_matmul
->
get_transpose_b
(),
matmul
->
get_transpose_b
());
}
TEST
(
attributes
,
max_pool_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
MaxPool
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
64
,
3
,
5
});
auto
strides
=
Strides
{
2
};
auto
pads_begin
=
Shape
{
1
};
auto
pads_end
=
Shape
{
1
};
auto
kernel
=
Shape
{
1
};
auto
rounding_mode
=
op
::
RoundingType
::
FLOOR
;
auto
auto_pad
=
op
::
PadType
::
EXPLICIT
;
auto
max_pool
=
make_shared
<
opset1
::
MaxPool
>
(
data
,
strides
,
pads_begin
,
pads_end
,
kernel
,
rounding_mode
,
auto_pad
);
NodeBuilder
builder
(
max_pool
);
auto
g_max_pool
=
as_type_ptr
<
opset1
::
MaxPool
>
(
builder
.
create
());
EXPECT_EQ
(
g_max_pool
->
get_strides
(),
max_pool
->
get_strides
());
EXPECT_EQ
(
g_max_pool
->
get_pads_begin
(),
max_pool
->
get_pads_begin
());
EXPECT_EQ
(
g_max_pool
->
get_pads_end
(),
max_pool
->
get_pads_end
());
EXPECT_EQ
(
g_max_pool
->
get_kernel
(),
max_pool
->
get_kernel
());
EXPECT_EQ
(
g_max_pool
->
get_rounding_type
(),
max_pool
->
get_rounding_type
());
EXPECT_EQ
(
g_max_pool
->
get_auto_pad
(),
max_pool
->
get_auto_pad
());
}
TEST
(
attributes
,
mod_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Mod
>
();
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
0
,
2
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
0
});
auto
auto_broadcast
=
op
::
AutoBroadcastType
::
NUMPY
;
auto
mod
=
make_shared
<
opset1
::
Mod
>
(
A
,
B
,
auto_broadcast
);
NodeBuilder
builder
(
mod
);
auto
g_mod
=
as_type_ptr
<
opset1
::
Mod
>
(
builder
.
create
());
EXPECT_EQ
(
g_mod
->
get_auto_broadcast
(),
mod
->
get_auto_broadcast
());
}
TEST
(
attributes
,
non_max_suppression_op_custom_attributes
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
NonMaxSuppression
>
();
auto
boxes
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
4
});
auto
scores
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
1
});
auto
box_encoding
=
opset1
::
NonMaxSuppression
::
BoxEncodingType
::
CENTER
;
bool
sort_result_descending
=
false
;
auto
nms
=
make_shared
<
opset1
::
NonMaxSuppression
>
(
boxes
,
scores
,
box_encoding
,
sort_result_descending
);
NodeBuilder
builder
(
nms
);
auto
g_nms
=
as_type_ptr
<
opset1
::
NonMaxSuppression
>
(
builder
.
create
());
EXPECT_EQ
(
g_nms
->
get_box_encoding
(),
nms
->
get_box_encoding
());
EXPECT_EQ
(
g_nms
->
get_sort_result_descending
(),
nms
->
get_sort_result_descending
());
}
TEST
(
attributes
,
non_max_suppression_op_default_attributes
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
NonMaxSuppression
>
();
auto
boxes
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
4
});
auto
scores
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
1
});
auto
nms
=
make_shared
<
opset1
::
NonMaxSuppression
>
(
boxes
,
scores
);
NodeBuilder
builder
(
nms
);
auto
g_nms
=
as_type_ptr
<
opset1
::
NonMaxSuppression
>
(
builder
.
create
());
EXPECT_EQ
(
g_nms
->
get_box_encoding
(),
nms
->
get_box_encoding
());
EXPECT_EQ
(
g_nms
->
get_sort_result_descending
(),
nms
->
get_sort_result_descending
());
}
TEST
(
attributes
,
normalize_l2_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
NormalizeL2
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
1
});
const
auto
axes
=
make_shared
<
op
::
Constant
>
(
element
::
i32
,
Shape
{},
vector
<
int32_t
>
{
0
});
float
eps
{
1e-6
f
};
auto
eps_mode
=
op
::
EpsMode
::
ADD
;
auto
normalize_l2
=
make_shared
<
opset1
::
NormalizeL2
>
(
data
,
axes
,
eps
,
eps_mode
);
NodeBuilder
builder
(
normalize_l2
);
auto
g_normalize_l2
=
as_type_ptr
<
opset1
::
NormalizeL2
>
(
builder
.
create
());
EXPECT_EQ
(
g_normalize_l2
->
get_eps
(),
normalize_l2
->
get_eps
());
EXPECT_EQ
(
g_normalize_l2
->
get_eps_mode
(),
normalize_l2
->
get_eps_mode
());
}
TEST
(
attributes
,
one_hot_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
OneHot
>
();
auto
indices
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
1
,
3
,
2
,
3
});
auto
depth
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{},
{
4
});
auto
on_value
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{},
{
1.0
f
});
auto
off_value
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{},
{
0.0
f
});
int64_t
axis
=
3
;
auto
one_hot
=
make_shared
<
opset1
::
OneHot
>
(
indices
,
depth
,
on_value
,
off_value
,
axis
);
NodeBuilder
builder
(
one_hot
);
auto
g_one_hot
=
as_type_ptr
<
opset1
::
OneHot
>
(
builder
.
create
());
EXPECT_EQ
(
g_one_hot
->
get_axis
(),
one_hot
->
get_axis
());
}
TEST
(
attributes
,
pad_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Pad
>
();
auto
arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
auto
pads_begin
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
1
});
auto
pads_end
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
1
});
auto
pad_mode
=
op
::
PadMode
::
EDGE
;
auto
pad
=
make_shared
<
opset1
::
Pad
>
(
arg
,
pads_begin
,
pads_end
,
pad_mode
);
NodeBuilder
builder
(
pad
);
auto
g_pad
=
as_type_ptr
<
opset1
::
Pad
>
(
builder
.
create
());
EXPECT_EQ
(
g_pad
->
get_pad_mode
(),
pad
->
get_pad_mode
());
}
TEST
(
attributes
,
psroi_pooling_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
PSROIPooling
>
();
auto
input
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1024
,
63
,
38
});
auto
coords
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
300
,
5
});
const
int64_t
output_dim
=
882
;
const
int64_t
group_size
=
3
;
const
float
spatial_scale
=
0.0625
;
int
spatial_bins_x
=
1
;
int
spatial_bins_y
=
1
;
string
mode
=
"Avg"
;
auto
psroi_pool
=
make_shared
<
opset1
::
PSROIPooling
>
(
input
,
coords
,
output_dim
,
group_size
,
spatial_scale
,
spatial_bins_x
,
spatial_bins_y
,
mode
);
NodeBuilder
builder
(
psroi_pool
);
auto
g_psroi_pool
=
as_type_ptr
<
opset1
::
PSROIPooling
>
(
builder
.
create
());
EXPECT_EQ
(
g_psroi_pool
->
get_output_dim
(),
psroi_pool
->
get_output_dim
());
EXPECT_EQ
(
g_psroi_pool
->
get_group_size
(),
psroi_pool
->
get_group_size
());
EXPECT_EQ
(
g_psroi_pool
->
get_spatial_scale
(),
psroi_pool
->
get_spatial_scale
());
EXPECT_EQ
(
g_psroi_pool
->
get_spatial_bins_x
(),
psroi_pool
->
get_spatial_bins_x
());
EXPECT_EQ
(
g_psroi_pool
->
get_spatial_bins_y
(),
psroi_pool
->
get_spatial_bins_y
());
EXPECT_EQ
(
g_psroi_pool
->
get_mode
(),
psroi_pool
->
get_mode
());
}
TEST
(
attributes
,
reduce_logical_and_op
)
{
// ReduceLogicalAnd derives visit_attributes from op::util::LogicalReductionKeepDims
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment