Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
5d8c39e9
Unverified
Commit
5d8c39e9
authored
Mar 03, 2020
by
Tomasz Socha
Committed by
GitHub
Mar 03, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add attribute visitor for ops R (#4340)
parent
1b611294
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
23 changed files
with
128 additions
and
0 deletions
+128
-0
region_yolo.cpp
src/ngraph/op/experimental/layers/region_yolo.cpp
+14
-0
region_yolo.hpp
src/ngraph/op/experimental/layers/region_yolo.hpp
+1
-0
range.cpp
src/ngraph/op/experimental/range.cpp
+5
-0
range.hpp
src/ngraph/op/experimental/range.hpp
+1
-0
rnn_cell.cpp
src/ngraph/op/fused/rnn_cell.cpp
+5
-0
rnn_cell.hpp
src/ngraph/op/fused/rnn_cell.hpp
+1
-0
reshape.cpp
src/ngraph/op/reshape.cpp
+6
-0
reshape.hpp
src/ngraph/op/reshape.hpp
+1
-0
result.cpp
src/ngraph/op/result.cpp
+5
-0
result.hpp
src/ngraph/op/result.hpp
+1
-0
reverse.cpp
src/ngraph/op/reverse.cpp
+26
-0
reverse.hpp
src/ngraph/op/reverse.hpp
+17
-0
reverse_sequence.cpp
src/ngraph/op/reverse_sequence.cpp
+8
-0
reverse_sequence.hpp
src/ngraph/op/reverse_sequence.hpp
+1
-0
arithmetic_reductions_keep_dims.cpp
src/ngraph/op/util/arithmetic_reductions_keep_dims.cpp
+7
-0
arithmetic_reductions_keep_dims.hpp
src/ngraph/op/util/arithmetic_reductions_keep_dims.hpp
+2
-0
logical_reduction_keep_dims.cpp
src/ngraph/op/util/logical_reduction_keep_dims.cpp
+7
-0
logical_reduction_keep_dims.hpp
src/ngraph/op/util/logical_reduction_keep_dims.hpp
+2
-0
rnn_cell_base.cpp
src/ngraph/op/util/rnn_cell_base.cpp
+11
-0
rnn_cell_base.hpp
src/ngraph/op/util/rnn_cell_base.hpp
+1
-0
unary_elementwise_arithmetic.cpp
src/ngraph/op/util/unary_elementwise_arithmetic.cpp
+5
-0
unary_elementwise_arithmetic.hpp
src/ngraph/op/util/unary_elementwise_arithmetic.hpp
+1
-0
attributes.cpp
test/attributes.cpp
+0
-0
No files found.
src/ngraph/op/experimental/layers/region_yolo.cpp
View file @
5d8c39e9
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "region_yolo.hpp"
#include "ngraph/attribute_visitor.hpp"
using
namespace
std
;
using
namespace
ngraph
;
...
...
@@ -43,6 +44,19 @@ op::RegionYolo::RegionYolo(const Output<Node>& input,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
RegionYolo
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"anchors"
,
m_anchors
);
visitor
.
on_attribute
(
"axis"
,
m_axis
);
visitor
.
on_attribute
(
"coords"
,
m_num_coords
);
visitor
.
on_attribute
(
"classes"
,
m_num_classes
);
visitor
.
on_attribute
(
"end_axis"
,
m_end_axis
);
visitor
.
on_attribute
(
"num"
,
m_num_regions
);
visitor
.
on_attribute
(
"do_softmax"
,
m_do_softmax
);
visitor
.
on_attribute
(
"mask"
,
m_mask
);
return
true
;
}
void
op
::
RegionYolo
::
validate_and_infer_types
()
{
auto
input_et
=
get_input_element_type
(
0
);
...
...
src/ngraph/op/experimental/layers/region_yolo.hpp
View file @
5d8c39e9
...
...
@@ -55,6 +55,7 @@ namespace ngraph
const
int
end_axis
,
const
std
::
vector
<
float
>&
anchors
=
std
::
vector
<
float
>
{});
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/experimental/range.cpp
View file @
5d8c39e9
...
...
@@ -182,6 +182,11 @@ static PartialShape infer_output_shape(const op::Range* node, const element::Typ
return
result
;
}
bool
ngraph
::
op
::
v0
::
Range
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
Range
::
validate_and_infer_types
()
{
set_input_is_relevant_to_shape
(
0
);
...
...
src/ngraph/op/experimental/range.hpp
View file @
5d8c39e9
...
...
@@ -46,6 +46,7 @@ namespace ngraph
const
Output
<
Node
>&
stop
,
const
Output
<
Node
>&
step
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/rnn_cell.cpp
View file @
5d8c39e9
...
...
@@ -65,6 +65,11 @@ op::RNNCell::RNNCell(const Output<Node>& X,
constructor_validate_and_infer_types
();
}
bool
op
::
RNNCell
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
op
::
util
::
RNNCellBase
::
visit_attributes
(
visitor
);
}
void
op
::
RNNCell
::
pre_validate_and_infer_types
()
{
if
(
is_dynamic
())
...
...
src/ngraph/op/fused/rnn_cell.hpp
View file @
5d8c39e9
...
...
@@ -132,6 +132,7 @@ namespace ngraph
const
std
::
vector
<
float
>&
activations_beta
=
{},
float
clip
=
0.
f
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/reshape.cpp
View file @
5d8c39e9
...
...
@@ -156,6 +156,12 @@ op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, b
constructor_validate_and_infer_types
();
}
bool
op
::
v1
::
Reshape
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"special_zero"
,
m_special_zero
);
return
true
;
}
void
op
::
v1
::
Reshape
::
validate_and_infer_types
()
{
auto
pattern_et
=
get_input_element_type
(
1
);
...
...
src/ngraph/op/reshape.hpp
View file @
5d8c39e9
...
...
@@ -137,6 +137,7 @@ namespace ngraph
/// from input shape at the same index.
Reshape
(
const
Output
<
Node
>&
arg
,
const
Output
<
Node
>&
pattern
,
bool
special_zero
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
size_t
get_version
()
const
override
{
return
1
;
}
...
...
src/ngraph/op/result.cpp
View file @
5d8c39e9
...
...
@@ -35,6 +35,11 @@ op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
set_placement_index
(
input_value
(
0
).
get_node
()
->
get_placement_index
());
}
bool
ngraph
::
op
::
v0
::
Result
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
Result
::
validate_and_infer_types
()
{
NODE_VALIDATION_CHECK
(
...
...
src/ngraph/op/result.hpp
View file @
5d8c39e9
...
...
@@ -38,6 +38,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Result
(
const
Output
<
Node
>&
arg
,
bool
needs_default_layout
=
false
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/reverse.cpp
View file @
5d8c39e9
...
...
@@ -17,6 +17,7 @@
#include <algorithm>
#include <sstream>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reverse.hpp"
...
...
@@ -91,6 +92,12 @@ op::v1::Reverse::Reverse(const Output<Node>& data,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
Reverse
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"mode"
,
m_mode
);
return
true
;
}
void
op
::
v1
::
Reverse
::
validate_and_infer_types
()
{
if
(
m_mode
==
Mode
::
MASK
)
...
...
@@ -194,3 +201,22 @@ op::v1::Reverse::Mode op::v1::Reverse::mode_from_string(const std::string& mode)
return
allowed_values
.
at
(
mode
);
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
v1
::
Reverse
::
Mode
>&
EnumNames
<
op
::
v1
::
Reverse
::
Mode
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
v1
::
Reverse
::
Mode
>
(
"op::v1::Reverse::Mode"
,
{{
"index"
,
op
::
v1
::
Reverse
::
Mode
::
INDEX
},
{
"mask"
,
op
::
v1
::
Reverse
::
Mode
::
MASK
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
v1
::
Reverse
::
Mode
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
Reverse
::
Mode
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
src/ngraph/op/reverse.hpp
View file @
5d8c39e9
...
...
@@ -108,6 +108,7 @@ namespace ngraph
const
Output
<
Node
>&
reversed_axes
,
const
Mode
mode
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
@@ -135,4 +136,20 @@ namespace ngraph
// default opset version
using
v0
::
Reverse
;
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v1
::
Reverse
::
Mode
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
v1
::
Reverse
::
Mode
>
:
public
EnumAttributeAdapterBase
<
op
::
v1
::
Reverse
::
Mode
>
{
public
:
AttributeAdapter
(
op
::
v1
::
Reverse
::
Mode
&
value
)
:
EnumAttributeAdapterBase
<
op
::
v1
::
Reverse
::
Mode
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::v1::Reverse::Mode>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
src/ngraph/op/reverse_sequence.cpp
View file @
5d8c39e9
...
...
@@ -17,6 +17,7 @@
#include <algorithm>
#include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/validation_util.hpp"
...
...
@@ -39,6 +40,13 @@ op::ReverseSequence::ReverseSequence(const Output<Node>& arg,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
ReverseSequence
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"batch_axis"
,
m_batch_axis
);
visitor
.
on_attribute
(
"seq_axis"
,
m_seq_axis
);
return
true
;
}
void
op
::
ReverseSequence
::
validate_and_infer_types
()
{
auto
input_shape
=
get_input_partial_shape
(
0
);
...
...
src/ngraph/op/reverse_sequence.hpp
View file @
5d8c39e9
...
...
@@ -38,6 +38,7 @@ namespace ngraph
int64_t
batch_axis
,
int64_t
seq_axis
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/util/arithmetic_reductions_keep_dims.cpp
View file @
5d8c39e9
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
...
...
@@ -30,6 +31,12 @@ op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims(
{
}
bool
ngraph
::
op
::
util
::
ArithmeticReductionKeepDims
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"keep_dims"
,
m_keep_dims
);
return
true
;
}
void
op
::
util
::
ArithmeticReductionKeepDims
::
validate_and_infer_types
()
{
if
(
m_keep_dims
)
...
...
src/ngraph/op/util/arithmetic_reductions_keep_dims.hpp
View file @
5d8c39e9
...
...
@@ -37,6 +37,8 @@ namespace ngraph
const
Output
<
Node
>&
reduction_axes
,
bool
keep_dims
=
false
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
public
:
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/util/logical_reduction_keep_dims.cpp
View file @
5d8c39e9
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/logical_reduction_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
...
...
@@ -30,6 +31,12 @@ op::util::LogicalReductionKeepDims::LogicalReductionKeepDims(
{
}
bool
ngraph
::
op
::
util
::
LogicalReductionKeepDims
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"keep_dims"
,
m_keep_dims
);
return
true
;
}
void
op
::
util
::
LogicalReductionKeepDims
::
validate_and_infer_types
()
{
if
(
m_keep_dims
)
...
...
src/ngraph/op/util/logical_reduction_keep_dims.hpp
View file @
5d8c39e9
...
...
@@ -37,6 +37,8 @@ namespace ngraph
const
Output
<
Node
>&
reduction_axes
,
const
bool
keep_dims
=
false
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
public
:
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/util/rnn_cell_base.cpp
View file @
5d8c39e9
...
...
@@ -17,6 +17,7 @@
#include <algorithm>
#include <iterator>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/multiply.hpp"
...
...
@@ -48,6 +49,16 @@ op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
{
}
bool
ngraph
::
op
::
util
::
RNNCellBase
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"hidden_size"
,
m_hidden_size
);
visitor
.
on_attribute
(
"activations"
,
m_activations
);
visitor
.
on_attribute
(
"activations_alpha"
,
m_activations_alpha
);
visitor
.
on_attribute
(
"activations_beta"
,
m_activations_beta
);
visitor
.
on_attribute
(
"clip"
,
m_clip
);
return
true
;
}
op
::
util
::
ActivationFunction
op
::
util
::
RNNCellBase
::
get_activation_function
(
size_t
idx
)
const
{
op
::
util
::
ActivationFunction
afunc
=
get_activation_func_by_name
(
m_activations
.
at
(
idx
));
...
...
src/ngraph/op/util/rnn_cell_base.hpp
View file @
5d8c39e9
...
...
@@ -58,6 +58,7 @@ namespace ngraph
RNNCellBase
()
=
default
;
virtual
bool
visit_attributes
(
AttributeVisitor
&
visitor
);
std
::
size_t
get_hidden_size
()
const
{
return
m_hidden_size
;
}
float
get_clip
()
const
{
return
m_clip
;
}
const
std
::
vector
<
std
::
string
>&
get_activations
()
const
{
return
m_activations
;
}
...
...
src/ngraph/op/util/unary_elementwise_arithmetic.cpp
View file @
5d8c39e9
...
...
@@ -43,3 +43,8 @@ void op::util::UnaryElementwiseArithmetic::validate_and_infer_types()
{
validate_and_infer_elementwise_arithmetic
();
}
bool
op
::
util
::
UnaryElementwiseArithmetic
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
src/ngraph/op/util/unary_elementwise_arithmetic.hpp
View file @
5d8c39e9
...
...
@@ -68,6 +68,7 @@ namespace ngraph
public
:
void
validate_and_infer_types
()
override
;
bool
is_unary_elementwise_arithmetic
()
const
override
{
return
true
;
}
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
};
}
}
...
...
test/attributes.cpp
View file @
5d8c39e9
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment