Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
205ff94a
Unverified
Commit
205ff94a
authored
Mar 02, 2020
by
Scott Cyphers
Committed by
GitHub
Mar 02, 2020
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into cloudhan_fix-mkldnn_v1
parents
4973b763
61b4e6f9
Hide whitespace changes
Inline
Side-by-side
Showing
35 changed files
with
555 additions
and
12 deletions
+555
-12
attribute_adapter.cpp
src/ngraph/attribute_adapter.cpp
+36
-0
attribute_adapter.hpp
src/ngraph/attribute_adapter.hpp
+34
-0
attribute_visitor.hpp
src/ngraph/attribute_visitor.hpp
+17
-2
coordinate_diff.cpp
src/ngraph/coordinate_diff.cpp
+1
-1
erf.cpp
src/ngraph/op/erf.cpp
+5
-0
erf.hpp
src/ngraph/op/erf.hpp
+1
-0
exp.cpp
src/ngraph/op/exp.cpp
+5
-0
exp.hpp
src/ngraph/op/exp.hpp
+1
-0
floor.cpp
src/ngraph/op/floor.cpp
+5
-0
floor.hpp
src/ngraph/op/floor.hpp
+1
-0
elu.cpp
src/ngraph/op/fused/elu.cpp
+7
-0
elu.hpp
src/ngraph/op/fused/elu.hpp
+1
-0
fake_quantize.cpp
src/ngraph/op/fused/fake_quantize.cpp
+8
-0
fake_quantize.hpp
src/ngraph/op/fused/fake_quantize.hpp
+1
-0
grn.cpp
src/ngraph/op/fused/grn.cpp
+7
-0
grn.hpp
src/ngraph/op/fused/grn.hpp
+1
-0
group_conv.cpp
src/ngraph/op/fused/group_conv.cpp
+22
-0
group_conv.hpp
src/ngraph/op/fused/group_conv.hpp
+3
-0
hard_sigmoid.cpp
src/ngraph/op/fused/hard_sigmoid.cpp
+5
-0
hard_sigmoid.hpp
src/ngraph/op/fused/hard_sigmoid.hpp
+1
-0
lstm_cell.cpp
src/ngraph/op/fused/lstm_cell.cpp
+37
-0
lstm_cell.hpp
src/ngraph/op/fused/lstm_cell.hpp
+17
-0
lstm_sequence.cpp
src/ngraph/op/fused/lstm_sequence.cpp
+35
-0
lstm_sequence.hpp
src/ngraph/op/fused/lstm_sequence.hpp
+18
-0
gather.cpp
src/ngraph/op/gather.cpp
+5
-0
gather.hpp
src/ngraph/op/gather.hpp
+1
-0
gather_tree.cpp
src/ngraph/op/gather_tree.cpp
+5
-0
gather_tree.hpp
src/ngraph/op/gather_tree.hpp
+1
-0
log.cpp
src/ngraph/op/log.cpp
+5
-0
log.hpp
src/ngraph/op/log.hpp
+1
-0
lrn.cpp
src/ngraph/op/lrn.cpp
+10
-0
lrn.hpp
src/ngraph/op/lrn.hpp
+1
-0
not.cpp
src/ngraph/op/not.cpp
+5
-0
not.hpp
src/ngraph/op/not.hpp
+1
-0
attributes.cpp
test/attributes.cpp
+251
-9
No files found.
src/ngraph/attribute_adapter.cpp
View file @
205ff94a
...
@@ -242,4 +242,40 @@ namespace ngraph
...
@@ -242,4 +242,40 @@ namespace ngraph
m_value
=
copy_from
<
vector
<
uint64_t
>>
(
value
);
m_value
=
copy_from
<
vector
<
uint64_t
>>
(
value
);
m_buffer_valid
=
false
;
m_buffer_valid
=
false
;
}
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
vector
<
float
>>::
type_info
;
const
vector
<
float
>&
AttributeAdapter
<
vector
<
float
>>::
get
()
{
if
(
!
m_buffer_valid
)
{
m_buffer
=
copy_from
<
vector
<
float
>>
(
m_value
);
m_buffer_valid
=
true
;
}
return
m_buffer
;
}
void
AttributeAdapter
<
vector
<
float
>>::
set
(
const
vector
<
float
>&
value
)
{
m_value
=
copy_from
<
vector
<
float
>>
(
value
);
m_buffer_valid
=
false
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
vector
<
string
>>::
type_info
;
const
vector
<
string
>&
AttributeAdapter
<
vector
<
string
>>::
get
()
{
if
(
!
m_buffer_valid
)
{
m_buffer
=
copy_from
<
vector
<
string
>>
(
m_value
);
m_buffer_valid
=
true
;
}
return
m_buffer
;
}
void
AttributeAdapter
<
vector
<
string
>>::
set
(
const
vector
<
string
>&
value
)
{
m_value
=
copy_from
<
vector
<
string
>>
(
value
);
m_buffer_valid
=
false
;
}
}
}
src/ngraph/attribute_adapter.hpp
View file @
205ff94a
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#pragma once
#pragma once
#include <string>
#include <type_traits>
#include <type_traits>
#include <vector>
#include <vector>
...
@@ -299,6 +300,39 @@ namespace ngraph
...
@@ -299,6 +300,39 @@ namespace ngraph
void
set
(
const
std
::
vector
<
int64_t
>&
value
)
override
;
void
set
(
const
std
::
vector
<
int64_t
>&
value
)
override
;
};
};
template
<>
class
NGRAPH_API
AttributeAdapter
<
std
::
vector
<
float
>>
:
public
ValueReference
<
std
::
vector
<
float
>>
,
public
ValueAccessor
<
std
::
vector
<
float
>>
{
public
:
AttributeAdapter
(
std
::
vector
<
float
>&
value
)
:
ValueReference
<
std
::
vector
<
float
>>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<vector<float>>"
,
0
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
const
std
::
vector
<
float
>&
get
()
override
;
void
set
(
const
std
::
vector
<
float
>&
value
)
override
;
};
template
<>
class
NGRAPH_API
AttributeAdapter
<
std
::
vector
<
std
::
string
>>
:
public
ValueReference
<
std
::
vector
<
std
::
string
>>
,
public
ValueAccessor
<
std
::
vector
<
std
::
string
>>
{
public
:
AttributeAdapter
(
std
::
vector
<
std
::
string
>&
value
)
:
ValueReference
<
std
::
vector
<
std
::
string
>>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<vector<string>>"
,
0
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
const
std
::
vector
<
std
::
string
>&
get
()
override
;
void
set
(
const
std
::
vector
<
std
::
string
>&
value
)
override
;
};
template
<
typename
A
,
typename
B
>
template
<
typename
A
,
typename
B
>
A
copy_from
(
B
&
b
)
A
copy_from
(
B
&
b
)
{
{
...
...
src/ngraph/attribute_visitor.hpp
View file @
205ff94a
...
@@ -46,16 +46,25 @@ namespace ngraph
...
@@ -46,16 +46,25 @@ namespace ngraph
{
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
};
};
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
double
>&
adapter
)
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
std
::
vector
<
int64_t
>>&
adapter
)
ValueAccessor
<
std
::
vector
<
int64_t
>>&
adapter
)
{
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
std
::
vector
<
float
>
>&
adapter
)
{
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
}
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
double
>&
adapter
)
virtual
void
on_adapter
(
const
std
::
string
&
name
,
ValueAccessor
<
std
::
vector
<
std
::
string
>>&
adapter
)
{
{
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
on_adapter
(
name
,
static_cast
<
ValueAccessor
<
void
>&>
(
adapter
));
}
}
...
@@ -68,5 +77,11 @@ namespace ngraph
...
@@ -68,5 +77,11 @@ namespace ngraph
AttributeAdapter
<
T
>
adapter
(
value
);
AttributeAdapter
<
T
>
adapter
(
value
);
on_adapter
(
name
,
adapter
);
on_adapter
(
name
,
adapter
);
}
}
void
on_attribute
(
const
std
::
string
&
name
,
op
::
AutoBroadcastSpec
&
value
)
{
AttributeAdapter
<
op
::
AutoBroadcastType
>
adapter
(
value
.
m_type
);
on_adapter
(
name
,
adapter
);
}
};
};
}
}
src/ngraph/coordinate_diff.cpp
View file @
205ff94a
...
@@ -80,7 +80,7 @@ const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get()
...
@@ -80,7 +80,7 @@ const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get()
void
AttributeAdapter
<
CoordinateDiff
>::
set
(
const
vector
<
int64_t
>&
value
)
void
AttributeAdapter
<
CoordinateDiff
>::
set
(
const
vector
<
int64_t
>&
value
)
{
{
m_value
=
copy_from
<
CoordinateDiff
>
(
m_
value
);
m_value
=
copy_from
<
CoordinateDiff
>
(
value
);
m_buffer_valid
=
false
;
m_buffer_valid
=
false
;
}
}
...
...
src/ngraph/op/erf.cpp
View file @
205ff94a
...
@@ -23,6 +23,11 @@ using namespace ngraph;
...
@@ -23,6 +23,11 @@ using namespace ngraph;
constexpr
NodeTypeInfo
op
::
Erf
::
type_info
;
constexpr
NodeTypeInfo
op
::
Erf
::
type_info
;
bool
ngraph
::
op
::
v0
::
Erf
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Erf
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Erf
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/erf.hpp
View file @
205ff94a
...
@@ -32,6 +32,7 @@ namespace ngraph
...
@@ -32,6 +32,7 @@ namespace ngraph
Erf
()
=
default
;
Erf
()
=
default
;
Erf
(
const
Output
<
Node
>&
arg
);
Erf
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
};
};
...
...
src/ngraph/op/exp.cpp
View file @
205ff94a
...
@@ -28,6 +28,11 @@ op::Exp::Exp(const Output<Node>& arg)
...
@@ -28,6 +28,11 @@ op::Exp::Exp(const Output<Node>& arg)
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v0
::
Exp
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Exp
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Exp
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/exp.hpp
View file @
205ff94a
...
@@ -37,6 +37,7 @@ namespace ngraph
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Exp
(
const
Output
<
Node
>&
arg
);
Exp
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/floor.cpp
View file @
205ff94a
...
@@ -27,6 +27,11 @@ op::Floor::Floor(const Output<Node>& arg)
...
@@ -27,6 +27,11 @@ op::Floor::Floor(const Output<Node>& arg)
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v0
::
Floor
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Floor
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Floor
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/floor.hpp
View file @
205ff94a
...
@@ -37,6 +37,7 @@ namespace ngraph
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Floor
(
const
Output
<
Node
>&
arg
);
Floor
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
};
};
...
...
src/ngraph/op/fused/elu.cpp
View file @
205ff94a
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
//*****************************************************************************
//*****************************************************************************
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/op/fused/elu.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/add.hpp"
...
@@ -37,6 +38,12 @@ op::Elu::Elu(const Output<Node>& data, const double alpha)
...
@@ -37,6 +38,12 @@ op::Elu::Elu(const Output<Node>& data, const double alpha)
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v0
::
Elu
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"alpha"
,
m_alpha
);
return
true
;
}
NodeVector
op
::
Elu
::
decompose_op
()
const
NodeVector
op
::
Elu
::
decompose_op
()
const
{
{
auto
data
=
input_value
(
0
);
auto
data
=
input_value
(
0
);
...
...
src/ngraph/op/fused/elu.hpp
View file @
205ff94a
...
@@ -42,6 +42,7 @@ namespace ngraph
...
@@ -42,6 +42,7 @@ namespace ngraph
/// \param alpha Multiplier for negative values
/// \param alpha Multiplier for negative values
Elu
(
const
Output
<
Node
>&
data
,
const
double
alpha
);
Elu
(
const
Output
<
Node
>&
data
,
const
double
alpha
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/fake_quantize.cpp
View file @
205ff94a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <memory>
#include <memory>
#include "fake_quantize.hpp"
#include "fake_quantize.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
...
@@ -80,6 +81,13 @@ void op::FakeQuantize::validate_and_infer_types()
...
@@ -80,6 +81,13 @@ void op::FakeQuantize::validate_and_infer_types()
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_partial_shape
(
0
));
set_output_type
(
0
,
get_input_element_type
(
0
),
get_input_partial_shape
(
0
));
}
}
bool
ngraph
::
op
::
v0
::
FakeQuantize
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"levels"
,
m_levels
);
visitor
.
on_attribute
(
"auto_broadcast"
,
m_auto_broadcast
);
return
true
;
}
NodeVector
op
::
FakeQuantize
::
decompose_op
()
const
NodeVector
op
::
FakeQuantize
::
decompose_op
()
const
{
{
Output
<
Node
>
data
{
input_value
(
0
)};
Output
<
Node
>
data
{
input_value
(
0
)};
...
...
src/ngraph/op/fused/fake_quantize.hpp
View file @
205ff94a
...
@@ -67,6 +67,7 @@ namespace ngraph
...
@@ -67,6 +67,7 @@ namespace ngraph
const
AutoBroadcastSpec
&
auto_broadcast
=
const
AutoBroadcastSpec
&
auto_broadcast
=
AutoBroadcastSpec
(
AutoBroadcastType
::
NUMPY
));
AutoBroadcastSpec
(
AutoBroadcastType
::
NUMPY
));
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
void
validate_and_infer_types
()
override
;
virtual
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/fused/grn.cpp
View file @
205ff94a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <iterator>
#include <iterator>
#include "grn.hpp"
#include "grn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/reshape.hpp"
...
@@ -36,6 +37,12 @@ op::GRN::GRN(const Output<Node>& data, float bias)
...
@@ -36,6 +37,12 @@ op::GRN::GRN(const Output<Node>& data, float bias)
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v0
::
GRN
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"bias"
,
m_bias
);
return
true
;
}
void
op
::
GRN
::
pre_validate_and_infer_types
()
void
op
::
GRN
::
pre_validate_and_infer_types
()
{
{
const
auto
&
data_pshape
=
get_input_partial_shape
(
0
);
const
auto
&
data_pshape
=
get_input_partial_shape
(
0
);
...
...
src/ngraph/op/fused/grn.hpp
View file @
205ff94a
...
@@ -42,6 +42,7 @@ namespace ngraph
...
@@ -42,6 +42,7 @@ namespace ngraph
///
///
GRN
(
const
Output
<
Node
>&
data
,
float
bias
);
GRN
(
const
Output
<
Node
>&
data
,
float
bias
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
float
get_bias
()
const
{
return
m_bias
;
}
float
get_bias
()
const
{
return
m_bias
;
}
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
...
...
src/ngraph/op/fused/group_conv.cpp
View file @
205ff94a
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "group_conv.hpp"
#include "group_conv.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/concat.hpp"
...
@@ -57,6 +58,16 @@ op::v1::GroupConvolution::GroupConvolution(const Output<Node>& data_batch,
...
@@ -57,6 +58,16 @@ op::v1::GroupConvolution::GroupConvolution(const Output<Node>& data_batch,
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v1
::
GroupConvolution
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"strides"
,
m_strides
);
visitor
.
on_attribute
(
"pads_begin"
,
m_pads_begin
);
visitor
.
on_attribute
(
"pads_end"
,
m_pads_end
);
visitor
.
on_attribute
(
"dilations"
,
m_dilations
);
visitor
.
on_attribute
(
"auto_pad"
,
m_auto_pad
);
return
true
;
}
void
op
::
v1
::
GroupConvolution
::
validate_and_infer_types
()
void
op
::
v1
::
GroupConvolution
::
validate_and_infer_types
()
{
{
const
PartialShape
&
data_batch_pshape
=
get_input_partial_shape
(
0
);
const
PartialShape
&
data_batch_pshape
=
get_input_partial_shape
(
0
);
...
@@ -219,6 +230,17 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData(
...
@@ -219,6 +230,17 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData(
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v1
::
GroupConvolutionBackpropData
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"strides"
,
m_strides
);
visitor
.
on_attribute
(
"pads_begin"
,
m_pads_begin
);
visitor
.
on_attribute
(
"pads_end"
,
m_pads_end
);
visitor
.
on_attribute
(
"dilations"
,
m_dilations
);
visitor
.
on_attribute
(
"auto_pad"
,
m_auto_pad
);
visitor
.
on_attribute
(
"output_padding"
,
m_output_padding
);
return
true
;
}
bool
op
::
v1
::
GroupConvolutionBackpropData
::
is_dynamic
()
const
bool
op
::
v1
::
GroupConvolutionBackpropData
::
is_dynamic
()
const
{
{
bool
is_dynamic
=
Node
::
is_dynamic
();
bool
is_dynamic
=
Node
::
is_dynamic
();
...
...
src/ngraph/op/fused/group_conv.hpp
View file @
205ff94a
...
@@ -61,6 +61,8 @@ namespace ngraph
...
@@ -61,6 +61,8 @@ namespace ngraph
const
CoordinateDiff
&
pads_end
,
const
CoordinateDiff
&
pads_end
,
const
Strides
&
dilations
,
const
Strides
&
dilations
,
const
PadType
&
auto_pad
=
PadType
::
EXPLICIT
);
const
PadType
&
auto_pad
=
PadType
::
EXPLICIT
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
// TODO - Remove supports_decompose and validate_and_infer_type once op supports
// TODO - Remove supports_decompose and validate_and_infer_type once op supports
// decomposition
// decomposition
bool
supports_decompose
()
const
override
{
return
false
;
}
bool
supports_decompose
()
const
override
{
return
false
;
}
...
@@ -187,6 +189,7 @@ namespace ngraph
...
@@ -187,6 +189,7 @@ namespace ngraph
const
PadType
&
auto_pad
=
PadType
::
EXPLICIT
,
const
PadType
&
auto_pad
=
PadType
::
EXPLICIT
,
const
CoordinateDiff
&
output_padding
=
{});
const
CoordinateDiff
&
output_padding
=
{});
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
bool
is_dynamic
()
const
override
;
virtual
bool
is_dynamic
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
...
...
src/ngraph/op/fused/hard_sigmoid.cpp
View file @
205ff94a
...
@@ -37,6 +37,11 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data,
...
@@ -37,6 +37,11 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data,
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v0
::
HardSigmoid
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
HardSigmoid
::
pre_validate_and_infer_types
()
void
op
::
HardSigmoid
::
pre_validate_and_infer_types
()
{
{
const
auto
&
alpha_pshape
=
get_input_partial_shape
(
1
);
const
auto
&
alpha_pshape
=
get_input_partial_shape
(
1
);
...
...
src/ngraph/op/fused/hard_sigmoid.hpp
View file @
205ff94a
...
@@ -46,6 +46,7 @@ namespace ngraph
...
@@ -46,6 +46,7 @@ namespace ngraph
const
Output
<
Node
>&
alpha
,
const
Output
<
Node
>&
alpha
,
const
Output
<
Node
>&
beta
);
const
Output
<
Node
>&
beta
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/fused/lstm_cell.cpp
View file @
205ff94a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include <cmath>
#include <cmath>
#include <functional>
#include <functional>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/add.hpp"
...
@@ -107,6 +108,19 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
...
@@ -107,6 +108,19 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v0
::
LSTMCell
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"hidden_size"
,
m_hidden_size
);
visitor
.
on_attribute
(
"activations"
,
m_activations
);
visitor
.
on_attribute
(
"activations_alpha"
,
m_activations_alpha
);
visitor
.
on_attribute
(
"activations_beta"
,
m_activations_beta
);
visitor
.
on_attribute
(
"clip"
,
m_clip
);
visitor
.
on_attribute
(
"input_forget"
,
m_input_forget
);
visitor
.
on_attribute
(
"weights_format"
,
m_weights_format
);
return
true
;
}
void
op
::
LSTMCell
::
pre_validate_and_infer_types
()
void
op
::
LSTMCell
::
pre_validate_and_infer_types
()
{
{
set_output_size
(
2
);
set_output_size
(
2
);
...
@@ -386,3 +400,26 @@ shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) co
...
@@ -386,3 +400,26 @@ shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) co
throw
ngraph_error
(
"Incorrect number of new arguments"
);
throw
ngraph_error
(
"Incorrect number of new arguments"
);
}
}
}
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
LSTMWeightsFormat
>&
EnumNames
<
op
::
LSTMWeightsFormat
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
LSTMWeightsFormat
>
(
"op::LSTMWeightsFormat"
,
{{
"fico"
,
op
::
LSTMWeightsFormat
::
FICO
},
{
"icof"
,
op
::
LSTMWeightsFormat
::
ICOF
},
{
"ifco"
,
op
::
LSTMWeightsFormat
::
IFCO
},
{
"ifoc"
,
op
::
LSTMWeightsFormat
::
IFOC
},
{
"iofc"
,
op
::
LSTMWeightsFormat
::
IOFC
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
LSTMWeightsFormat
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
LSTMWeightsFormat
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
src/ngraph/op/fused/lstm_cell.hpp
View file @
205ff94a
...
@@ -224,6 +224,7 @@ namespace ngraph
...
@@ -224,6 +224,7 @@ namespace ngraph
float
clip
=
0.
f
,
float
clip
=
0.
f
,
bool
input_forget
=
false
);
bool
input_forget
=
false
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
void
pre_validate_and_infer_types
()
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
@@ -284,4 +285,20 @@ namespace ngraph
...
@@ -284,4 +285,20 @@ namespace ngraph
}
}
using
v0
::
LSTMCell
;
using
v0
::
LSTMCell
;
}
// namespace op
}
// namespace op
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
LSTMWeightsFormat
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
LSTMWeightsFormat
>
:
public
EnumAttributeAdapterBase
<
op
::
LSTMWeightsFormat
>
{
public
:
AttributeAdapter
(
op
::
LSTMWeightsFormat
&
value
)
:
EnumAttributeAdapterBase
<
op
::
LSTMWeightsFormat
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::LSTMWeightsFormat>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
// namespace ngraph
}
// namespace ngraph
src/ngraph/op/fused/lstm_sequence.cpp
View file @
205ff94a
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/builder/split.hpp"
...
@@ -32,6 +33,19 @@ using namespace ngraph;
...
@@ -32,6 +33,19 @@ using namespace ngraph;
using
namespace
std
;
using
namespace
std
;
constexpr
NodeTypeInfo
op
::
LSTMSequence
::
type_info
;
constexpr
NodeTypeInfo
op
::
LSTMSequence
::
type_info
;
bool
ngraph
::
op
::
v0
::
LSTMSequence
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"hidden_size"
,
m_hidden_size
);
visitor
.
on_attribute
(
"activations"
,
m_activations
);
visitor
.
on_attribute
(
"activations_alpha"
,
m_activations_alpha
);
visitor
.
on_attribute
(
"activations_beta"
,
m_activations_beta
);
visitor
.
on_attribute
(
"clip"
,
m_clip_threshold
);
visitor
.
on_attribute
(
"direction"
,
m_direction
);
visitor
.
on_attribute
(
"input_forget"
,
m_input_forget
);
visitor
.
on_attribute
(
"weights_format"
,
m_weights_format
);
return
true
;
}
NodeVector
op
::
LSTMSequence
::
decompose_op
()
const
NodeVector
op
::
LSTMSequence
::
decompose_op
()
const
{
{
NodeVector
results
;
NodeVector
results
;
...
@@ -247,3 +261,24 @@ shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reve
...
@@ -247,3 +261,24 @@ shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reve
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
return
builder
::
squeeze
(
tmp
);
return
builder
::
squeeze
(
tmp
);
}
}
namespace
ngraph
{
template
<>
EnumNames
<
op
::
v0
::
LSTMSequence
::
direction
>&
EnumNames
<
op
::
v0
::
LSTMSequence
::
direction
>::
get
()
{
static
auto
enum_names
=
EnumNames
<
op
::
v0
::
LSTMSequence
::
direction
>
(
"op::v0::LSTMSequence::direction"
,
{{
"forward"
,
op
::
v0
::
LSTMSequence
::
direction
::
FORWARD
},
{
"reverse"
,
op
::
v0
::
LSTMSequence
::
direction
::
REVERSE
},
{
"bidirectional"
,
op
::
v0
::
LSTMSequence
::
direction
::
BIDIRECTIONAL
}});
return
enum_names
;
}
constexpr
DiscreteTypeInfo
AttributeAdapter
<
op
::
v0
::
LSTMSequence
::
direction
>::
type_info
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v0
::
LSTMSequence
::
direction
&
type
)
{
return
s
<<
as_string
(
type
);
}
}
// namespace ngraph
src/ngraph/op/fused/lstm_sequence.hpp
View file @
205ff94a
...
@@ -135,6 +135,7 @@ namespace ngraph
...
@@ -135,6 +135,7 @@ namespace ngraph
{
{
}
}
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
NodeVector
decompose_op
()
const
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
@@ -185,4 +186,21 @@ namespace ngraph
...
@@ -185,4 +186,21 @@ namespace ngraph
}
}
using
v0
::
LSTMSequence
;
using
v0
::
LSTMSequence
;
}
// namespace op
}
// namespace op
std
::
ostream
&
operator
<<
(
std
::
ostream
&
s
,
const
op
::
v0
::
LSTMSequence
::
direction
&
type
);
template
<>
class
NGRAPH_API
AttributeAdapter
<
op
::
v0
::
LSTMSequence
::
direction
>
:
public
EnumAttributeAdapterBase
<
op
::
v0
::
LSTMSequence
::
direction
>
{
public
:
AttributeAdapter
(
op
::
v0
::
LSTMSequence
::
direction
&
value
)
:
EnumAttributeAdapterBase
<
op
::
v0
::
LSTMSequence
::
direction
>
(
value
)
{
}
static
constexpr
DiscreteTypeInfo
type_info
{
"AttributeAdapter<op::v0::LSTMSequence::direction>"
,
1
};
const
DiscreteTypeInfo
&
get_type_info
()
const
override
{
return
type_info
;
}
};
}
// namespace ngraph
}
// namespace ngraph
src/ngraph/op/gather.cpp
View file @
205ff94a
...
@@ -112,6 +112,11 @@ op::v1::Gather::Gather(const Output<Node>& params,
...
@@ -112,6 +112,11 @@ op::v1::Gather::Gather(const Output<Node>& params,
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v1
::
Gather
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
v1
::
Gather
::
validate_and_infer_types
()
void
op
::
v1
::
Gather
::
validate_and_infer_types
()
{
{
const
auto
&
input_rank
=
get_input_partial_shape
(
PARAMS
).
rank
();
const
auto
&
input_rank
=
get_input_partial_shape
(
PARAMS
).
rank
();
...
...
src/ngraph/op/gather.hpp
View file @
205ff94a
...
@@ -67,6 +67,7 @@ namespace ngraph
...
@@ -67,6 +67,7 @@ namespace ngraph
const
Output
<
Node
>&
indices
,
const
Output
<
Node
>&
indices
,
const
Output
<
Node
>&
axis
);
const
Output
<
Node
>&
axis
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
int64_t
get_axis
()
const
;
int64_t
get_axis
()
const
;
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/gather_tree.cpp
View file @
205ff94a
...
@@ -38,6 +38,11 @@ shared_ptr<Node> op::v1::GatherTree::copy_with_new_args(const NodeVector& new_ar
...
@@ -38,6 +38,11 @@ shared_ptr<Node> op::v1::GatherTree::copy_with_new_args(const NodeVector& new_ar
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
));
new_args
.
at
(
0
),
new_args
.
at
(
1
),
new_args
.
at
(
2
),
new_args
.
at
(
3
));
}
}
bool
ngraph
::
op
::
v1
::
GatherTree
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
void
op
::
v1
::
GatherTree
::
validate_and_infer_types
()
void
op
::
v1
::
GatherTree
::
validate_and_infer_types
()
{
{
const
auto
&
step_ids_rank
=
get_input_partial_shape
(
0
);
const
auto
&
step_ids_rank
=
get_input_partial_shape
(
0
);
...
...
src/ngraph/op/gather_tree.hpp
View file @
205ff94a
...
@@ -44,6 +44,7 @@ namespace ngraph
...
@@ -44,6 +44,7 @@ namespace ngraph
const
Output
<
Node
>&
max_seq_len
,
const
Output
<
Node
>&
max_seq_len
,
const
Output
<
Node
>&
end_token
);
const
Output
<
Node
>&
end_token
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
...
src/ngraph/op/log.cpp
View file @
205ff94a
...
@@ -28,6 +28,11 @@ op::Log::Log(const Output<Node>& arg)
...
@@ -28,6 +28,11 @@ op::Log::Log(const Output<Node>& arg)
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v0
::
Log
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
shared_ptr
<
Node
>
op
::
Log
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
Log
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/log.hpp
View file @
205ff94a
...
@@ -37,6 +37,7 @@ namespace ngraph
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
Log
(
const
Output
<
Node
>&
arg
);
Log
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
...
...
src/ngraph/op/lrn.cpp
View file @
205ff94a
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
//*****************************************************************************
//*****************************************************************************
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/multiply.hpp"
...
@@ -111,6 +112,15 @@ void op::LRN::validate_and_infer_types()
...
@@ -111,6 +112,15 @@ void op::LRN::validate_and_infer_types()
")."
);
")."
);
}
}
bool
ngraph
::
op
::
v0
::
LRN
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
visitor
.
on_attribute
(
"alpha"
,
m_alpha
);
visitor
.
on_attribute
(
"beta"
,
m_beta
);
visitor
.
on_attribute
(
"bias"
,
m_bias
);
visitor
.
on_attribute
(
"size"
,
m_size
);
return
true
;
}
shared_ptr
<
Node
>
op
::
LRN
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
shared_ptr
<
Node
>
op
::
LRN
::
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
{
{
check_new_args_count
(
this
,
new_args
);
check_new_args_count
(
this
,
new_args
);
...
...
src/ngraph/op/lrn.hpp
View file @
205ff94a
...
@@ -58,6 +58,7 @@ namespace ngraph
...
@@ -58,6 +58,7 @@ namespace ngraph
double
bias
,
double
bias
,
size_t
size
);
size_t
size
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
copy_with_new_args
(
const
NodeVector
&
new_args
)
const
override
;
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
...
...
src/ngraph/op/not.cpp
View file @
205ff94a
...
@@ -28,6 +28,11 @@ op::v1::LogicalNot::LogicalNot(const Output<Node>& arg)
...
@@ -28,6 +28,11 @@ op::v1::LogicalNot::LogicalNot(const Output<Node>& arg)
constructor_validate_and_infer_types
();
constructor_validate_and_infer_types
();
}
}
bool
ngraph
::
op
::
v1
::
LogicalNot
::
visit_attributes
(
AttributeVisitor
&
visitor
)
{
return
true
;
}
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void
op
::
v1
::
LogicalNot
::
validate_and_infer_types
()
void
op
::
v1
::
LogicalNot
::
validate_and_infer_types
()
{
{
...
...
src/ngraph/op/not.hpp
View file @
205ff94a
...
@@ -37,6 +37,7 @@ namespace ngraph
...
@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
/// \param arg Node that produces the input tensor.
LogicalNot
(
const
Output
<
Node
>&
arg
);
LogicalNot
(
const
Output
<
Node
>&
arg
);
bool
visit_attributes
(
AttributeVisitor
&
visitor
)
override
;
void
validate_and_infer_types
()
override
;
void
validate_and_infer_types
()
override
;
virtual
std
::
shared_ptr
<
Node
>
virtual
std
::
shared_ptr
<
Node
>
...
...
test/attributes.cpp
View file @
205ff94a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset1.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
...
@@ -140,16 +141,26 @@ public:
...
@@ -140,16 +141,26 @@ public:
double
get_double
(
const
string
&
name
)
{
return
m_doubles
.
at
(
name
);
}
double
get_double
(
const
string
&
name
)
{
return
m_doubles
.
at
(
name
);
}
int64_t
get_signed
(
const
string
&
name
)
{
return
m_signeds
.
at
(
name
);
}
int64_t
get_signed
(
const
string
&
name
)
{
return
m_signeds
.
at
(
name
);
}
uint64_t
get_unsigned
(
const
string
&
name
)
{
return
m_unsigneds
.
at
(
name
);
}
uint64_t
get_unsigned
(
const
string
&
name
)
{
return
m_unsigneds
.
at
(
name
);
}
vector
<
float
>&
get_float_vector
(
const
string
&
name
)
{
return
m_float_vectors
.
at
(
name
);
}
vector
<
int64_t
>&
get_signed_vector
(
const
string
&
name
)
{
return
m_signed_vectors
.
at
(
name
);
}
vector
<
int64_t
>&
get_signed_vector
(
const
string
&
name
)
{
return
m_signed_vectors
.
at
(
name
);
}
vector
<
string
>&
get_string_vector
(
const
string
&
name
)
{
return
m_string_vectors
.
at
(
name
);
}
void
set_string
(
const
string
&
name
,
const
string
&
value
)
{
m_strings
[
name
]
=
value
;
}
void
set_string
(
const
string
&
name
,
const
string
&
value
)
{
m_strings
[
name
]
=
value
;
}
void
set_bool
(
const
string
&
name
,
bool
value
)
{
m_bools
[
name
]
=
value
;
}
void
set_bool
(
const
string
&
name
,
bool
value
)
{
m_bools
[
name
]
=
value
;
}
void
set_double
(
const
string
&
name
,
double
value
)
{
m_doubles
[
name
]
=
value
;
}
void
set_double
(
const
string
&
name
,
double
value
)
{
m_doubles
[
name
]
=
value
;
}
void
set_signed
(
const
string
&
name
,
int64_t
value
)
{
m_signeds
[
name
]
=
value
;
}
void
set_signed
(
const
string
&
name
,
int64_t
value
)
{
m_signeds
[
name
]
=
value
;
}
void
set_unsigned
(
const
string
&
name
,
uint64_t
value
)
{
m_unsigneds
[
name
]
=
value
;
}
void
set_unsigned
(
const
string
&
name
,
uint64_t
value
)
{
m_unsigneds
[
name
]
=
value
;
}
void
set_float_vector
(
const
string
&
name
,
const
vector
<
float
>&
value
)
{
m_float_vectors
[
name
]
=
value
;
}
void
set_signed_vector
(
const
string
&
name
,
const
vector
<
int64_t
>&
value
)
void
set_signed_vector
(
const
string
&
name
,
const
vector
<
int64_t
>&
value
)
{
{
m_signed_vectors
[
name
]
=
value
;
m_signed_vectors
[
name
]
=
value
;
}
}
void
set_string_vector
(
const
string
&
name
,
const
vector
<
string
>&
value
)
{
m_string_vectors
[
name
]
=
value
;
}
void
on_attribute
(
const
string
&
name
,
string
&
value
)
override
{
set_string
(
name
,
value
);
};
void
on_attribute
(
const
string
&
name
,
string
&
value
)
override
{
set_string
(
name
,
value
);
};
void
on_attribute
(
const
string
&
name
,
bool
&
value
)
override
{
set_bool
(
name
,
value
);
}
void
on_attribute
(
const
string
&
name
,
bool
&
value
)
override
{
set_bool
(
name
,
value
);
}
...
@@ -162,10 +173,6 @@ public:
...
@@ -162,10 +173,6 @@ public:
{
{
set_string
(
name
,
adapter
.
get
());
set_string
(
name
,
adapter
.
get
());
};
};
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
int64_t
>>&
adapter
)
override
{
set_signed_vector
(
name
,
adapter
.
get
());
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
override
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
override
{
{
set_signed
(
name
,
adapter
.
get
());
set_signed
(
name
,
adapter
.
get
());
...
@@ -174,6 +181,18 @@ public:
...
@@ -174,6 +181,18 @@ public:
{
{
set_double
(
name
,
adapter
.
get
());
set_double
(
name
,
adapter
.
get
());
}
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
float
>>&
adapter
)
override
{
set_float_vector
(
name
,
adapter
.
get
());
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
int64_t
>>&
adapter
)
override
{
set_signed_vector
(
name
,
adapter
.
get
());
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
string
>>&
adapter
)
override
{
set_string_vector
(
name
,
adapter
.
get
());
}
protected
:
protected
:
NodeTypeInfo
m_node_type_info
;
NodeTypeInfo
m_node_type_info
;
...
@@ -183,6 +202,8 @@ protected:
...
@@ -183,6 +202,8 @@ protected:
map
<
string
,
int64_t
>
m_signeds
;
map
<
string
,
int64_t
>
m_signeds
;
map
<
string
,
uint64_t
>
m_unsigneds
;
map
<
string
,
uint64_t
>
m_unsigneds
;
map
<
string
,
vector
<
int64_t
>>
m_signed_vectors
;
map
<
string
,
vector
<
int64_t
>>
m_signed_vectors
;
map
<
string
,
vector
<
float
>>
m_float_vectors
;
map
<
string
,
vector
<
std
::
string
>>
m_string_vectors
;
};
};
class
NodeBuilder
:
public
AttributeVisitor
class
NodeBuilder
:
public
AttributeVisitor
...
@@ -197,7 +218,6 @@ public:
...
@@ -197,7 +218,6 @@ public:
{
{
shared_ptr
<
Node
>
node
(
FactoryRegistry
<
Node
>::
get
().
create
(
m_values
.
get_node_type_info
()));
shared_ptr
<
Node
>
node
(
FactoryRegistry
<
Node
>::
get
().
create
(
m_values
.
get_node_type_info
()));
node
->
visit_attributes
(
*
this
);
node
->
visit_attributes
(
*
this
);
node
->
validate_and_infer_types
();
return
node
;
return
node
;
}
}
...
@@ -215,10 +235,6 @@ public:
...
@@ -215,10 +235,6 @@ public:
{
{
adapter
.
set
(
m_values
.
get_string
(
name
));
adapter
.
set
(
m_values
.
get_string
(
name
));
};
};
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
int64_t
>>&
adapter
)
override
{
adapter
.
set
(
m_values
.
get_signed_vector
(
name
));
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
override
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
int64_t
>&
adapter
)
override
{
{
adapter
.
set
(
m_values
.
get_signed
(
name
));
adapter
.
set
(
m_values
.
get_signed
(
name
));
...
@@ -227,6 +243,18 @@ public:
...
@@ -227,6 +243,18 @@ public:
{
{
adapter
.
set
(
m_values
.
get_double
(
name
));
adapter
.
set
(
m_values
.
get_double
(
name
));
}
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
int64_t
>>&
adapter
)
override
{
adapter
.
set
(
m_values
.
get_signed_vector
(
name
));
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
string
>>&
adapter
)
override
{
adapter
.
set
(
m_values
.
get_string_vector
(
name
));
}
void
on_adapter
(
const
string
&
name
,
ValueAccessor
<
vector
<
float
>>&
adapter
)
override
{
adapter
.
set
(
m_values
.
get_float_vector
(
name
));
}
protected
:
protected
:
NodeSaver
m_values
;
NodeSaver
m_values
;
...
@@ -256,3 +284,217 @@ TEST(attributes, user_op)
...
@@ -256,3 +284,217 @@ TEST(attributes, user_op)
EXPECT_EQ
(
g_oracle
->
get_hyper_parameters
(),
oracle
->
get_hyper_parameters
());
EXPECT_EQ
(
g_oracle
->
get_hyper_parameters
(),
oracle
->
get_hyper_parameters
());
EXPECT_EQ
(
g_oracle
->
get_ultra_parameters
(),
oracle
->
get_ultra_parameters
());
EXPECT_EQ
(
g_oracle
->
get_ultra_parameters
(),
oracle
->
get_ultra_parameters
());
}
}
TEST
(
attributes
,
elu_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
Elu
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
4
});
double
alpha
=
0.1
;
const
auto
elu
=
make_shared
<
opset1
::
Elu
>
(
data
,
alpha
);
NodeBuilder
builder
(
elu
);
auto
g_elu
=
as_type_ptr
<
opset1
::
Elu
>
(
builder
.
create
());
EXPECT_EQ
(
g_elu
->
get_alpha
(),
elu
->
get_alpha
());
}
TEST
(
attributes
,
fake_quantize_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
FakeQuantize
>
();
const
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
,
4
});
const
auto
input_low
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{});
const
auto
input_high
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{});
const
auto
output_low
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{});
const
auto
output_high
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{});
auto
levels
=
5
;
auto
auto_broadcast
=
op
::
AutoBroadcastType
::
NUMPY
;
const
auto
fake_quantize
=
make_shared
<
op
::
FakeQuantize
>
(
data
,
input_low
,
input_high
,
output_low
,
output_high
,
levels
,
auto_broadcast
);
NodeBuilder
builder
(
fake_quantize
);
auto
g_fake_quantize
=
as_type_ptr
<
opset1
::
FakeQuantize
>
(
builder
.
create
());
EXPECT_EQ
(
g_fake_quantize
->
get_levels
(),
fake_quantize
->
get_levels
());
EXPECT_EQ
(
g_fake_quantize
->
get_auto_broadcast
(),
fake_quantize
->
get_auto_broadcast
());
}
TEST
(
attributes
,
grn_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
GRN
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
,
4
,
5
});
float
bias
=
1.25
f
;
auto
grn
=
make_shared
<
opset1
::
GRN
>
(
data
,
bias
);
NodeBuilder
builder
(
grn
);
auto
g_grn
=
as_type_ptr
<
opset1
::
GRN
>
(
builder
.
create
());
EXPECT_EQ
(
g_grn
->
get_bias
(),
grn
->
get_bias
());
}
TEST
(
attributes
,
group_conv_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
GroupConvolution
>
();
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
,
224
,
224
});
auto
filters
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
1
,
3
,
5
,
5
});
auto
strides
=
Strides
{
1
,
1
};
auto
pads_begin
=
CoordinateDiff
{
1
,
2
};
auto
pads_end
=
CoordinateDiff
{
1
,
2
};
auto
dilations
=
Strides
{
1
,
1
};
auto
group_conv
=
make_shared
<
opset1
::
GroupConvolution
>
(
data
,
filters
,
strides
,
pads_begin
,
pads_end
,
dilations
,
op
::
PadType
::
VALID
);
NodeBuilder
builder
(
group_conv
);
auto
g_group_conv
=
as_type_ptr
<
opset1
::
GroupConvolution
>
(
builder
.
create
());
EXPECT_EQ
(
g_group_conv
->
get_strides
(),
group_conv
->
get_strides
());
EXPECT_EQ
(
g_group_conv
->
get_pads_begin
(),
group_conv
->
get_pads_begin
());
EXPECT_EQ
(
g_group_conv
->
get_pads_end
(),
group_conv
->
get_pads_end
());
EXPECT_EQ
(
g_group_conv
->
get_dilations
(),
group_conv
->
get_dilations
());
EXPECT_EQ
(
g_group_conv
->
get_auto_pad
(),
group_conv
->
get_auto_pad
());
}
TEST
(
attributes
,
group_conv_backprop_data_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
GroupConvolutionBackpropData
>
();
const
auto
data
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
20
,
224
,
224
});
const
auto
filter
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
4
,
5
,
2
,
3
,
3
});
const
auto
output_shape
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
8
,
447
,
447
});
const
auto
strides
=
Strides
{
2
,
1
};
const
auto
pads_begin
=
CoordinateDiff
{
3
,
4
};
const
auto
pads_end
=
CoordinateDiff
{
4
,
6
};
const
auto
dilations
=
Strides
{
3
,
1
};
const
auto
auto_pad
=
op
::
PadType
::
EXPLICIT
;
const
auto
output_padding
=
CoordinateDiff
{
3
,
4
};
const
auto
gcbd
=
make_shared
<
opset1
::
GroupConvolutionBackpropData
>
(
data
,
filter
,
output_shape
,
strides
,
pads_begin
,
pads_end
,
dilations
,
auto_pad
,
output_padding
);
NodeBuilder
builder
(
gcbd
);
const
auto
g_gcbd
=
as_type_ptr
<
opset1
::
GroupConvolutionBackpropData
>
(
builder
.
create
());
EXPECT_EQ
(
g_gcbd
->
get_strides
(),
gcbd
->
get_strides
());
EXPECT_EQ
(
g_gcbd
->
get_pads_begin
(),
gcbd
->
get_pads_begin
());
EXPECT_EQ
(
g_gcbd
->
get_pads_end
(),
gcbd
->
get_pads_end
());
EXPECT_EQ
(
g_gcbd
->
get_dilations
(),
gcbd
->
get_dilations
());
EXPECT_EQ
(
g_gcbd
->
get_auto_pad
(),
gcbd
->
get_auto_pad
());
EXPECT_EQ
(
g_gcbd
->
get_output_padding
(),
gcbd
->
get_output_padding
());
}
TEST
(
attributes
,
lrn_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
LRN
>
();
const
auto
arg
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
,
4
});
const
auto
axes
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
});
const
double
alpha
=
1.1
;
const
double
beta
=
2.2
;
const
double
bias
=
3.3
;
const
size_t
size
=
4
;
const
auto
lrn
=
make_shared
<
opset1
::
LRN
>
(
arg
,
axes
,
alpha
,
beta
,
bias
,
size
);
NodeBuilder
builder
(
lrn
);
auto
g_lrn
=
as_type_ptr
<
opset1
::
LRN
>
(
builder
.
create
());
EXPECT_EQ
(
g_lrn
->
get_alpha
(),
lrn
->
get_alpha
());
EXPECT_EQ
(
g_lrn
->
get_beta
(),
lrn
->
get_beta
());
EXPECT_EQ
(
g_lrn
->
get_bias
(),
lrn
->
get_bias
());
EXPECT_EQ
(
g_lrn
->
get_nsize
(),
lrn
->
get_nsize
());
}
TEST
(
attributes
,
lstm_cell_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
LSTMCell
>
();
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
auto
H
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
12
,
3
});
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
12
,
3
});
const
auto
initial_hidden_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
const
auto
initial_cell_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
2
,
3
});
const
auto
hidden_size
=
3
;
const
auto
weights_format
=
op
::
LSTMWeightsFormat
::
ICOF
;
const
std
::
vector
<
std
::
string
>
activations
=
{
"tanh"
,
"sigmoid"
,
"tanh"
};
auto
activations_alpha
=
std
::
vector
<
float
>
{
1.0
,
1.5
};
auto
activations_beta
=
std
::
vector
<
float
>
{
2.0
,
1.0
};
const
float
clip
=
0.5
f
;
bool
input_forget
=
true
;
const
auto
lstm_cell
=
make_shared
<
opset1
::
LSTMCell
>
(
X
,
initial_hidden_state
,
initial_cell_state
,
W
,
R
,
hidden_size
,
weights_format
,
activations
,
activations_alpha
,
activations_beta
,
clip
,
input_forget
);
NodeBuilder
builder
(
lstm_cell
);
auto
g_lstm_cell
=
as_type_ptr
<
opset1
::
LSTMCell
>
(
builder
.
create
());
EXPECT_EQ
(
g_lstm_cell
->
get_hidden_size
(),
lstm_cell
->
get_hidden_size
());
EXPECT_EQ
(
g_lstm_cell
->
get_activations
(),
lstm_cell
->
get_activations
());
EXPECT_EQ
(
g_lstm_cell
->
get_activations_alpha
(),
lstm_cell
->
get_activations_alpha
());
EXPECT_EQ
(
g_lstm_cell
->
get_activations_beta
(),
lstm_cell
->
get_activations_beta
());
EXPECT_EQ
(
g_lstm_cell
->
get_clip
(),
lstm_cell
->
get_clip
());
EXPECT_EQ
(
g_lstm_cell
->
get_input_forget
(),
lstm_cell
->
get_input_forget
());
EXPECT_EQ
(
g_lstm_cell
->
get_weights_format
(),
lstm_cell
->
get_weights_format
());
}
TEST
(
attributes
,
lstm_sequence_op
)
{
FactoryRegistry
<
Node
>::
get
().
register_factory
<
opset1
::
LSTMSequence
>
();
const
auto
X
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
4
});
const
auto
initial_hidden_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
initial_cell_state
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
2
,
3
});
const
auto
sequence_lengths
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
Shape
{
2
});
const
auto
W
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
,
4
});
const
auto
R
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
,
3
});
const
auto
B
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
Shape
{
1
,
12
});
const
auto
hidden_size
=
3
;
const
auto
lstm_direction
=
op
::
LSTMSequence
::
direction
::
FORWARD
;
const
auto
weights_format
=
op
::
LSTMWeightsFormat
::
ICOF
;
const
std
::
vector
<
float
>
activations_alpha
=
{
1
,
2
,
3
};
const
std
::
vector
<
float
>
activations_beta
=
{
4
,
5
,
6
};
const
std
::
vector
<
std
::
string
>
activations
=
{
"tanh"
,
"sigmoid"
,
"tanh"
};
const
float
clip_threshold
=
0.5
f
;
const
bool
input_forget
=
true
;
const
auto
lstm_sequence
=
make_shared
<
opset1
::
LSTMSequence
>
(
X
,
initial_hidden_state
,
initial_cell_state
,
sequence_lengths
,
W
,
R
,
B
,
hidden_size
,
lstm_direction
,
weights_format
,
activations_alpha
,
activations_beta
,
activations
,
clip_threshold
,
input_forget
);
NodeBuilder
builder
(
lstm_sequence
);
auto
g_lstm_sequence
=
as_type_ptr
<
opset1
::
LSTMSequence
>
(
builder
.
create
());
EXPECT_EQ
(
g_lstm_sequence
->
get_hidden_size
(),
lstm_sequence
->
get_hidden_size
());
EXPECT_EQ
(
g_lstm_sequence
->
get_activations
(),
lstm_sequence
->
get_activations
());
EXPECT_EQ
(
g_lstm_sequence
->
get_activations_alpha
(),
lstm_sequence
->
get_activations_alpha
());
EXPECT_EQ
(
g_lstm_sequence
->
get_activations_beta
(),
lstm_sequence
->
get_activations_beta
());
EXPECT_EQ
(
g_lstm_sequence
->
get_clip_threshold
(),
lstm_sequence
->
get_clip_threshold
());
EXPECT_EQ
(
g_lstm_sequence
->
get_direction
(),
lstm_sequence
->
get_direction
());
EXPECT_EQ
(
g_lstm_sequence
->
get_input_forget
(),
lstm_sequence
->
get_input_forget
());
EXPECT_EQ
(
g_lstm_sequence
->
get_weights_format
(),
lstm_sequence
->
get_weights_format
());
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment