Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
205ff94a
Unverified
Commit
205ff94a
authored
Mar 02, 2020
by
Scott Cyphers
Committed by
GitHub
Mar 02, 2020
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into cloudhan_fix-mkldnn_v1
parents
4973b763
61b4e6f9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
35 changed files
with
304 additions
and
3 deletions
+304
-3
attribute_adapter.cpp
src/ngraph/attribute_adapter.cpp
+36
-0
attribute_adapter.hpp
src/ngraph/attribute_adapter.hpp
+34
-0
attribute_visitor.hpp
src/ngraph/attribute_visitor.hpp
+17
-2
coordinate_diff.cpp
src/ngraph/coordinate_diff.cpp
+1
-1
erf.cpp
src/ngraph/op/erf.cpp
+5
-0
erf.hpp
src/ngraph/op/erf.hpp
+1
-0
exp.cpp
src/ngraph/op/exp.cpp
+5
-0
exp.hpp
src/ngraph/op/exp.hpp
+1
-0
floor.cpp
src/ngraph/op/floor.cpp
+5
-0
floor.hpp
src/ngraph/op/floor.hpp
+1
-0
elu.cpp
src/ngraph/op/fused/elu.cpp
+7
-0
elu.hpp
src/ngraph/op/fused/elu.hpp
+1
-0
fake_quantize.cpp
src/ngraph/op/fused/fake_quantize.cpp
+8
-0
fake_quantize.hpp
src/ngraph/op/fused/fake_quantize.hpp
+1
-0
grn.cpp
src/ngraph/op/fused/grn.cpp
+7
-0
grn.hpp
src/ngraph/op/fused/grn.hpp
+1
-0
group_conv.cpp
src/ngraph/op/fused/group_conv.cpp
+22
-0
group_conv.hpp
src/ngraph/op/fused/group_conv.hpp
+3
-0
hard_sigmoid.cpp
src/ngraph/op/fused/hard_sigmoid.cpp
+5
-0
hard_sigmoid.hpp
src/ngraph/op/fused/hard_sigmoid.hpp
+1
-0
lstm_cell.cpp
src/ngraph/op/fused/lstm_cell.cpp
+37
-0
lstm_cell.hpp
src/ngraph/op/fused/lstm_cell.hpp
+17
-0
lstm_sequence.cpp
src/ngraph/op/fused/lstm_sequence.cpp
+35
-0
lstm_sequence.hpp
src/ngraph/op/fused/lstm_sequence.hpp
+18
-0
gather.cpp
src/ngraph/op/gather.cpp
+5
-0
gather.hpp
src/ngraph/op/gather.hpp
+1
-0
gather_tree.cpp
src/ngraph/op/gather_tree.cpp
+5
-0
gather_tree.hpp
src/ngraph/op/gather_tree.hpp
+1
-0
log.cpp
src/ngraph/op/log.cpp
+5
-0
log.hpp
src/ngraph/op/log.hpp
+1
-0
lrn.cpp
src/ngraph/op/lrn.cpp
+10
-0
lrn.hpp
src/ngraph/op/lrn.hpp
+1
-0
not.cpp
src/ngraph/op/not.cpp
+5
-0
not.hpp
src/ngraph/op/not.hpp
+1
-0
attributes.cpp
test/attributes.cpp
+0
-0
No files found.
src/ngraph/attribute_adapter.cpp
View file @
205ff94a
...
...
@@ -242,4 +242,40 @@ namespace ngraph
m_value
=
copy_from
<
vector
<
uint64_t
>>
(
value
);
m_buffer_valid
=
false
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
vector
<
float
>>::
type_info
;
const
vector
<
float
>&
AttributeAdapter
<
vector
<
float
>>::
get
()
{
if
(
!
m_buffer_valid
)
{
m_buffer
=
copy_from
<
vector
<
float
>>
(
m_value
);
m_buffer_valid
=
true
;
}
return
m_buffer
;
}
void
AttributeAdapter
<
vector
<
float
>>::
set
(
const
vector
<
float
>&
value
)
{
m_value
=
copy_from
<
vector
<
float
>>
(
value
);
m_buffer_valid
=
false
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
vector
<
string
>>::
type_info
;
const
vector
<
string
>&
AttributeAdapter
<
vector
<
string
>>::
get
()
{
if
(
!
m_buffer_valid
)
{
m_buffer
=
copy_from
<
vector
<
string
>>
(
m_value
);
m_buffer_valid
=
true
;
}
return
m_buffer
;
}
void
AttributeAdapter
<
vector
<
string
>>::
set
(
const
vector
<
string
>&
value
)
{
m_value
=
copy_from
<
vector
<
string
>>
(
value
);
m_buffer_valid
=
false
;
}
}
src/ngraph/attribute_adapter.hpp
View file @
205ff94a
...
...
@@ -16,6 +16,7 @@
#pragma once
#include <string>
#include <type_traits>
#include <vector>
...
...
@@ -299,6 +300,39 @@ namespace ngraph
void
set
(
const
std
::
vector
<
int64_t
>&
value
)
override
;
};
template
<>
class
NGRAPH_API
AttributeAdapter
<
std
::
vector
<
float
>>
:
public
ValueReference
<
std
::
vector
<
float
>>
,
public
ValueAccessor
<
std
::
vector
<
float
>>
{
public
:
AttributeAdapter
(
std
::
vector
<
float
>&
value
)
:
ValueReference
<
std
::
vector
<
float
>>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<vector<float>>"
,
0
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
const
std
::
vector
<
float
>&
get
()
override
;
void
set
(
const
std
::
vector
<
float
>&
value
)
override
;
};
template
<>
class
NGRAPH_API
AttributeAdapter
<
std
::
vector
<
std
::
string
>>
:
public
ValueReference
<
std
::
vector
<
std
::
string
>>
,
public
ValueAccessor
<
std
::
vector
<
std
::
string
>>
{
public
:
AttributeAdapter
(
std
::
vector
<
std
::
string
>&
value
)
:
ValueReference
<
std
::
vector
<
std
::
string
>>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<vector<string>>"
,
0
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
const
std
::
vector
<
std
::
string
>&
get
()
override
;
void
set
(
const
std
::
vector
<
std
::
string
>&
value
)
override
;
};
template
<
typename
A
,
typename
B
>
A
copy_from
(
B
&
b
)
{
...
...
src/ngraph/attribute_visitor.hpp
View file @
205ff94a
...
...
@@ -46,16 +46,25 @@ namespace ngraph
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
};
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
double
>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
std
::
vector
<
int64_t
>>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
std
::
vector
<
float
>
>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
double
>&
adapter
)
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
std
::
vector
<
std
::
string
>>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
...
...
@@ -68,5 +77,11 @@ namespace ngraph
AttributeAdapter
<
T
>
adapter
(
value
);
on_adapter
(
name
,
adapter
);
}
void
on_attribute
(
const
std
::
string
&
name
,
op
::
AutoBroadcastSpec
&
value
)
{
AttributeAdapter
<
op
::
AutoBroadcastType
>
adapter
(
value
.
m_type
);
on_adapter
(
name
,
adapter
);
}
};
}
src/ngraph/coordinate_diff.cpp
View file @
205ff94a
...
...
@@ -80,7 +80,7 @@ const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get()
void
AttributeAdapter
<
CoordinateDiff
>::
set
(
const
vector
<
int64_t
>&
value
)
{
m_value
=
copy_from
<
CoordinateDiff
>
(
m_
value
);
m_value
=
copy_from
<
CoordinateDiff
>
(
value
);
m_buffer_valid
=
false
;
}
...
...
src/ngraph/op/erf.cpp
View file @
205ff94a
...
...
@@ -23,6 +23,11 @@ using namespace ngraph;
constexpr
NodeTypeInfo
op
::
Erf
::
type_info
;
bool
ngraph
::
op
::
v0
::
Erf
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Erf
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/erf.hpp
View file @
205ff94a
...
...
@@ -32,6 +32,7 @@ namespace ngraph
Erf
()
=
default
;
Erf
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
};
...
...
src/ngraph/op/exp.cpp
View file @
205ff94a
...
...
@@ -28,6 +28,11 @@ op::Exp::Exp(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Exp
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Exp
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/exp.hpp
View file @
205ff94a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Exp
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/floor.cpp
View file @
205ff94a
...
...
@@ -27,6 +27,11 @@ op::Floor::Floor(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Floor
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Floor
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/floor.hpp
View file @
205ff94a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Floor
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
};
...
...
src/ngraph/op/fused/elu.cpp
View file @
205ff94a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/add.hpp"
...
...
@@ -37,6 +38,12 @@ op::Elu::Elu(const Output<Node>& data, const double alpha)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Elu
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"alpha"
,
m_alpha
);
return
true
;
}
NodeVector
op
::
Elu
::
decompose_op
()
const
{
auto
data
=
input_value
(
0
);
...
...
src/ngraph/op/fused/elu.hpp
View file @
205ff94a
...
...
@@ -42,6 +42,7 @@ namespace ngraph
/// \param alpha Multiplier for negative values
Elu
(
const
Output
<
Node
>&
data
,
const
double
alpha
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/fake_quantize.cpp
View file @
205ff94a
...
...
@@ -17,6 +17,7 @@
#include <memory>
#include "fake_quantize.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
...
...
@@ -80,6 +81,13 @@ void op::FakeQuantize::validate_and_infer_types()
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_partial_shape
(
0
));
}
bool
ngraph
::
op
::
v0
::
FakeQuantize
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"levels"
,
m_levels
);
visitor
.
on_attribute
(
"auto_broadcast"
,
m_auto_broadcast
);
return
true
;
}
NodeVector
op
::
FakeQuantize
::
decompose_op
()
const
{
Output
<
Node
>
data
{
input_value
(
0
)};
...
...
src/ngraph/op/fused/fake_quantize.hpp
View file @
205ff94a
...
...
@@ -67,6 +67,7 @@ namespace ngraph
const
AutoBroadcastSpec
&
auto_broadcast
=
AutoBroadcastSpec
(
AutoBroadcastType
::
NUMPY
));
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/fused/grn.cpp
View file @
205ff94a
...
...
@@ -17,6 +17,7 @@
#include <iterator>
#include "grn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp"
...
...
@@ -36,6 +37,12 @@ op::GRN::GRN(const Output<Node>& data, float bias)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
GRN
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"bias"
,
m_bias
);
return
true
;
}
void
op
::
GRN
::
pre_validate_and_infer_types
()
{
const
auto
&
data_pshape
=
get_input_partial_shape
(
0
);
...
...
src/ngraph/op/fused/grn.hpp
View file @
205ff94a
...
...
@@ -42,6 +42,7 @@ namespace ngraph
///
GRN
(
const
Output
<
Node
>&
data
,
float
bias
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
float
get_bias
()
const
{
return
m_bias
;
}
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
...
...
src/ngraph/op/fused/group_conv.cpp
View file @
205ff94a
...
...
@@ -18,6 +18,7 @@
#include "group_conv.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
...
...
@@ -57,6 +58,16 @@ op::v1::GroupConvolution::GroupConvolution(const Output<Node>& data_batch,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
GroupConvolution
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"strides"
,
m_strides
);
visitor
.
on_attribute
(
"pads_begin"
,
m_pads_begin
);
visitor
.
on_attribute
(
"pads_end"
,
m_pads_end
);
visitor
.
on_attribute
(
"dilations"
,
m_dilations
);
visitor
.
on_attribute
(
"auto_pad"
,
m_auto_pad
);
return
true
;
}
void
op
::
v1
::
GroupConvolution
::
validate_and_infer_types
()
{
const
PartialShape
&
data_batch_pshape
=
get_input_partial_shape
(
0
);
...
...
@@ -219,6 +230,17 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData(
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
GroupConvolutionBackpropData
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"strides"
,
m_strides
);
visitor
.
on_attribute
(
"pads_begin"
,
m_pads_begin
);
visitor
.
on_attribute
(
"pads_end"
,
m_pads_end
);
visitor
.
on_attribute
(
"dilations"
,
m_dilations
);
visitor
.
on_attribute
(
"auto_pad"
,
m_auto_pad
);
visitor
.
on_attribute
(
"output_padding"
,
m_output_padding
);
return
true
;
}
bool
op
::
v1
::
GroupConvolutionBackpropData
::
is_dynamic
()
const
{
bool
is_dynamic
=
Node
::
is_dynamic
();
...
...
src/ngraph/op/fused/group_conv.hpp
View file @
205ff94a
...
...
@@ -61,6 +61,8 @@ namespace ngraph
const
CoordinateDiff
&
pads_end
,
const
Strides
&
dilations
,
const
PadType
&
auto_pad
=
PadType
::
EXPLICIT
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
// TODO - Remove supports_decompose and validate_and_infer_type once op supports
// decomposition
bool
supports_decompose
()
const
override
{
return
false
;
}
...
...
@@ -187,6 +189,7 @@ namespace ngraph
const
PadType
&
auto_pad
=
PadType
::
EXPLICIT
,
const
CoordinateDiff
&
output_padding
=
{});
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
bool
is_dynamic
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
...
...
src/ngraph/op/fused/hard_sigmoid.cpp
View file @
205ff94a
...
...
@@ -37,6 +37,11 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
HardSigmoid
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
HardSigmoid
::
pre_validate_and_infer_types
()
{
const
auto
&
alpha_pshape
=
get_input_partial_shape
(
1
);
...
...
src/ngraph/op/fused/hard_sigmoid.hpp
View file @
205ff94a
...
...
@@ -46,6 +46,7 @@ namespace ngraph
const
Output
<
Node
>&
alpha
,
const
Output
<
Node
>&
beta
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/lstm_cell.cpp
View file @
205ff94a
...
...
@@ -17,6 +17,7 @@
#include <cmath>
#include <functional>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp"
...
...
@@ -107,6 +108,19 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
LSTMCell
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"hidden_size"
,
m_hidden_size
);
visitor
.
on_attribute
(
"activations"
,
m_activations
);
visitor
.
on_attribute
(
"activations_alpha"
,
m_activations_alpha
);
visitor
.
on_attribute
(
"activations_beta"
,
m_activations_beta
);
visitor
.
on_attribute
(
"clip"
,
m_clip
);
visitor
.
on_attribute
(
"input_forget"
,
m_input_forget
);
visitor
.
on_attribute
(
"weights_format"
,
m_weights_format
);
return
true
;
}
void
op
::
LSTMCell
::
pre_validate_and_infer_types
()
{
set_output_size
(
2
);
...
...
@@ -386,3 +400,26 @@ shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) co
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
LSTMWeightsFormat
>&
EnumNames
<
op
::
LSTMWeightsFormat
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
LSTMWeightsFormat
>
(
"op::LSTMWeightsFormat"
,
{{
"fico"
,
op
::
LSTMWeightsFormat
::
FICO
},
{
"icof"
,
op
::
LSTMWeightsFormat
::
ICOF
},
{
"ifco"
,
op
::
LSTMWeightsFormat
::
IFCO
},
{
"ifoc"
,
op
::
LSTMWeightsFormat
::
IFOC
},
{
"iofc"
,
op
::
LSTMWeightsFormat
::
IOFC
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
LSTMWeightsFormat
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
LSTMWeightsFormat
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
src/ngraph/op/fused/lstm_cell.hpp
View file @
205ff94a
...
...
@@ -224,6 +224,7 @@ namespace ngraph
float
clip
=
0.
f
,
bool
input_forget
=
false
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
@@ -284,4 +285,20 @@ namespace ngraph
}
using
v0
::
LSTMCell
;
}
// namespace op
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
LSTMWeightsFormat
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
LSTMWeightsFormat
>
:
public
EnumAttributeAdapterBase
<
op
::
LSTMWeightsFormat
>
{
public
:
AttributeAdapter
(
op
::
LSTMWeightsFormat
&
value
)
:
EnumAttributeAdapterBase
<
op
::
LSTMWeightsFormat
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::LSTMWeightsFormat>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
// namespace ngraph
src/ngraph/op/fused/lstm_sequence.cpp
View file @
205ff94a
...
...
@@ -16,6 +16,7 @@
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
...
...
@@ -32,6 +33,19 @@ using namespace ngraph;
using
namespace
std
;
constexpr
NodeTypeInfo
op
::
LSTMSequence
::
type_info
;
bool
ngraph
::
op
::
v0
::
LSTMSequence
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"hidden_size"
,
m_hidden_size
);
visitor
.
on_attribute
(
"activations"
,
m_activations
);
visitor
.
on_attribute
(
"activations_alpha"
,
m_activations_alpha
);
visitor
.
on_attribute
(
"activations_beta"
,
m_activations_beta
);
visitor
.
on_attribute
(
"clip"
,
m_clip_threshold
);
visitor
.
on_attribute
(
"direction"
,
m_direction
);
visitor
.
on_attribute
(
"input_forget"
,
m_input_forget
);
visitor
.
on_attribute
(
"weights_format"
,
m_weights_format
);
return
true
;
}
NodeVector
op
::
LSTMSequence
::
decompose_op
()
const
{
NodeVector
results
;
...
...
@@ -247,3 +261,24 @@ shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reve
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
return
builder
::
squeeze
(
tmp
);
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
v0
::
LSTMSequence
::
direction
>&
EnumNames
<
op
::
v0
::
LSTMSequence
::
direction
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
v0
::
LSTMSequence
::
direction
>
(
"op::v0::LSTMSequence::direction"
,
{{
"forward"
,
op
::
v0
::
LSTMSequence
::
direction
::
FORWARD
},
{
"reverse"
,
op
::
v0
::
LSTMSequence
::
direction
::
REVERSE
},
{
"bidirectional"
,
op
::
v0
::
LSTMSequence
::
direction
::
BIDIRECTIONAL
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
v0
::
LSTMSequence
::
direction
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v0
::
LSTMSequence
::
direction
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
src/ngraph/op/fused/lstm_sequence.hpp
View file @
205ff94a
...
...
@@ -135,6 +135,7 @@ namespace ngraph
{
}
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
@@ -185,4 +186,21 @@ namespace ngraph
}
using
v0
::
LSTMSequence
;
}
// namespace op
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v0
::
LSTMSequence
::
direction
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
v0
::
LSTMSequence
::
direction
>
:
public
EnumAttributeAdapterBase
<
op
::
v0
::
LSTMSequence
::
direction
>
{
public
:
AttributeAdapter
(
op
::
v0
::
LSTMSequence
::
direction
&
value
)
:
EnumAttributeAdapterBase
<
op
::
v0
::
LSTMSequence
::
direction
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::v0::LSTMSequence::direction>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
// namespace ngraph
src/ngraph/op/gather.cpp
View file @
205ff94a
...
...
@@ -112,6 +112,11 @@ op::v1::Gather::Gather(const Output<Node>& params,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
Gather
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
v1
::
Gather
::
validate_and_infer_types
()
{
const
auto
&
input_rank
=
get_input_partial_shape
(
PARAMS
).
rank
();
...
...
src/ngraph/op/gather.hpp
View file @
205ff94a
...
...
@@ -67,6 +67,7 @@ namespace ngraph
const
Output
<
Node
>&
indices
,
const
Output
<
Node
>&
axis
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
int64_t
get_axis
()
const
;
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/gather_tree.cpp
View file @
205ff94a
...
...
@@ -38,6 +38,11 @@ shared_ptr<Node> op::v1::GatherTree::copy_with_new_args(const NodeVector& new_ar
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
));
}
bool
ngraph
::
op
::
v1
::
GatherTree
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
v1
::
GatherTree
::
validate_and_infer_types
()
{
const
auto
&
step_ids_rank
=
get_input_partial_shape
(
0
);
...
...
src/ngraph/op/gather_tree.hpp
View file @
205ff94a
...
...
@@ -44,6 +44,7 @@ namespace ngraph
const
Output
<
Node
>&
max_seq_len
,
const
Output
<
Node
>&
end_token
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/log.cpp
View file @
205ff94a
...
...
@@ -28,6 +28,11 @@ op::Log::Log(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Log
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Log
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/log.hpp
View file @
205ff94a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Log
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/lrn.cpp
View file @
205ff94a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/lrn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
...
...
@@ -111,6 +112,15 @@ void op::LRN::validate_and_infer_types()
")."
);
}
bool
ngraph
::
op
::
v0
::
LRN
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"alpha"
,
m_alpha
);
visitor
.
on_attribute
(
"beta"
,
m_beta
);
visitor
.
on_attribute
(
"bias"
,
m_bias
);
visitor
.
on_attribute
(
"size"
,
m_size
);
return
true
;
}
shared_ptr
<
Node
>
op
::
LRN
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/lrn.hpp
View file @
205ff94a
...
...
@@ -58,6 +58,7 @@ namespace ngraph
double
bias
,
size_t
size
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/not.cpp
View file @
205ff94a
...
...
@@ -28,6 +28,11 @@ op::v1::LogicalNot::LogicalNot(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
LogicalNot
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void
op
::
v1
::
LogicalNot
::
validate_and_infer_types
()
{
...
...
src/ngraph/op/not.hpp
View file @
205ff94a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
LogicalNot
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
test/attributes.cpp
View file @
205ff94a
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment