Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
01698d7a
Unverified
Commit
01698d7a
authored
Mar 03, 2020
by
Tomasz Socha
Committed by
GitHub
Mar 03, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add attribute visitor for ops M-P (#4344)
parent
5d8c39e9
Hide whitespace changes
Inline
Side-by-side
Showing
21 changed files
with
295 additions
and
0 deletions
+295
-0
psroi_pooling.cpp
src/ngraph/op/experimental/layers/psroi_pooling.cpp
+12
-0
psroi_pooling.hpp
src/ngraph/op/experimental/layers/psroi_pooling.hpp
+1
-0
matmul.cpp
src/ngraph/op/fused/matmul.cpp
+8
-0
matmul.hpp
src/ngraph/op/fused/matmul.hpp
+1
-0
mod.cpp
src/ngraph/op/fused/mod.cpp
+7
-0
mod.hpp
src/ngraph/op/fused/mod.hpp
+1
-0
normalize_l2.cpp
src/ngraph/op/fused/normalize_l2.cpp
+8
-0
normalize_l2.hpp
src/ngraph/op/fused/normalize_l2.hpp
+1
-0
prelu.cpp
src/ngraph/op/fused/prelu.cpp
+5
-0
prelu.hpp
src/ngraph/op/fused/prelu.hpp
+1
-0
max_pool.cpp
src/ngraph/op/max_pool.cpp
+12
-0
max_pool.hpp
src/ngraph/op/max_pool.hpp
+1
-0
negative.cpp
src/ngraph/op/negative.cpp
+5
-0
negative.hpp
src/ngraph/op/negative.hpp
+1
-0
non_max_suppression.cpp
src/ngraph/op/non_max_suppression.cpp
+31
-0
non_max_suppression.hpp
src/ngraph/op/non_max_suppression.hpp
+19
-0
one_hot.cpp
src/ngraph/op/one_hot.cpp
+7
-0
one_hot.hpp
src/ngraph/op/one_hot.hpp
+1
-0
pad.cpp
src/ngraph/op/pad.cpp
+7
-0
pad.hpp
src/ngraph/op/pad.hpp
+1
-0
attributes.cpp
test/attributes.cpp
+165
-0
No files found.
src/ngraph/op/experimental/layers/psroi_pooling.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "psroi_pooling.hpp"
#include "ngraph/attribute_visitor.hpp"
using
namespace
std
;
using
namespace
ngraph
;
...
...
@@ -40,6 +41,17 @@ op::PSROIPooling::PSROIPooling(const Output<Node>& input,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
PSROIPooling
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"output_dim"
,
m_output_dim
);
visitor
.
on_attribute
(
"group_size"
,
m_group_size
);
visitor
.
on_attribute
(
"spatial_scale"
,
m_spatial_scale
);
visitor
.
on_attribute
(
"mode"
,
m_mode
);
visitor
.
on_attribute
(
"spatial_bins_x"
,
m_spatial_bins_x
);
visitor
.
on_attribute
(
"spatial_bins_y"
,
m_spatial_bins_y
);
return
true
;
}
void
op
::
PSROIPooling
::
validate_and_infer_types
()
{
auto
input_et
=
get_input_element_type
(
0
);
...
...
src/ngraph/op/experimental/layers/psroi_pooling.hpp
View file @
01698d7a
...
...
@@ -51,6 +51,7 @@ namespace ngraph
int
spatial_bins_y
,
const
std
::
string
&
mode
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/matmul.cpp
View file @
01698d7a
...
...
@@ -17,6 +17,7 @@
#include <numeric>
#include "matmul.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp"
...
...
@@ -37,6 +38,13 @@ op::MatMul::MatMul(const Output<Node>& A,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
MatMul
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"transpose_a"
,
m_transpose_a
);
visitor
.
on_attribute
(
"transpose_b"
,
m_transpose_b
);
return
true
;
}
void
op
::
MatMul
::
pre_validate_and_infer_types
()
{
element
::
Type
result_et
;
...
...
src/ngraph/op/fused/matmul.hpp
View file @
01698d7a
...
...
@@ -44,6 +44,7 @@ namespace ngraph
const
bool
&
transpose_a
=
0
,
const
bool
&
transpose_b
=
0
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
...
...
src/ngraph/op/fused/mod.cpp
View file @
01698d7a
...
...
@@ -14,6 +14,7 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/fused/mod.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/convert.hpp"
...
...
@@ -35,6 +36,12 @@ op::v1::Mod::Mod(const Output<Node>& A,
{
}
bool
ngraph
::
op
::
v1
::
Mod
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"auto_broadcast"
,
m_auto_broadcast
);
return
true
;
}
NodeVector
op
::
v1
::
Mod
::
decompose_op
()
const
{
const
auto
dividend
=
make_shared
<
op
::
Abs
>
(
input_value
(
0
));
...
...
src/ngraph/op/fused/mod.hpp
View file @
01698d7a
...
...
@@ -43,6 +43,7 @@ namespace ngraph
const
Output
<
Node
>&
B
,
const
AutoBroadcastSpec
&
auto_broadcast
=
AutoBroadcastType
::
NUMPY
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/normalize_l2.cpp
View file @
01698d7a
...
...
@@ -16,6 +16,7 @@
#include <algorithm>
#include <iterator>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/constant.hpp"
...
...
@@ -39,6 +40,13 @@ op::NormalizeL2::NormalizeL2(const Output<Node>& data,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
NormalizeL2
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"eps"
,
m_eps
);
visitor
.
on_attribute
(
"eps_mode"
,
m_eps_mode
);
return
true
;
}
void
op
::
NormalizeL2
::
pre_validate_and_infer_types
()
{
auto
axes_node
=
input_value
(
1
).
get_node_shared_ptr
();
...
...
src/ngraph/op/fused/normalize_l2.hpp
View file @
01698d7a
...
...
@@ -52,6 +52,7 @@ namespace ngraph
float
eps
,
EpsMode
eps_mode
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
float
get_eps
()
const
{
return
m_eps
;
}
EpsMode
get_eps_mode
()
const
{
return
m_eps_mode
;
}
virtual
NodeVector
decompose_op
()
const
override
;
...
...
src/ngraph/op/fused/prelu.cpp
View file @
01698d7a
...
...
@@ -35,6 +35,11 @@ op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
PRelu
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
NodeVector
op
::
PRelu
::
decompose_op
()
const
{
auto
data
=
input_value
(
0
);
...
...
src/ngraph/op/fused/prelu.hpp
View file @
01698d7a
...
...
@@ -42,6 +42,7 @@ namespace ngraph
/// \param slope Multipliers for negative values
PRelu
(
const
Output
<
Node
>&
data
,
const
Output
<
Node
>&
slope
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/max_pool.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/max_pool.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
...
...
@@ -303,6 +304,17 @@ op::v1::MaxPool::MaxPool(const Output<Node>& arg,
{
}
bool
ngraph
::
op
::
v1
::
MaxPool
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"strides"
,
m_strides
);
visitor
.
on_attribute
(
"pads_begin"
,
m_pads_begin
);
visitor
.
on_attribute
(
"pads_end"
,
m_pads_end
);
visitor
.
on_attribute
(
"kernel"
,
m_kernel
);
visitor
.
on_attribute
(
"rounding_type"
,
m_rounding_type
);
visitor
.
on_attribute
(
"auto_pad"
,
m_auto_pad
);
return
true
;
}
void
op
::
v1
::
MaxPool
::
validate_and_infer_types
()
{
if
(
0
==
m_strides
.
size
())
...
...
src/ngraph/op/max_pool.hpp
View file @
01698d7a
...
...
@@ -247,6 +247,7 @@ namespace ngraph
const
Shape
&
kernel
,
op
::
RoundingType
rounding_mode
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
size_t
get_version
()
const
override
{
return
1
;
}
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/negative.cpp
View file @
01698d7a
...
...
@@ -27,6 +27,11 @@ op::Negative::Negative(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Negative
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Negative
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/negative.hpp
View file @
01698d7a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Negative
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/non_max_suppression.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/non_max_suppression.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
using
namespace
std
;
...
...
@@ -65,6 +66,13 @@ shared_ptr<Node> op::v1::NonMaxSuppression::copy_with_new_args(const NodeVector&
m_sort_result_descending
);
}
bool
ngraph
::
op
::
v1
::
NonMaxSuppression
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"box_encoding"
,
m_box_encoding
);
visitor
.
on_attribute
(
"sort_result_descending"
,
m_sort_result_descending
);
return
true
;
}
void
op
::
v1
::
NonMaxSuppression
::
validate_and_infer_types
()
{
const
auto
boxes_ps
=
get_input_partial_shape
(
0
);
...
...
@@ -157,3 +165,26 @@ int64_t op::v1::NonMaxSuppression::max_boxes_output_from_input() const
return
max_output_boxes
;
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>&
EnumNames
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>
(
"op::v1::NonMaxSuppression::BoxEncodingType"
,
{{
"corner"
,
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
::
CORNER
},
{
"center"
,
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
::
CENTER
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
src/ngraph/op/non_max_suppression.hpp
View file @
01698d7a
...
...
@@ -68,6 +68,7 @@ namespace ngraph
const
BoxEncodingType
box_encoding
=
BoxEncodingType
::
CORNER
,
const
bool
sort_result_descending
=
true
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
@@ -93,4 +94,22 @@ namespace ngraph
};
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>
:
public
EnumAttributeAdapterBase
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>
{
public
:
AttributeAdapter
(
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
&
value
)
:
EnumAttributeAdapterBase
<
op
::
v1
::
NonMaxSuppression
::
BoxEncodingType
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::v1::NonMaxSuppression::BoxEncodingType>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
src/ngraph/op/one_hot.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/one_hot.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/validation_util.hpp"
using
namespace
std
;
...
...
@@ -194,6 +195,12 @@ void op::v1::OneHot::validate_and_infer_types()
set_output_type
(
0
,
on_value_et
,
result_shape
);
}
bool
ngraph
::
op
::
v1
::
OneHot
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"axis"
,
m_axis
);
return
true
;
}
shared_ptr
<
Node
>
op
::
v1
::
OneHot
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/one_hot.hpp
View file @
01698d7a
...
...
@@ -98,6 +98,7 @@ namespace ngraph
const
Output
<
Node
>&
off_value
,
int64_t
axis
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/pad.cpp
View file @
01698d7a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/pad.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
...
...
@@ -222,6 +223,12 @@ CoordinateDiff op::v1::Pad::get_pads_end() const
return
pads_end_coord
;
}
bool
ngraph
::
op
::
v1
::
Pad
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"pad_mode"
,
m_pad_mode
);
return
true
;
}
void
op
::
v1
::
Pad
::
validate_and_infer_types
()
{
element
::
Type
result_et
;
...
...
src/ngraph/op/pad.hpp
View file @
01698d7a
...
...
@@ -135,6 +135,7 @@ namespace ngraph
/// \brief Constructs a generic padding operation.
Pad
()
=
default
;
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
size_t
get_version
()
const
override
{
return
1
;
}
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
test/attributes.cpp
View file @
01698d7a
...
...
@@ -285,6 +285,171 @@ TEST(attributes, user_op)
EXPECT_EQ
(
g_oracle
->
get_ultra_parameters
(),
oracle
->
get_ultra_parameters
());
}
TEST
(
attributes
,
matmul_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
MatMul
>
();
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
0
,
2
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
0
});
bool
transpose_a
=
true
;
bool
transpose_b
=
true
;
auto
matmul
=
make_shared
<
opset1
::
MatMul
>
(
A
,
B
,
transpose_a
,
transpose_b
);
NodeBuilder
builder
(
matmul
);
auto
g_matmul
=
as_type_ptr
<
opset1
::
MatMul
>
(
builder
.
create
());
EXPECT_EQ
(
g_matmul
->
get_transpose_a
(),
matmul
->
get_transpose_a
());
EXPECT_EQ
(
g_matmul
->
get_transpose_b
(),
matmul
->
get_transpose_b
());
}
TEST
(
attributes
,
max_pool_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
MaxPool
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
64
,
3
,
5
});
auto
strides
=
Strides
{
2
};
auto
pads_begin
=
Shape
{
1
};
auto
pads_end
=
Shape
{
1
};
auto
kernel
=
Shape
{
1
};
auto
rounding_mode
=
op
::
RoundingType
::
FLOOR
;
auto
auto_pad
=
op
::
PadType
::
EXPLICIT
;
auto
max_pool
=
make_shared
<
opset1
::
MaxPool
>
(
data
,
strides
,
pads_begin
,
pads_end
,
kernel
,
rounding_mode
,
auto_pad
);
NodeBuilder
builder
(
max_pool
);
auto
g_max_pool
=
as_type_ptr
<
opset1
::
MaxPool
>
(
builder
.
create
());
EXPECT_EQ
(
g_max_pool
->
get_strides
(),
max_pool
->
get_strides
());
EXPECT_EQ
(
g_max_pool
->
get_pads_begin
(),
max_pool
->
get_pads_begin
());
EXPECT_EQ
(
g_max_pool
->
get_pads_end
(),
max_pool
->
get_pads_end
());
EXPECT_EQ
(
g_max_pool
->
get_kernel
(),
max_pool
->
get_kernel
());
EXPECT_EQ
(
g_max_pool
->
get_rounding_type
(),
max_pool
->
get_rounding_type
());
EXPECT_EQ
(
g_max_pool
->
get_auto_pad
(),
max_pool
->
get_auto_pad
());
}
TEST
(
attributes
,
mod_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Mod
>
();
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
0
,
2
});
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
0
});
auto
auto_broadcast
=
op
::
AutoBroadcastType
::
NUMPY
;
auto
mod
=
make_shared
<
opset1
::
Mod
>
(
A
,
B
,
auto_broadcast
);
NodeBuilder
builder
(
mod
);
auto
g_mod
=
as_type_ptr
<
opset1
::
Mod
>
(
builder
.
create
());
EXPECT_EQ
(
g_mod
->
get_auto_broadcast
(),
mod
->
get_auto_broadcast
());
}
TEST
(
attributes
,
non_max_suppression_op_custom_attributes
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
NonMaxSuppression
>
();
auto
boxes
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
4
});
auto
scores
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
1
});
auto
box_encoding
=
opset1
::
NonMaxSuppression
::
BoxEncodingType
::
CENTER
;
bool
sort_result_descending
=
false
;
auto
nms
=
make_shared
<
opset1
::
NonMaxSuppression
>
(
boxes
,
scores
,
box_encoding
,
sort_result_descending
);
NodeBuilder
builder
(
nms
);
auto
g_nms
=
as_type_ptr
<
opset1
::
NonMaxSuppression
>
(
builder
.
create
());
EXPECT_EQ
(
g_nms
->
get_box_encoding
(),
nms
->
get_box_encoding
());
EXPECT_EQ
(
g_nms
->
get_sort_result_descending
(),
nms
->
get_sort_result_descending
());
}
TEST
(
attributes
,
non_max_suppression_op_default_attributes
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
NonMaxSuppression
>
();
auto
boxes
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
4
});
auto
scores
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1
,
1
});
auto
nms
=
make_shared
<
opset1
::
NonMaxSuppression
>
(
boxes
,
scores
);
NodeBuilder
builder
(
nms
);
auto
g_nms
=
as_type_ptr
<
opset1
::
NonMaxSuppression
>
(
builder
.
create
());
EXPECT_EQ
(
g_nms
->
get_box_encoding
(),
nms
->
get_box_encoding
());
EXPECT_EQ
(
g_nms
->
get_sort_result_descending
(),
nms
->
get_sort_result_descending
());
}
TEST
(
attributes
,
normalize_l2_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
NormalizeL2
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
1
});
const
auto
axes
=
make_shared
<
op
::
Constant
>
(
element
::
i32
,
Shape
{},
vector
<
int32_t
>
{
0
});
float
eps
{
1e-6
f
};
auto
eps_mode
=
op
::
EpsMode
::
ADD
;
auto
normalize_l2
=
make_shared
<
opset1
::
NormalizeL2
>
(
data
,
axes
,
eps
,
eps_mode
);
NodeBuilder
builder
(
normalize_l2
);
auto
g_normalize_l2
=
as_type_ptr
<
opset1
::
NormalizeL2
>
(
builder
.
create
());
EXPECT_EQ
(
g_normalize_l2
->
get_eps
(),
normalize_l2
->
get_eps
());
EXPECT_EQ
(
g_normalize_l2
->
get_eps_mode
(),
normalize_l2
->
get_eps_mode
());
}
TEST
(
attributes
,
one_hot_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
OneHot
>
();
auto
indices
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
1
,
3
,
2
,
3
});
auto
depth
=
op
::
Constant
::
create
(
element
::
i64
,
Shape
{},
{
4
});
auto
on_value
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{},
{
1.0
f
});
auto
off_value
=
op
::
Constant
::
create
(
element
::
f32
,
Shape
{},
{
0.0
f
});
int64_t
axis
=
3
;
auto
one_hot
=
make_shared
<
opset1
::
OneHot
>
(
indices
,
depth
,
on_value
,
off_value
,
axis
);
NodeBuilder
builder
(
one_hot
);
auto
g_one_hot
=
as_type_ptr
<
opset1
::
OneHot
>
(
builder
.
create
());
EXPECT_EQ
(
g_one_hot
->
get_axis
(),
one_hot
->
get_axis
());
}
TEST
(
attributes
,
pad_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Pad
>
();
auto
arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
auto
pads_begin
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
1
});
auto
pads_end
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
1
});
auto
pad_mode
=
op
::
PadMode
::
EDGE
;
auto
pad
=
make_shared
<
opset1
::
Pad
>
(
arg
,
pads_begin
,
pads_end
,
pad_mode
);
NodeBuilder
builder
(
pad
);
auto
g_pad
=
as_type_ptr
<
opset1
::
Pad
>
(
builder
.
create
());
EXPECT_EQ
(
g_pad
->
get_pad_mode
(),
pad
->
get_pad_mode
());
}
TEST
(
attributes
,
psroi_pooling_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
PSROIPooling
>
();
auto
input
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
1024
,
63
,
38
});
auto
coords
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
300
,
5
});
const
int64_t
output_dim
=
882
;
const
int64_t
group_size
=
3
;
const
float
spatial_scale
=
0.0625
;
int
spatial_bins_x
=
1
;
int
spatial_bins_y
=
1
;
string
mode
=
"Avg"
;
auto
psroi_pool
=
make_shared
<
opset1
::
PSROIPooling
>
(
input
,
coords
,
output_dim
,
group_size
,
spatial_scale
,
spatial_bins_x
,
spatial_bins_y
,
mode
);
NodeBuilder
builder
(
psroi_pool
);
auto
g_psroi_pool
=
as_type_ptr
<
opset1
::
PSROIPooling
>
(
builder
.
create
());
EXPECT_EQ
(
g_psroi_pool
->
get_output_dim
(),
psroi_pool
->
get_output_dim
());
EXPECT_EQ
(
g_psroi_pool
->
get_group_size
(),
psroi_pool
->
get_group_size
());
EXPECT_EQ
(
g_psroi_pool
->
get_spatial_scale
(),
psroi_pool
->
get_spatial_scale
());
EXPECT_EQ
(
g_psroi_pool
->
get_spatial_bins_x
(),
psroi_pool
->
get_spatial_bins_x
());
EXPECT_EQ
(
g_psroi_pool
->
get_spatial_bins_y
(),
psroi_pool
->
get_spatial_bins_y
());
EXPECT_EQ
(
g_psroi_pool
->
get_mode
(),
psroi_pool
->
get_mode
());
}
TEST
(
attributes
,
reduce_logical_and_op
)
{
// ReduceLogicalAnd derives visit_attributes from op::util::LogicalReductionKeepDims
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment