Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
5d8c39e9
Unverified
Commit
5d8c39e9
authored
Mar 03, 2020
by
Tomasz Socha
Committed by
GitHub
Mar 03, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add attribute visitor for ops R (#4340)
parent
1b611294
Hide whitespace changes
Inline
Side-by-side
Showing
23 changed files
with
381 additions
and
0 deletions
+381
-0
region_yolo.cpp
src/ngraph/op/experimental/layers/region_yolo.cpp
+14
-0
region_yolo.hpp
src/ngraph/op/experimental/layers/region_yolo.hpp
+1
-0
range.cpp
src/ngraph/op/experimental/range.cpp
+5
-0
range.hpp
src/ngraph/op/experimental/range.hpp
+1
-0
rnn_cell.cpp
src/ngraph/op/fused/rnn_cell.cpp
+5
-0
rnn_cell.hpp
src/ngraph/op/fused/rnn_cell.hpp
+1
-0
reshape.cpp
src/ngraph/op/reshape.cpp
+6
-0
reshape.hpp
src/ngraph/op/reshape.hpp
+1
-0
result.cpp
src/ngraph/op/result.cpp
+5
-0
result.hpp
src/ngraph/op/result.hpp
+1
-0
reverse.cpp
src/ngraph/op/reverse.cpp
+26
-0
reverse.hpp
src/ngraph/op/reverse.hpp
+17
-0
reverse_sequence.cpp
src/ngraph/op/reverse_sequence.cpp
+8
-0
reverse_sequence.hpp
src/ngraph/op/reverse_sequence.hpp
+1
-0
arithmetic_reductions_keep_dims.cpp
src/ngraph/op/util/arithmetic_reductions_keep_dims.cpp
+7
-0
arithmetic_reductions_keep_dims.hpp
src/ngraph/op/util/arithmetic_reductions_keep_dims.hpp
+2
-0
logical_reduction_keep_dims.cpp
src/ngraph/op/util/logical_reduction_keep_dims.cpp
+7
-0
logical_reduction_keep_dims.hpp
src/ngraph/op/util/logical_reduction_keep_dims.hpp
+2
-0
rnn_cell_base.cpp
src/ngraph/op/util/rnn_cell_base.cpp
+11
-0
rnn_cell_base.hpp
src/ngraph/op/util/rnn_cell_base.hpp
+1
-0
unary_elementwise_arithmetic.cpp
src/ngraph/op/util/unary_elementwise_arithmetic.cpp
+5
-0
unary_elementwise_arithmetic.hpp
src/ngraph/op/util/unary_elementwise_arithmetic.hpp
+1
-0
attributes.cpp
test/attributes.cpp
+253
-0
No files found.
src/ngraph/op/experimental/layers/region_yolo.cpp
View file @
5d8c39e9
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "region_yolo.hpp"
#include "ngraph/attribute_visitor.hpp"
using
namespace
std
;
using
namespace
ngraph
;
...
...
@@ -43,6 +44,19 @@ op::RegionYolo::RegionYolo(const Output<Node>& input,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
RegionYolo
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"anchors"
,
m_anchors
);
visitor
.
on_attribute
(
"axis"
,
m_axis
);
visitor
.
on_attribute
(
"coords"
,
m_num_coords
);
visitor
.
on_attribute
(
"classes"
,
m_num_classes
);
visitor
.
on_attribute
(
"end_axis"
,
m_end_axis
);
visitor
.
on_attribute
(
"num"
,
m_num_regions
);
visitor
.
on_attribute
(
"do_softmax"
,
m_do_softmax
);
visitor
.
on_attribute
(
"mask"
,
m_mask
);
return
true
;
}
void
op
::
RegionYolo
::
validate_and_infer_types
()
{
auto
input_et
=
get_input_element_type
(
0
);
...
...
src/ngraph/op/experimental/layers/region_yolo.hpp
View file @
5d8c39e9
...
...
@@ -55,6 +55,7 @@ namespace ngraph
const
int
end_axis
,
const
std
::
vector
<
float
>&
anchors
=
std
::
vector
<
float
>
{});
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/experimental/range.cpp
View file @
5d8c39e9
...
...
@@ -182,6 +182,11 @@ static PartialShape infer_output_shape(const op::Range* node, const element::Typ
return
result
;
}
bool
ngraph
::
op
::
v0
::
Range
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
Range
::
validate_and_infer_types
()
{
set_input_is_relevant_to_shape
(
0
);
...
...
src/ngraph/op/experimental/range.hpp
View file @
5d8c39e9
...
...
@@ -46,6 +46,7 @@ namespace ngraph
const
Output
<
Node
>&
stop
,
const
Output
<
Node
>&
step
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/rnn_cell.cpp
View file @
5d8c39e9
...
...
@@ -65,6 +65,11 @@ op::RNNCell::RNNCell(const Output<Node>& X,
constructor_validate_and_infer_types
();
}
bool
op
::
RNNCell
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
op
::
util
::
RNNCellBase
::
visit_attributes
(
visitor
);
}
void
op
::
RNNCell
::
pre_validate_and_infer_types
()
{
if
(
is_dynamic
())
...
...
src/ngraph/op/fused/rnn_cell.hpp
View file @
5d8c39e9
...
...
@@ -132,6 +132,7 @@ namespace ngraph
const
std
::
vector
<
float
>&
activations_beta
=
{},
float
clip
=
0.
f
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/reshape.cpp
View file @
5d8c39e9
...
...
@@ -156,6 +156,12 @@ op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, b
constructor_validate_and_infer_types
();
}
bool
op
::
v1
::
Reshape
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"special_zero"
,
m_special_zero
);
return
true
;
}
void
op
::
v1
::
Reshape
::
validate_and_infer_types
()
{
auto
pattern_et
=
get_input_element_type
(
1
);
...
...
src/ngraph/op/reshape.hpp
View file @
5d8c39e9
...
...
@@ -137,6 +137,7 @@ namespace ngraph
/// from input shape at the same index.
Reshape
(
const
Output
<
Node
>&
arg
,
const
Output
<
Node
>&
pattern
,
bool
special_zero
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
size_t
get_version
()
const
override
{
return
1
;
}
...
...
src/ngraph/op/result.cpp
View file @
5d8c39e9
...
...
@@ -35,6 +35,11 @@ op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
set_placement_index
(
input_value
(
0
).
get_node
()
->
get_placement_index
());
}
bool
ngraph
::
op
::
v0
::
Result
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
Result
::
validate_and_infer_types
()
{
NODE_VALIDATION_CHECK
(
...
...
src/ngraph/op/result.hpp
View file @
5d8c39e9
...
...
@@ -38,6 +38,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Result
(
const
Output
<
Node
>&
arg
,
bool
needs_default_layout
=
false
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/reverse.cpp
View file @
5d8c39e9
...
...
@@ -17,6 +17,7 @@
#include <algorithm>
#include <sstream>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reverse.hpp"
...
...
@@ -91,6 +92,12 @@ op::v1::Reverse::Reverse(const Output<Node>& data,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
Reverse
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"mode"
,
m_mode
);
return
true
;
}
void
op
::
v1
::
Reverse
::
validate_and_infer_types
()
{
if
(
m_mode
==
Mode
::
MASK
)
...
...
@@ -194,3 +201,22 @@ op::v1::Reverse::Mode op::v1::Reverse::mode_from_string(const std::string& mode)
return
allowed_values
.
at
(
mode
);
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
v1
::
Reverse
::
Mode
>&
EnumNames
<
op
::
v1
::
Reverse
::
Mode
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
v1
::
Reverse
::
Mode
>
(
"op::v1::Reverse::Mode"
,
{{
"index"
,
op
::
v1
::
Reverse
::
Mode
::
INDEX
},
{
"mask"
,
op
::
v1
::
Reverse
::
Mode
::
MASK
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
v1
::
Reverse
::
Mode
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
Reverse
::
Mode
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
src/ngraph/op/reverse.hpp
View file @
5d8c39e9
...
...
@@ -108,6 +108,7 @@ namespace ngraph
const
Output
<
Node
>&
reversed_axes
,
const
Mode
mode
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
@@ -135,4 +136,20 @@ namespace ngraph
// default opset version
using
v0
::
Reverse
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
Reverse
::
Mode
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
v1
::
Reverse
::
Mode
>
:
public
EnumAttributeAdapterBase
<
op
::
v1
::
Reverse
::
Mode
>
{
public
:
AttributeAdapter
(
op
::
v1
::
Reverse
::
Mode
&
value
)
:
EnumAttributeAdapterBase
<
op
::
v1
::
Reverse
::
Mode
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::v1::Reverse::Mode>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
src/ngraph/op/reverse_sequence.cpp
View file @
5d8c39e9
...
...
@@ -17,6 +17,7 @@
#include <algorithm>
#include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/validation_util.hpp"
...
...
@@ -39,6 +40,13 @@ op::ReverseSequence::ReverseSequence(const Output<Node>& arg,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
ReverseSequence
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"batch_axis"
,
m_batch_axis
);
visitor
.
on_attribute
(
"seq_axis"
,
m_seq_axis
);
return
true
;
}
void
op
::
ReverseSequence
::
validate_and_infer_types
()
{
auto
input_shape
=
get_input_partial_shape
(
0
);
...
...
src/ngraph/op/reverse_sequence.hpp
View file @
5d8c39e9
...
...
@@ -38,6 +38,7 @@ namespace ngraph
int64_t
batch_axis
,
int64_t
seq_axis
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/util/arithmetic_reductions_keep_dims.cpp
View file @
5d8c39e9
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
...
...
@@ -30,6 +31,12 @@ op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims(
{
}
bool
ngraph
::
op
::
util
::
ArithmeticReductionKeepDims
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"keep_dims"
,
m_keep_dims
);
return
true
;
}
void
op
::
util
::
ArithmeticReductionKeepDims
::
validate_and_infer_types
()
{
if
(
m_keep_dims
)
...
...
src/ngraph/op/util/arithmetic_reductions_keep_dims.hpp
View file @
5d8c39e9
...
...
@@ -37,6 +37,8 @@ namespace ngraph
const
Output
<
Node
>&
reduction_axes
,
bool
keep_dims
=
false
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
public
:
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/util/logical_reduction_keep_dims.cpp
View file @
5d8c39e9
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/logical_reduction_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
...
...
@@ -30,6 +31,12 @@ op::util::LogicalReductionKeepDims::LogicalReductionKeepDims(
{
}
bool
ngraph
::
op
::
util
::
LogicalReductionKeepDims
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"keep_dims"
,
m_keep_dims
);
return
true
;
}
void
op
::
util
::
LogicalReductionKeepDims
::
validate_and_infer_types
()
{
if
(
m_keep_dims
)
...
...
src/ngraph/op/util/logical_reduction_keep_dims.hpp
View file @
5d8c39e9
...
...
@@ -37,6 +37,8 @@ namespace ngraph
const
Output
<
Node
>&
reduction_axes
,
const
bool
keep_dims
=
false
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
public
:
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/util/rnn_cell_base.cpp
View file @
5d8c39e9
...
...
@@ -17,6 +17,7 @@
#include <algorithm>
#include <iterator>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/multiply.hpp"
...
...
@@ -48,6 +49,16 @@ op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
{
}
bool
ngraph
::
op
::
util
::
RNNCellBase
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"hidden_size"
,
m_hidden_size
);
visitor
.
on_attribute
(
"activations"
,
m_activations
);
visitor
.
on_attribute
(
"activations_alpha"
,
m_activations_alpha
);
visitor
.
on_attribute
(
"activations_beta"
,
m_activations_beta
);
visitor
.
on_attribute
(
"clip"
,
m_clip
);
return
true
;
}
op
::
util
::
ActivationFunction
op
::
util
::
RNNCellBase
::
get_activation_function
(
size_t
idx
)
const
{
op
::
util
::
ActivationFunction
afunc
=
get_activation_func_by_name
(
m_activations
.
at
(
idx
));
...
...
src/ngraph/op/util/rnn_cell_base.hpp
View file @
5d8c39e9
...
...
@@ -58,6 +58,7 @@ namespace ngraph
RNNCellBase
()
=
default
;
virtual
bool
visit_attributes
(
AttributeVisitor
&
visitor
);
std
::
size_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
float
get_clip
()
const
{
return
m_clip
;
}
const
std
::
vector
<
std
::
string
>&
get_activations
()
const
{
return
m_activations
;
}
...
...
src/ngraph/op/util/unary_elementwise_arithmetic.cpp
View file @
5d8c39e9
...
...
@@ -43,3 +43,8 @@ void op::util::UnaryElementwiseArithmetic::validate_and_infer_types()
{
validate_and_infer_elementwise_arithmetic
();
}
bool
op
::
util
::
UnaryElementwiseArithmetic
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
src/ngraph/op/util/unary_elementwise_arithmetic.hpp
View file @
5d8c39e9
...
...
@@ -68,6 +68,7 @@ namespace ngraph
public
:
void
validate_and_infer_types
()
override
;
bool
is_unary_elementwise_arithmetic
()
const
override
{
return
true
;
}
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
};
}
}
...
...
test/attributes.cpp
View file @
5d8c39e9
...
...
@@ -285,6 +285,259 @@ TEST(attributes, user_op)
EXPECT_EQ
(
g_oracle
->
get_ultra_parameters
(),
oracle
->
get_ultra_parameters
());
}
TEST
(
attributes
,
reduce_logical_and_op
)
{
// ReduceLogicalAnd derives visit_attributes from op::util::LogicalReductionKeepDims
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
ReduceLogicalAnd
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
4
,
5
});
auto
reduction_axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
2
});
bool
keep_dims
=
true
;
auto
reduce_logical_and
=
make_shared
<
opset1
::
ReduceSum
>
(
data
,
reduction_axes
,
keep_dims
);
NodeBuilder
builder
(
reduce_logical_and
);
auto
g_reduce_logical_and
=
as_type_ptr
<
opset1
::
ReduceSum
>
(
builder
.
create
());
EXPECT_EQ
(
g_reduce_logical_and
->
get_keep_dims
(),
reduce_logical_and
->
get_keep_dims
());
}
TEST
(
attributes
,
reduce_logical_or_op
)
{
// ReduceLogicalOr derives visit_attributes from op::util::LogicalReductionKeepDims
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
ReduceLogicalOr
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
4
,
5
});
auto
reduction_axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
2
});
bool
keep_dims
=
true
;
auto
reduce_logical_or
=
make_shared
<
opset1
::
ReduceLogicalOr
>
(
data
,
reduction_axes
,
keep_dims
);
NodeBuilder
builder
(
reduce_logical_or
);
auto
g_reduce_logical_or
=
as_type_ptr
<
opset1
::
ReduceLogicalOr
>
(
builder
.
create
());
EXPECT_EQ
(
g_reduce_logical_or
->
get_keep_dims
(),
reduce_logical_or
->
get_keep_dims
());
}
TEST
(
attributes
,
reduce_max_op
)
{
// ReduceMax derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
ReduceMax
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
4
,
5
});
auto
reduction_axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
2
});
bool
keep_dims
=
true
;
auto
reduce_max
=
make_shared
<
opset1
::
ReduceMax
>
(
data
,
reduction_axes
,
keep_dims
);
NodeBuilder
builder
(
reduce_max
);
auto
g_reduce_max
=
as_type_ptr
<
opset1
::
ReduceMax
>
(
builder
.
create
());
EXPECT_EQ
(
g_reduce_max
->
get_keep_dims
(),
reduce_max
->
get_keep_dims
());
}
TEST
(
attributes
,
reduce_mean_op
)
{
// ReduceMean derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
ReduceMean
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
4
,
5
});
auto
reduction_axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
2
});
bool
keep_dims
=
true
;
auto
reduce_mean
=
make_shared
<
opset1
::
ReduceMean
>
(
data
,
reduction_axes
,
keep_dims
);
NodeBuilder
builder
(
reduce_mean
);
auto
g_reduce_mean
=
as_type_ptr
<
opset1
::
ReduceMean
>
(
builder
.
create
());
EXPECT_EQ
(
g_reduce_mean
->
get_keep_dims
(),
reduce_mean
->
get_keep_dims
());
}
TEST
(
attributes
,
reduce_min_op
)
{
// ReduceMin derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
ReduceMin
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
4
,
5
});
auto
reduction_axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
2
});
bool
keep_dims
=
true
;
auto
reduce_min
=
make_shared
<
opset1
::
ReduceMin
>
(
data
,
reduction_axes
,
keep_dims
);
NodeBuilder
builder
(
reduce_min
);
auto
g_reduce_min
=
as_type_ptr
<
opset1
::
ReduceMin
>
(
builder
.
create
());
EXPECT_EQ
(
g_reduce_min
->
get_keep_dims
(),
reduce_min
->
get_keep_dims
());
}
TEST
(
attributes
,
reduce_prod_op
)
{
// ReduceProd derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
ReduceProd
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
4
,
5
});
auto
reduction_axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
2
});
bool
keep_dims
=
true
;
auto
reduce_prod
=
make_shared
<
opset1
::
ReduceProd
>
(
data
,
reduction_axes
,
keep_dims
);
NodeBuilder
builder
(
reduce_prod
);
auto
g_reduce_prod
=
as_type_ptr
<
opset1
::
ReduceProd
>
(
builder
.
create
());
EXPECT_EQ
(
g_reduce_prod
->
get_keep_dims
(),
reduce_prod
->
get_keep_dims
());
}
TEST
(
attributes
,
reduce_sum_op
)
{
// ReduceSum derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
ReduceSum
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
4
,
5
});
auto
reduction_axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
2
});
bool
keep_dims
=
true
;
auto
reduce_sum
=
make_shared
<
opset1
::
ReduceSum
>
(
data
,
reduction_axes
,
keep_dims
);
NodeBuilder
builder
(
reduce_sum
);
auto
g_reduce_sum
=
as_type_ptr
<
opset1
::
ReduceSum
>
(
builder
.
create
());
EXPECT_EQ
(
g_reduce_sum
->
get_keep_dims
(),
reduce_sum
->
get_keep_dims
());
}
TEST
(
attributes
,
region_yolo_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
RegionYolo
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
i64
,
Shape
{
1
,
255
,
26
,
26
});
size_t
num_coords
=
4
;
size_t
num_classes
=
1
;
size_t
num_regions
=
6
;
auto
do_softmax
=
false
;
auto
mask
=
std
::
vector
<
int64_t
>
{
0
,
1
};
auto
axis
=
1
;
auto
end_axis
=
3
;
auto
anchors
=
std
::
vector
<
float
>
{
10
,
14
,
23
,
27
,
37
,
58
,
81
,
82
,
135
,
169
,
344
,
319
};
auto
region_yolo
=
make_shared
<
opset1
::
RegionYolo
>
(
data
,
num_coords
,
num_classes
,
num_regions
,
do_softmax
,
mask
,
axis
,
end_axis
,
anchors
);
NodeBuilder
builder
(
region_yolo
);
auto
g_region_yolo
=
as_type_ptr
<
opset1
::
RegionYolo
>
(
builder
.
create
());
EXPECT_EQ
(
g_region_yolo
->
get_num_coords
(),
region_yolo
->
get_num_coords
());
EXPECT_EQ
(
g_region_yolo
->
get_num_classes
(),
region_yolo
->
get_num_classes
());
EXPECT_EQ
(
g_region_yolo
->
get_num_regions
(),
region_yolo
->
get_num_regions
());
EXPECT_EQ
(
g_region_yolo
->
get_do_softmax
(),
region_yolo
->
get_do_softmax
());
EXPECT_EQ
(
g_region_yolo
->
get_mask
(),
region_yolo
->
get_mask
());
EXPECT_EQ
(
g_region_yolo
->
get_anchors
(),
region_yolo
->
get_anchors
());
EXPECT_EQ
(
g_region_yolo
->
get_axis
(),
region_yolo
->
get_axis
());
EXPECT_EQ
(
g_region_yolo
->
get_end_axis
(),
region_yolo
->
get_end_axis
());
}
TEST
(
attributes
,
reshape_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Reshape
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
,
3
,
4
});
auto
pattern
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
});
bool
special_zero
=
true
;
auto
reshape
=
make_shared
<
opset1
::
Reshape
>
(
data
,
pattern
,
special_zero
);
NodeBuilder
builder
(
reshape
);
auto
g_reshape
=
as_type_ptr
<
opset1
::
Reshape
>
(
builder
.
create
());
EXPECT_EQ
(
g_reshape
->
get_special_zero
(),
reshape
->
get_special_zero
());
}
TEST
(
attributes
,
reverse_op_enum_mode
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Reverse
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
200
});
auto
reversed_axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
200
});
auto
reverse
=
make_shared
<
opset1
::
Reverse
>
(
data
,
reversed_axes
,
opset1
::
Reverse
::
Mode
::
INDEX
);
NodeBuilder
builder
(
reverse
);
auto
g_reverse
=
as_type_ptr
<
opset1
::
Reverse
>
(
builder
.
create
());
EXPECT_EQ
(
g_reverse
->
get_mode
(),
reverse
->
get_mode
());
}
TEST
(
attributes
,
reverse_op_string_mode
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Reverse
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
200
});
auto
reversed_axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
200
});
std
::
string
mode
=
"index"
;
auto
reverse
=
make_shared
<
opset1
::
Reverse
>
(
data
,
reversed_axes
,
mode
);
NodeBuilder
builder
(
reverse
);
auto
g_reverse
=
as_type_ptr
<
opset1
::
Reverse
>
(
builder
.
create
());
EXPECT_EQ
(
g_reverse
->
get_mode
(),
reverse
->
get_mode
());
}
TEST
(
attributes
,
reverse_sequence_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
ReverseSequence
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
,
3
,
4
,
2
});
auto
seq_indices
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
4
});
auto
batch_axis
=
2
;
auto
seq_axis
=
1
;
auto
reverse_sequence
=
make_shared
<
opset1
::
ReverseSequence
>
(
data
,
seq_indices
,
batch_axis
,
seq_axis
);
NodeBuilder
builder
(
reverse_sequence
);
auto
g_reverse_sequence
=
as_type_ptr
<
opset1
::
ReverseSequence
>
(
builder
.
create
());
EXPECT_EQ
(
g_reverse_sequence
->
get_origin_batch_axis
(),
reverse_sequence
->
get_origin_batch_axis
());
EXPECT_EQ
(
g_reverse_sequence
->
get_origin_sequence_axis
(),
reverse_sequence
->
get_origin_sequence_axis
());
}
TEST
(
attributes
,
rnn_cell_op_custom_attributes
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
RNNCell
>
();
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
auto
H
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
3
});
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
3
});
const
size_t
hidden_size
=
3
;
auto
activations
=
std
::
vector
<
std
::
string
>
{
"sigmoid"
,
"tanh"
};
auto
activations_alpha
=
std
::
vector
<
float
>
{
1.0
,
1.5
};
auto
activations_beta
=
std
::
vector
<
float
>
{
2.0
,
1.0
};
float
clip
=
1.0
;
auto
rnn_cell
=
make_shared
<
opset1
::
RNNCell
>
(
X
,
H
,
W
,
R
,
hidden_size
,
activations
,
activations_alpha
,
activations_beta
,
clip
);
NodeBuilder
builder
(
rnn_cell
);
auto
g_rnn_cell
=
as_type_ptr
<
opset1
::
RNNCell
>
(
builder
.
create
());
EXPECT_EQ
(
g_rnn_cell
->
get_hidden_size
(),
rnn_cell
->
get_hidden_size
());
EXPECT_EQ
(
g_rnn_cell
->
get_clip
(),
rnn_cell
->
get_clip
());
EXPECT_EQ
(
g_rnn_cell
->
get_activations
(),
rnn_cell
->
get_activations
());
EXPECT_EQ
(
g_rnn_cell
->
get_activations_alpha
(),
rnn_cell
->
get_activations_alpha
());
EXPECT_EQ
(
g_rnn_cell
->
get_activations_beta
(),
rnn_cell
->
get_activations_beta
());
}
TEST
(
attributes
,
rnn_cell_op_default_attributes
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
RNNCell
>
();
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
auto
H
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
3
});
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
3
,
3
});
const
size_t
hidden_size
=
3
;
auto
rnn_cell
=
make_shared
<
opset1
::
RNNCell
>
(
X
,
H
,
W
,
R
,
hidden_size
);
NodeBuilder
builder
(
rnn_cell
);
auto
g_rnn_cell
=
as_type_ptr
<
opset1
::
RNNCell
>
(
builder
.
create
());
EXPECT_EQ
(
g_rnn_cell
->
get_hidden_size
(),
rnn_cell
->
get_hidden_size
());
EXPECT_EQ
(
g_rnn_cell
->
get_clip
(),
rnn_cell
->
get_clip
());
EXPECT_EQ
(
g_rnn_cell
->
get_activations
(),
rnn_cell
->
get_activations
());
EXPECT_EQ
(
g_rnn_cell
->
get_activations_alpha
(),
rnn_cell
->
get_activations_alpha
());
EXPECT_EQ
(
g_rnn_cell
->
get_activations_beta
(),
rnn_cell
->
get_activations_beta
());
}
TEST
(
attributes
,
elu_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Elu
>
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment