Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
205ff94a
Unverified
Commit
205ff94a
authored
Mar 02, 2020
by
Scott Cyphers
Committed by
GitHub
Mar 02, 2020
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into cloudhan_fix-mkldnn_v1
parents
4973b763
61b4e6f9
Hide whitespace changes
Inline
Side-by-side
Showing
35 changed files
with
555 additions
and
12 deletions
+555
-12
attribute_adapter.cpp
src/ngraph/attribute_adapter.cpp
+36
-0
attribute_adapter.hpp
src/ngraph/attribute_adapter.hpp
+34
-0
attribute_visitor.hpp
src/ngraph/attribute_visitor.hpp
+17
-2
coordinate_diff.cpp
src/ngraph/coordinate_diff.cpp
+1
-1
erf.cpp
src/ngraph/op/erf.cpp
+5
-0
erf.hpp
src/ngraph/op/erf.hpp
+1
-0
exp.cpp
src/ngraph/op/exp.cpp
+5
-0
exp.hpp
src/ngraph/op/exp.hpp
+1
-0
floor.cpp
src/ngraph/op/floor.cpp
+5
-0
floor.hpp
src/ngraph/op/floor.hpp
+1
-0
elu.cpp
src/ngraph/op/fused/elu.cpp
+7
-0
elu.hpp
src/ngraph/op/fused/elu.hpp
+1
-0
fake_quantize.cpp
src/ngraph/op/fused/fake_quantize.cpp
+8
-0
fake_quantize.hpp
src/ngraph/op/fused/fake_quantize.hpp
+1
-0
grn.cpp
src/ngraph/op/fused/grn.cpp
+7
-0
grn.hpp
src/ngraph/op/fused/grn.hpp
+1
-0
group_conv.cpp
src/ngraph/op/fused/group_conv.cpp
+22
-0
group_conv.hpp
src/ngraph/op/fused/group_conv.hpp
+3
-0
hard_sigmoid.cpp
src/ngraph/op/fused/hard_sigmoid.cpp
+5
-0
hard_sigmoid.hpp
src/ngraph/op/fused/hard_sigmoid.hpp
+1
-0
lstm_cell.cpp
src/ngraph/op/fused/lstm_cell.cpp
+37
-0
lstm_cell.hpp
src/ngraph/op/fused/lstm_cell.hpp
+17
-0
lstm_sequence.cpp
src/ngraph/op/fused/lstm_sequence.cpp
+35
-0
lstm_sequence.hpp
src/ngraph/op/fused/lstm_sequence.hpp
+18
-0
gather.cpp
src/ngraph/op/gather.cpp
+5
-0
gather.hpp
src/ngraph/op/gather.hpp
+1
-0
gather_tree.cpp
src/ngraph/op/gather_tree.cpp
+5
-0
gather_tree.hpp
src/ngraph/op/gather_tree.hpp
+1
-0
log.cpp
src/ngraph/op/log.cpp
+5
-0
log.hpp
src/ngraph/op/log.hpp
+1
-0
lrn.cpp
src/ngraph/op/lrn.cpp
+10
-0
lrn.hpp
src/ngraph/op/lrn.hpp
+1
-0
not.cpp
src/ngraph/op/not.cpp
+5
-0
not.hpp
src/ngraph/op/not.hpp
+1
-0
attributes.cpp
test/attributes.cpp
+251
-9
No files found.
src/ngraph/attribute_adapter.cpp
View file @
205ff94a
...
...
@@ -242,4 +242,40 @@ namespace ngraph
m_value
=
copy_from
<
vector
<
uint64_t
>>
(
value
);
m_buffer_valid
=
false
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
vector
<
float
>>::
type_info
;
const
vector
<
float
>&
AttributeAdapter
<
vector
<
float
>>::
get
()
{
if
(
!
m_buffer_valid
)
{
m_buffer
=
copy_from
<
vector
<
float
>>
(
m_value
);
m_buffer_valid
=
true
;
}
return
m_buffer
;
}
void
AttributeAdapter
<
vector
<
float
>>::
set
(
const
vector
<
float
>&
value
)
{
m_value
=
copy_from
<
vector
<
float
>>
(
value
);
m_buffer_valid
=
false
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
vector
<
string
>>::
type_info
;
const
vector
<
string
>&
AttributeAdapter
<
vector
<
string
>>::
get
()
{
if
(
!
m_buffer_valid
)
{
m_buffer
=
copy_from
<
vector
<
string
>>
(
m_value
);
m_buffer_valid
=
true
;
}
return
m_buffer
;
}
void
AttributeAdapter
<
vector
<
string
>>::
set
(
const
vector
<
string
>&
value
)
{
m_value
=
copy_from
<
vector
<
string
>>
(
value
);
m_buffer_valid
=
false
;
}
}
src/ngraph/attribute_adapter.hpp
View file @
205ff94a
...
...
@@ -16,6 +16,7 @@
#pragma once
#include <string>
#include <type_traits>
#include <vector>
...
...
@@ -299,6 +300,39 @@ namespace ngraph
void
set
(
const
std
::
vector
<
int64_t
>&
value
)
override
;
};
template
<>
class
NGRAPH_API
AttributeAdapter
<
std
::
vector
<
float
>>
:
public
ValueReference
<
std
::
vector
<
float
>>
,
public
ValueAccessor
<
std
::
vector
<
float
>>
{
public
:
AttributeAdapter
(
std
::
vector
<
float
>&
value
)
:
ValueReference
<
std
::
vector
<
float
>>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<vector<float>>"
,
0
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
const
std
::
vector
<
float
>&
get
()
override
;
void
set
(
const
std
::
vector
<
float
>&
value
)
override
;
};
template
<>
class
NGRAPH_API
AttributeAdapter
<
std
::
vector
<
std
::
string
>>
:
public
ValueReference
<
std
::
vector
<
std
::
string
>>
,
public
ValueAccessor
<
std
::
vector
<
std
::
string
>>
{
public
:
AttributeAdapter
(
std
::
vector
<
std
::
string
>&
value
)
:
ValueReference
<
std
::
vector
<
std
::
string
>>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<vector<string>>"
,
0
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
const
std
::
vector
<
std
::
string
>&
get
()
override
;
void
set
(
const
std
::
vector
<
std
::
string
>&
value
)
override
;
};
template
<
typename
A
,
typename
B
>
A
copy_from
(
B
&
b
)
{
...
...
src/ngraph/attribute_visitor.hpp
View file @
205ff94a
...
...
@@ -46,16 +46,25 @@ namespace ngraph
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
};
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
double
>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
std
::
vector
<
int64_t
>>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
std
::
vector
<
float
>
>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
double
>&
adapter
)
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
std
::
vector
<
std
::
string
>>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
...
...
@@ -68,5 +77,11 @@ namespace ngraph
AttributeAdapter
<
T
>
adapter
(
value
);
on_adapter
(
name
,
adapter
);
}
void
on_attribute
(
const
std
::
string
&
name
,
op
::
AutoBroadcastSpec
&
value
)
{
AttributeAdapter
<
op
::
AutoBroadcastType
>
adapter
(
value
.
m_type
);
on_adapter
(
name
,
adapter
);
}
};
}
src/ngraph/coordinate_diff.cpp
View file @
205ff94a
...
...
@@ -80,7 +80,7 @@ const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get()
void
AttributeAdapter
<
CoordinateDiff
>::
set
(
const
vector
<
int64_t
>&
value
)
{
m_value
=
copy_from
<
CoordinateDiff
>
(
m_
value
);
m_value
=
copy_from
<
CoordinateDiff
>
(
value
);
m_buffer_valid
=
false
;
}
...
...
src/ngraph/op/erf.cpp
View file @
205ff94a
...
...
@@ -23,6 +23,11 @@ using namespace ngraph;
constexpr
NodeTypeInfo
op
::
Erf
::
type_info
;
bool
ngraph
::
op
::
v0
::
Erf
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Erf
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/erf.hpp
View file @
205ff94a
...
...
@@ -32,6 +32,7 @@ namespace ngraph
Erf
()
=
default
;
Erf
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
};
...
...
src/ngraph/op/exp.cpp
View file @
205ff94a
...
...
@@ -28,6 +28,11 @@ op::Exp::Exp(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Exp
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Exp
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/exp.hpp
View file @
205ff94a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Exp
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/floor.cpp
View file @
205ff94a
...
...
@@ -27,6 +27,11 @@ op::Floor::Floor(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Floor
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Floor
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/floor.hpp
View file @
205ff94a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Floor
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
};
...
...
src/ngraph/op/fused/elu.cpp
View file @
205ff94a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/add.hpp"
...
...
@@ -37,6 +38,12 @@ op::Elu::Elu(const Output<Node>& data, const double alpha)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Elu
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"alpha"
,
m_alpha
);
return
true
;
}
NodeVector
op
::
Elu
::
decompose_op
()
const
{
auto
data
=
input_value
(
0
);
...
...
src/ngraph/op/fused/elu.hpp
View file @
205ff94a
...
...
@@ -42,6 +42,7 @@ namespace ngraph
/// \param alpha Multiplier for negative values
Elu
(
const
Output
<
Node
>&
data
,
const
double
alpha
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/fake_quantize.cpp
View file @
205ff94a
...
...
@@ -17,6 +17,7 @@
#include <memory>
#include "fake_quantize.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
...
...
@@ -80,6 +81,13 @@ void op::FakeQuantize::validate_and_infer_types()
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_partial_shape
(
0
));
}
bool
ngraph
::
op
::
v0
::
FakeQuantize
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"levels"
,
m_levels
);
visitor
.
on_attribute
(
"auto_broadcast"
,
m_auto_broadcast
);
return
true
;
}
NodeVector
op
::
FakeQuantize
::
decompose_op
()
const
{
Output
<
Node
>
data
{
input_value
(
0
)};
...
...
src/ngraph/op/fused/fake_quantize.hpp
View file @
205ff94a
...
...
@@ -67,6 +67,7 @@ namespace ngraph
const
AutoBroadcastSpec
&
auto_broadcast
=
AutoBroadcastSpec
(
AutoBroadcastType
::
NUMPY
));
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/fused/grn.cpp
View file @
205ff94a
...
...
@@ -17,6 +17,7 @@
#include <iterator>
#include "grn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp"
...
...
@@ -36,6 +37,12 @@ op::GRN::GRN(const Output<Node>& data, float bias)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
GRN
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"bias"
,
m_bias
);
return
true
;
}
void
op
::
GRN
::
pre_validate_and_infer_types
()
{
const
auto
&
data_pshape
=
get_input_partial_shape
(
0
);
...
...
src/ngraph/op/fused/grn.hpp
View file @
205ff94a
...
...
@@ -42,6 +42,7 @@ namespace ngraph
///
GRN
(
const
Output
<
Node
>&
data
,
float
bias
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
float
get_bias
()
const
{
return
m_bias
;
}
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
...
...
src/ngraph/op/fused/group_conv.cpp
View file @
205ff94a
...
...
@@ -18,6 +18,7 @@
#include "group_conv.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
...
...
@@ -57,6 +58,16 @@ op::v1::GroupConvolution::GroupConvolution(const Output<Node>& data_batch,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
GroupConvolution
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"strides"
,
m_strides
);
visitor
.
on_attribute
(
"pads_begin"
,
m_pads_begin
);
visitor
.
on_attribute
(
"pads_end"
,
m_pads_end
);
visitor
.
on_attribute
(
"dilations"
,
m_dilations
);
visitor
.
on_attribute
(
"auto_pad"
,
m_auto_pad
);
return
true
;
}
void
op
::
v1
::
GroupConvolution
::
validate_and_infer_types
()
{
const
PartialShape
&
data_batch_pshape
=
get_input_partial_shape
(
0
);
...
...
@@ -219,6 +230,17 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData(
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
GroupConvolutionBackpropData
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"strides"
,
m_strides
);
visitor
.
on_attribute
(
"pads_begin"
,
m_pads_begin
);
visitor
.
on_attribute
(
"pads_end"
,
m_pads_end
);
visitor
.
on_attribute
(
"dilations"
,
m_dilations
);
visitor
.
on_attribute
(
"auto_pad"
,
m_auto_pad
);
visitor
.
on_attribute
(
"output_padding"
,
m_output_padding
);
return
true
;
}
bool
op
::
v1
::
GroupConvolutionBackpropData
::
is_dynamic
()
const
{
bool
is_dynamic
=
Node
::
is_dynamic
();
...
...
src/ngraph/op/fused/group_conv.hpp
View file @
205ff94a
...
...
@@ -61,6 +61,8 @@ namespace ngraph
const
CoordinateDiff
&
pads_end
,
const
Strides
&
dilations
,
const
PadType
&
auto_pad
=
PadType
::
EXPLICIT
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
// TODO - Remove supports_decompose and validate_and_infer_type once op supports
// decomposition
bool
supports_decompose
()
const
override
{
return
false
;
}
...
...
@@ -187,6 +189,7 @@ namespace ngraph
const
PadType
&
auto_pad
=
PadType
::
EXPLICIT
,
const
CoordinateDiff
&
output_padding
=
{});
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
bool
is_dynamic
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
...
...
src/ngraph/op/fused/hard_sigmoid.cpp
View file @
205ff94a
...
...
@@ -37,6 +37,11 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
HardSigmoid
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
HardSigmoid
::
pre_validate_and_infer_types
()
{
const
auto
&
alpha_pshape
=
get_input_partial_shape
(
1
);
...
...
src/ngraph/op/fused/hard_sigmoid.hpp
View file @
205ff94a
...
...
@@ -46,6 +46,7 @@ namespace ngraph
const
Output
<
Node
>&
alpha
,
const
Output
<
Node
>&
beta
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/lstm_cell.cpp
View file @
205ff94a
...
...
@@ -17,6 +17,7 @@
#include <cmath>
#include <functional>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp"
...
...
@@ -107,6 +108,19 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
LSTMCell
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"hidden_size"
,
m_hidden_size
);
visitor
.
on_attribute
(
"activations"
,
m_activations
);
visitor
.
on_attribute
(
"activations_alpha"
,
m_activations_alpha
);
visitor
.
on_attribute
(
"activations_beta"
,
m_activations_beta
);
visitor
.
on_attribute
(
"clip"
,
m_clip
);
visitor
.
on_attribute
(
"input_forget"
,
m_input_forget
);
visitor
.
on_attribute
(
"weights_format"
,
m_weights_format
);
return
true
;
}
void
op
::
LSTMCell
::
pre_validate_and_infer_types
()
{
set_output_size
(
2
);
...
...
@@ -386,3 +400,26 @@ shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) co
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
LSTMWeightsFormat
>&
EnumNames
<
op
::
LSTMWeightsFormat
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
LSTMWeightsFormat
>
(
"op::LSTMWeightsFormat"
,
{{
"fico"
,
op
::
LSTMWeightsFormat
::
FICO
},
{
"icof"
,
op
::
LSTMWeightsFormat
::
ICOF
},
{
"ifco"
,
op
::
LSTMWeightsFormat
::
IFCO
},
{
"ifoc"
,
op
::
LSTMWeightsFormat
::
IFOC
},
{
"iofc"
,
op
::
LSTMWeightsFormat
::
IOFC
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
LSTMWeightsFormat
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
LSTMWeightsFormat
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
src/ngraph/op/fused/lstm_cell.hpp
View file @
205ff94a
...
...
@@ -224,6 +224,7 @@ namespace ngraph
float
clip
=
0.
f
,
bool
input_forget
=
false
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
@@ -284,4 +285,20 @@ namespace ngraph
}
using
v0
::
LSTMCell
;
}
// namespace op
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
LSTMWeightsFormat
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
LSTMWeightsFormat
>
:
public
EnumAttributeAdapterBase
<
op
::
LSTMWeightsFormat
>
{
public
:
AttributeAdapter
(
op
::
LSTMWeightsFormat
&
value
)
:
EnumAttributeAdapterBase
<
op
::
LSTMWeightsFormat
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::LSTMWeightsFormat>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
// namespace ngraph
src/ngraph/op/fused/lstm_sequence.cpp
View file @
205ff94a
...
...
@@ -16,6 +16,7 @@
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
...
...
@@ -32,6 +33,19 @@ using namespace ngraph;
using
namespace
std
;
constexpr
NodeTypeInfo
op
::
LSTMSequence
::
type_info
;
bool
ngraph
::
op
::
v0
::
LSTMSequence
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"hidden_size"
,
m_hidden_size
);
visitor
.
on_attribute
(
"activations"
,
m_activations
);
visitor
.
on_attribute
(
"activations_alpha"
,
m_activations_alpha
);
visitor
.
on_attribute
(
"activations_beta"
,
m_activations_beta
);
visitor
.
on_attribute
(
"clip"
,
m_clip_threshold
);
visitor
.
on_attribute
(
"direction"
,
m_direction
);
visitor
.
on_attribute
(
"input_forget"
,
m_input_forget
);
visitor
.
on_attribute
(
"weights_format"
,
m_weights_format
);
return
true
;
}
NodeVector
op
::
LSTMSequence
::
decompose_op
()
const
{
NodeVector
results
;
...
...
@@ -247,3 +261,24 @@ shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reve
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
return
builder
::
squeeze
(
tmp
);
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
v0
::
LSTMSequence
::
direction
>&
EnumNames
<
op
::
v0
::
LSTMSequence
::
direction
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
v0
::
LSTMSequence
::
direction
>
(
"op::v0::LSTMSequence::direction"
,
{{
"forward"
,
op
::
v0
::
LSTMSequence
::
direction
::
FORWARD
},
{
"reverse"
,
op
::
v0
::
LSTMSequence
::
direction
::
REVERSE
},
{
"bidirectional"
,
op
::
v0
::
LSTMSequence
::
direction
::
BIDIRECTIONAL
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
v0
::
LSTMSequence
::
direction
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v0
::
LSTMSequence
::
direction
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
src/ngraph/op/fused/lstm_sequence.hpp
View file @
205ff94a
...
...
@@ -135,6 +135,7 @@ namespace ngraph
{
}
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
@@ -185,4 +186,21 @@ namespace ngraph
}
using
v0
::
LSTMSequence
;
}
// namespace op
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v0
::
LSTMSequence
::
direction
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
v0
::
LSTMSequence
::
direction
>
:
public
EnumAttributeAdapterBase
<
op
::
v0
::
LSTMSequence
::
direction
>
{
public
:
AttributeAdapter
(
op
::
v0
::
LSTMSequence
::
direction
&
value
)
:
EnumAttributeAdapterBase
<
op
::
v0
::
LSTMSequence
::
direction
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::v0::LSTMSequence::direction>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
// namespace ngraph
src/ngraph/op/gather.cpp
View file @
205ff94a
...
...
@@ -112,6 +112,11 @@ op::v1::Gather::Gather(const Output<Node>& params,
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
Gather
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
v1
::
Gather
::
validate_and_infer_types
()
{
const
auto
&
input_rank
=
get_input_partial_shape
(
PARAMS
).
rank
();
...
...
src/ngraph/op/gather.hpp
View file @
205ff94a
...
...
@@ -67,6 +67,7 @@ namespace ngraph
const
Output
<
Node
>&
indices
,
const
Output
<
Node
>&
axis
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
int64_t
get_axis
()
const
;
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/gather_tree.cpp
View file @
205ff94a
...
...
@@ -38,6 +38,11 @@ shared_ptr<Node> op::v1::GatherTree::copy_with_new_args(const NodeVector& new_ar
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
));
}
bool
ngraph
::
op
::
v1
::
GatherTree
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
v1
::
GatherTree
::
validate_and_infer_types
()
{
const
auto
&
step_ids_rank
=
get_input_partial_shape
(
0
);
...
...
src/ngraph/op/gather_tree.hpp
View file @
205ff94a
...
...
@@ -44,6 +44,7 @@ namespace ngraph
const
Output
<
Node
>&
max_seq_len
,
const
Output
<
Node
>&
end_token
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/log.cpp
View file @
205ff94a
...
...
@@ -28,6 +28,11 @@ op::Log::Log(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v0
::
Log
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Log
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/log.hpp
View file @
205ff94a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Log
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/lrn.cpp
View file @
205ff94a
...
...
@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/lrn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
...
...
@@ -111,6 +112,15 @@ void op::LRN::validate_and_infer_types()
")."
);
}
bool
ngraph
::
op
::
v0
::
LRN
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"alpha"
,
m_alpha
);
visitor
.
on_attribute
(
"beta"
,
m_beta
);
visitor
.
on_attribute
(
"bias"
,
m_bias
);
visitor
.
on_attribute
(
"size"
,
m_size
);
return
true
;
}
shared_ptr
<
Node
>
op
::
LRN
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/lrn.hpp
View file @
205ff94a
...
...
@@ -58,6 +58,7 @@ namespace ngraph
double
bias
,
size_t
size
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/not.cpp
View file @
205ff94a
...
...
@@ -28,6 +28,11 @@ op::v1::LogicalNot::LogicalNot(const Output<Node>& arg)
constructor_validate_and_infer_types
();
}
bool
ngraph
::
op
::
v1
::
LogicalNot
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void
op
::
v1
::
LogicalNot
::
validate_and_infer_types
()
{
...
...
src/ngraph/op/not.hpp
View file @
205ff94a
...
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
LogicalNot
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
...
...
test/attributes.cpp
View file @
205ff94a
...
...
@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset1.hpp"
using
namespace
std
;
using
namespace
ngraph
;
...
...
@@ -140,16 +141,26 @@ public:
double
get_double
(
const
string
&
name
)
{
return
m_doubles
.
at
(
name
);
}
int64_t
get_signed
(
const
string
&
name
)
{
return
m_signeds
.
at
(
name
);
}
uint64_t
get_unsigned
(
const
string
&
name
)
{
return
m_unsigneds
.
at
(
name
);
}
vector
<
float
>&
get_float_vector
(
const
string
&
name
)
{
return
m_float_vectors
.
at
(
name
);
}
vector
<
int64_t
>&
get_signed_vector
(
const
string
&
name
)
{
return
m_signed_vectors
.
at
(
name
);
}
vector
<
string
>&
get_string_vector
(
const
string
&
name
)
{
return
m_string_vectors
.
at
(
name
);
}
void
set_string
(
const
string
&
name
,
const
string
&
value
)
{
m_strings
[
name
]
=
value
;
}
void
set_bool
(
const
string
&
name
,
bool
value
)
{
m_bools
[
name
]
=
value
;
}
void
set_double
(
const
string
&
name
,
double
value
)
{
m_doubles
[
name
]
=
value
;
}
void
set_signed
(
const
string
&
name
,
int64_t
value
)
{
m_signeds
[
name
]
=
value
;
}
void
set_unsigned
(
const
string
&
name
,
uint64_t
value
)
{
m_unsigneds
[
name
]
=
value
;
}
void
set_float_vector
(
const
string
&
name
,
const
vector
<
float
>&
value
)
{
m_float_vectors
[
name
]
=
value
;
}
void
set_signed_vector
(
const
string
&
name
,
const
vector
<
int64_t
>&
value
)
{
m_signed_vectors
[
name
]
=
value
;
}
void
set_string_vector
(
const
string
&
name
,
const
vector
<
string
>&
value
)
{
m_string_vectors
[
name
]
=
value
;
}
void
on_attribute
(
const
string
&
name
,
string
&
value
)
override
{
set_string
(
name
,
value
);
};
void
on_attribute
(
const
string
&
name
,
bool
&
value
)
override
{
set_bool
(
name
,
value
);
}
...
...
@@ -162,10 +173,6 @@ public:
{
set_string
(
name
,
adapter
.
get
());
};
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
int64_t
>>&
adapter
)
override
{
set_signed_vector
(
name
,
adapter
.
get
());
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
override
{
set_signed
(
name
,
adapter
.
get
());
...
...
@@ -174,6 +181,18 @@ public:
{
set_double
(
name
,
adapter
.
get
());
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
float
>>&
adapter
)
override
{
set_float_vector
(
name
,
adapter
.
get
());
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
int64_t
>>&
adapter
)
override
{
set_signed_vector
(
name
,
adapter
.
get
());
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
string
>>&
adapter
)
override
{
set_string_vector
(
name
,
adapter
.
get
());
}
protected
:
NodeTypeInfo
m_node_type_info
;
...
...
@@ -183,6 +202,8 @@ protected:
map
<
string
,
int64_t
>
m_signeds
;
map
<
string
,
uint64_t
>
m_unsigneds
;
map
<
string
,
vector
<
int64_t
>>
m_signed_vectors
;
map
<
string
,
vector
<
float
>>
m_float_vectors
;
map
<
string
,
vector
<
std
::
string
>>
m_string_vectors
;
};
class
NodeBuilder
:
public
AttributeVisitor
...
...
@@ -197,7 +218,6 @@ public:
{
shared_ptr
<
Node
>
node
(
FactoryRegistry
<
Node
>::
get
().
create
(
m_values
.
get_node_type_info
()));
node
->
visit_attributes
(
*
this
);
node
->
validate_and_infer_types
();
return
node
;
}
...
...
@@ -215,10 +235,6 @@ public:
{
adapter
.
set
(
m_values
.
get_string
(
name
));
};
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
int64_t
>>&
adapter
)
override
{
adapter
.
set
(
m_values
.
get_signed_vector
(
name
));
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
override
{
adapter
.
set
(
m_values
.
get_signed
(
name
));
...
...
@@ -227,6 +243,18 @@ public:
{
adapter
.
set
(
m_values
.
get_double
(
name
));
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
int64_t
>>&
adapter
)
override
{
adapter
.
set
(
m_values
.
get_signed_vector
(
name
));
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
string
>>&
adapter
)
override
{
adapter
.
set
(
m_values
.
get_string_vector
(
name
));
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
float
>>&
adapter
)
override
{
adapter
.
set
(
m_values
.
get_float_vector
(
name
));
}
protected
:
NodeSaver
m_values
;
...
...
@@ -256,3 +284,217 @@ TEST(attributes, user_op)
EXPECT_EQ
(
g_oracle
->
get_hyper_parameters
(),
oracle
->
get_hyper_parameters
());
EXPECT_EQ
(
g_oracle
->
get_ultra_parameters
(),
oracle
->
get_ultra_parameters
());
}
TEST
(
attributes
,
elu_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Elu
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
4
});
double
alpha
=
0.1
;
const
auto
elu
=
make_shared
<
opset1
::
Elu
>
(
data
,
alpha
);
NodeBuilder
builder
(
elu
);
auto
g_elu
=
as_type_ptr
<
opset1
::
Elu
>
(
builder
.
create
());
EXPECT_EQ
(
g_elu
->
get_alpha
(),
elu
->
get_alpha
());
}
TEST
(
attributes
,
fake_quantize_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
FakeQuantize
>
();
const
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
,
4
});
const
auto
input_low
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{});
const
auto
input_high
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{});
const
auto
output_low
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{});
const
auto
output_high
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{});
auto
levels
=
5
;
auto
auto_broadcast
=
op
::
AutoBroadcastType
::
NUMPY
;
const
auto
fake_quantize
=
make_shared
<
op
::
FakeQuantize
>
(
data
,
input_low
,
input_high
,
output_low
,
output_high
,
levels
,
auto_broadcast
);
NodeBuilder
builder
(
fake_quantize
);
auto
g_fake_quantize
=
as_type_ptr
<
opset1
::
FakeQuantize
>
(
builder
.
create
());
EXPECT_EQ
(
g_fake_quantize
->
get_levels
(),
fake_quantize
->
get_levels
());
EXPECT_EQ
(
g_fake_quantize
->
get_auto_broadcast
(),
fake_quantize
->
get_auto_broadcast
());
}
TEST
(
attributes
,
grn_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
GRN
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
,
4
,
5
});
float
bias
=
1.25
f
;
auto
grn
=
make_shared
<
opset1
::
GRN
>
(
data
,
bias
);
NodeBuilder
builder
(
grn
);
auto
g_grn
=
as_type_ptr
<
opset1
::
GRN
>
(
builder
.
create
());
EXPECT_EQ
(
g_grn
->
get_bias
(),
grn
->
get_bias
());
}
TEST
(
attributes
,
group_conv_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
GroupConvolution
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
,
224
,
224
});
auto
filters
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
1
,
3
,
5
,
5
});
auto
strides
=
Strides
{
1
,
1
};
auto
pads_begin
=
CoordinateDiff
{
1
,
2
};
auto
pads_end
=
CoordinateDiff
{
1
,
2
};
auto
dilations
=
Strides
{
1
,
1
};
auto
group_conv
=
make_shared
<
opset1
::
GroupConvolution
>
(
data
,
filters
,
strides
,
pads_begin
,
pads_end
,
dilations
,
op
::
PadType
::
VALID
);
NodeBuilder
builder
(
group_conv
);
auto
g_group_conv
=
as_type_ptr
<
opset1
::
GroupConvolution
>
(
builder
.
create
());
EXPECT_EQ
(
g_group_conv
->
get_strides
(),
group_conv
->
get_strides
());
EXPECT_EQ
(
g_group_conv
->
get_pads_begin
(),
group_conv
->
get_pads_begin
());
EXPECT_EQ
(
g_group_conv
->
get_pads_end
(),
group_conv
->
get_pads_end
());
EXPECT_EQ
(
g_group_conv
->
get_dilations
(),
group_conv
->
get_dilations
());
EXPECT_EQ
(
g_group_conv
->
get_auto_pad
(),
group_conv
->
get_auto_pad
());
}
TEST
(
attributes
,
group_conv_backprop_data_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
GroupConvolutionBackpropData
>
();
const
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
20
,
224
,
224
});
const
auto
filter
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
5
,
2
,
3
,
3
});
const
auto
output_shape
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
8
,
447
,
447
});
const
auto
strides
=
Strides
{
2
,
1
};
const
auto
pads_begin
=
CoordinateDiff
{
3
,
4
};
const
auto
pads_end
=
CoordinateDiff
{
4
,
6
};
const
auto
dilations
=
Strides
{
3
,
1
};
const
auto
auto_pad
=
op
::
PadType
::
EXPLICIT
;
const
auto
output_padding
=
CoordinateDiff
{
3
,
4
};
const
auto
gcbd
=
make_shared
<
opset1
::
GroupConvolutionBackpropData
>
(
data
,
filter
,
output_shape
,
strides
,
pads_begin
,
pads_end
,
dilations
,
auto_pad
,
output_padding
);
NodeBuilder
builder
(
gcbd
);
const
auto
g_gcbd
=
as_type_ptr
<
opset1
::
GroupConvolutionBackpropData
>
(
builder
.
create
());
EXPECT_EQ
(
g_gcbd
->
get_strides
(),
gcbd
->
get_strides
());
EXPECT_EQ
(
g_gcbd
->
get_pads_begin
(),
gcbd
->
get_pads_begin
());
EXPECT_EQ
(
g_gcbd
->
get_pads_end
(),
gcbd
->
get_pads_end
());
EXPECT_EQ
(
g_gcbd
->
get_dilations
(),
gcbd
->
get_dilations
());
EXPECT_EQ
(
g_gcbd
->
get_auto_pad
(),
gcbd
->
get_auto_pad
());
EXPECT_EQ
(
g_gcbd
->
get_output_padding
(),
gcbd
->
get_output_padding
());
}
TEST
(
attributes
,
lrn_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
LRN
>
();
const
auto
arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
,
4
});
const
auto
axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
});
const
double
alpha
=
1.1
;
const
double
beta
=
2.2
;
const
double
bias
=
3.3
;
const
size_t
size
=
4
;
const
auto
lrn
=
make_shared
<
opset1
::
LRN
>
(
arg
,
axes
,
alpha
,
beta
,
bias
,
size
);
NodeBuilder
builder
(
lrn
);
auto
g_lrn
=
as_type_ptr
<
opset1
::
LRN
>
(
builder
.
create
());
EXPECT_EQ
(
g_lrn
->
get_alpha
(),
lrn
->
get_alpha
());
EXPECT_EQ
(
g_lrn
->
get_beta
(),
lrn
->
get_beta
());
EXPECT_EQ
(
g_lrn
->
get_bias
(),
lrn
->
get_bias
());
EXPECT_EQ
(
g_lrn
->
get_nsize
(),
lrn
->
get_nsize
());
}
TEST
(
attributes
,
lstm_cell_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
LSTMCell
>
();
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
auto
H
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
12
,
3
});
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
12
,
3
});
const
auto
initial_hidden_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
const
auto
initial_cell_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
const
auto
hidden_size
=
3
;
const
auto
weights_format
=
op
::
LSTMWeightsFormat
::
ICOF
;
const
std
::
vector
<
std
::
string
>
activations
=
{
"tanh"
,
"sigmoid"
,
"tanh"
};
auto
activations_alpha
=
std
::
vector
<
float
>
{
1.0
,
1.5
};
auto
activations_beta
=
std
::
vector
<
float
>
{
2.0
,
1.0
};
const
float
clip
=
0.5
f
;
bool
input_forget
=
true
;
const
auto
lstm_cell
=
make_shared
<
opset1
::
LSTMCell
>
(
X
,
initial_hidden_state
,
initial_cell_state
,
W
,
R
,
hidden_size
,
weights_format
,
activations
,
activations_alpha
,
activations_beta
,
clip
,
input_forget
);
NodeBuilder
builder
(
lstm_cell
);
auto
g_lstm_cell
=
as_type_ptr
<
opset1
::
LSTMCell
>
(
builder
.
create
());
EXPECT_EQ
(
g_lstm_cell
->
get_hidden_size
(),
lstm_cell
->
get_hidden_size
());
EXPECT_EQ
(
g_lstm_cell
->
get_activations
(),
lstm_cell
->
get_activations
());
EXPECT_EQ
(
g_lstm_cell
->
get_activations_alpha
(),
lstm_cell
->
get_activations_alpha
());
EXPECT_EQ
(
g_lstm_cell
->
get_activations_beta
(),
lstm_cell
->
get_activations_beta
());
EXPECT_EQ
(
g_lstm_cell
->
get_clip
(),
lstm_cell
->
get_clip
());
EXPECT_EQ
(
g_lstm_cell
->
get_input_forget
(),
lstm_cell
->
get_input_forget
());
EXPECT_EQ
(
g_lstm_cell
->
get_weights_format
(),
lstm_cell
->
get_weights_format
());
}
TEST
(
attributes
,
lstm_sequence_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
LSTMSequence
>
();
const
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
4
});
const
auto
initial_hidden_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
initial_cell_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
sequence_lengths
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
});
const
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
,
4
});
const
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
,
3
});
const
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
});
const
auto
hidden_size
=
3
;
const
auto
lstm_direction
=
op
::
LSTMSequence
::
direction
::
FORWARD
;
const
auto
weights_format
=
op
::
LSTMWeightsFormat
::
ICOF
;
const
std
::
vector
<
float
>
activations_alpha
=
{
1
,
2
,
3
};
const
std
::
vector
<
float
>
activations_beta
=
{
4
,
5
,
6
};
const
std
::
vector
<
std
::
string
>
activations
=
{
"tanh"
,
"sigmoid"
,
"tanh"
};
const
float
clip_threshold
=
0.5
f
;
const
bool
input_forget
=
true
;
const
auto
lstm_sequence
=
make_shared
<
opset1
::
LSTMSequence
>
(
X
,
initial_hidden_state
,
initial_cell_state
,
sequence_lengths
,
W
,
R
,
B
,
hidden_size
,
lstm_direction
,
weights_format
,
activations_alpha
,
activations_beta
,
activations
,
clip_threshold
,
input_forget
);
NodeBuilder
builder
(
lstm_sequence
);
auto
g_lstm_sequence
=
as_type_ptr
<
opset1
::
LSTMSequence
>
(
builder
.
create
());
EXPECT_EQ
(
g_lstm_sequence
->
get_hidden_size
(),
lstm_sequence
->
get_hidden_size
());
EXPECT_EQ
(
g_lstm_sequence
->
get_activations
(),
lstm_sequence
->
get_activations
());
EXPECT_EQ
(
g_lstm_sequence
->
get_activations_alpha
(),
lstm_sequence
->
get_activations_alpha
());
EXPECT_EQ
(
g_lstm_sequence
->
get_activations_beta
(),
lstm_sequence
->
get_activations_beta
());
EXPECT_EQ
(
g_lstm_sequence
->
get_clip_threshold
(),
lstm_sequence
->
get_clip_threshold
());
EXPECT_EQ
(
g_lstm_sequence
->
get_direction
(),
lstm_sequence
->
get_direction
());
EXPECT_EQ
(
g_lstm_sequence
->
get_input_forget
(),
lstm_sequence
->
get_input_forget
());
EXPECT_EQ
(
g_lstm_sequence
->
get_weights_format
(),
lstm_sequence
->
get_weights_format
());
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment