Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
c1680ce3
Commit
c1680ce3
authored
Jun 26, 2019
by
fenglei
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'tfl/send_recv_op' of github.com:NervanaSystems/ngraph into tfl/send_recv_op
parents
c643cb5e
603cbdab
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
232 additions
and
154 deletions
+232
-154
fused_op_tbl.hpp
src/ngraph/op/fused_op_tbl.hpp
+3
-3
op_tbl.hpp
src/ngraph/op/op_tbl.hpp
+5
-5
fused_op_decomposition.cpp
src/ngraph/pass/fused_op_decomposition.cpp
+24
-14
fused_op_decomposition.hpp
src/ngraph/pass/fused_op_decomposition.hpp
+15
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+0
-1
serializer.cpp
src/ngraph/serializer.cpp
+185
-130
No files found.
src/ngraph/op/fused_op_tbl.hpp
View file @
c1680ce3
...
...
@@ -24,8 +24,8 @@ NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP
(
DepthToSpace
,
ngraph
::
op
)
NGRAPH_OP
(
Elu
,
ngraph
::
op
)
NGRAPH_OP
(
FakeQuantize
,
ngraph
::
op
)
NGRAPH_OP
(
GRN
,
ngraph
::
op
)
NGRAPH_OP
(
Gemm
,
ngraph
::
op
)
NGRAPH_OP
(
GRN
,
ngraph
::
op
)
NGRAPH_OP
(
GroupConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
GroupConvolutionTranspose
,
ngraph
::
op
)
NGRAPH_OP
(
HardSigmoid
,
ngraph
::
op
)
...
...
@@ -35,9 +35,9 @@ NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP
(
Normalize
,
ngraph
::
op
)
NGRAPH_OP
(
PRelu
,
ngraph
::
op
)
NGRAPH_OP
(
ScaleShift
,
ngraph
::
op
)
NGRAPH_OP
(
SpaceToDepth
,
ngraph
::
op
)
NGRAPH_OP
(
ShuffleChannels
,
ngraph
::
op
)
NGRAPH_OP
(
SpaceToDepth
,
ngraph
::
op
)
NGRAPH_OP
(
Split
,
ngraph
::
op
)
NGRAPH_OP
(
SquaredDifference
,
ngraph
::
op
)
NGRAPH_OP
(
Squeeze
,
ngraph
::
op
)
NGRAPH_OP
(
Split
,
ngraph
::
op
)
NGRAPH_OP
(
Unsqueeze
,
ngraph
::
op
)
src/ngraph/op/op_tbl.hpp
View file @
c1680ce3
...
...
@@ -81,11 +81,12 @@ NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP
(
Cosh
,
ngraph
::
op
)
NGRAPH_OP
(
Dequantize
,
ngraph
::
op
)
NGRAPH_OP
(
Divide
,
ngraph
::
op
)
NGRAPH_OP
(
DynBroadcast
,
ngraph
::
op
)
NGRAPH_OP
(
Dot
,
ngraph
::
op
)
NGRAPH_OP
(
DynBroadcast
,
ngraph
::
op
)
NGRAPH_OP
(
DynPad
,
ngraph
::
op
)
NGRAPH_OP
(
DynReshape
,
ngraph
::
op
)
NGRAPH_OP
(
DynSlice
,
ngraph
::
op
)
NGRAPH_OP
(
EmbeddingLookup
,
ngraph
::
op
)
NGRAPH_OP
(
Equal
,
ngraph
::
op
)
NGRAPH_OP
(
Erf
,
ngraph
::
op
)
NGRAPH_OP
(
Exp
,
ngraph
::
op
)
...
...
@@ -119,13 +120,13 @@ NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP
(
Product
,
ngraph
::
op
)
NGRAPH_OP
(
Quantize
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedAvgPool
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBiasAdd
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBiasSignedAdd
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionRelu
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDotBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDot
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDotBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedMaxPool
,
ngraph
::
op
)
NGRAPH_OP
(
Recv
,
ngraph
::
op
)
NGRAPH_OP
(
Range
,
ngraph
::
op
)
...
...
@@ -155,7 +156,6 @@ NGRAPH_OP(Subtract, ngraph::op)
NGRAPH_OP
(
Sum
,
ngraph
::
op
)
NGRAPH_OP
(
Tan
,
ngraph
::
op
)
NGRAPH_OP
(
Tanh
,
ngraph
::
op
)
NGRAPH_OP
(
TopK
,
ngraph
::
op
)
NGRAPH_OP
(
Tile
,
ngraph
::
op
)
NGRAPH_OP
(
TopK
,
ngraph
::
op
)
NGRAPH_OP
(
Transpose
,
ngraph
::
op
)
NGRAPH_OP
(
EmbeddingLookup
,
ngraph
::
op
)
src/ngraph/pass/fused_op_decomposition.cpp
View file @
c1680ce3
...
...
@@ -13,36 +13,51 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/fused_op.hpp"
using
namespace
std
;
using
namespace
ngraph
;
bool
ngraph
::
pass
::
FusedOpDecomposition
::
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
pass
::
FusedOpDecomposition
::
FusedOpDecomposition
(
op_query_t
callback
)
:
m_has_direct_support
{
callback
}
{
}
bool
pass
::
FusedOpDecomposition
::
run_on_node
(
shared_ptr
<
Node
>
node
)
{
bool
modified
=
false
;
if
(
auto
fused_op
=
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
util
::
FusedOp
>
(
node
))
if
(
auto
fused_op
=
dynamic_pointer_cast
<
op
::
util
::
FusedOp
>
(
node
))
{
if
(
m_
callback
&&
m_callback
(
*
node
))
if
(
m_
has_direct_support
&&
m_has_direct_support
(
*
node
))
{
// Op supported by backend. Do not decompose
return
modified
;
}
auto
subgraph_outputs
=
fused_op
->
decompose_op
();
// Run recursively untill no more fused ops
auto
subgraph
=
extract_subgraph
(
subgraph_outputs
,
fused_op
->
get_arguments
());
for
(
auto
subgraph_node
:
subgraph
)
{
if
(
auto
nested_fused_op
=
dynamic_pointer_cast
<
op
::
util
::
FusedOp
>
(
subgraph_node
))
{
if
(
!
(
m_has_direct_support
&&
m_has_direct_support
(
*
nested_fused_op
)))
{
run_on_node
(
nested_fused_op
);
}
}
}
size_t
i
=
0
;
for
(
auto
output_node
:
subgraph_outputs
)
{
for
(
size_t
j
=
0
;
j
<
output_node
->
get_outputs
().
size
();
j
++
,
i
++
)
{
// TODO: Provenance
std
::
set
<
ngraph
::
descriptor
::
Input
*>
fop_users
{
begin
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
()),
set
<
descriptor
::
Input
*>
fop_users
{
begin
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
()),
end
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
())};
for
(
auto
fop_user
:
fop_users
)
{
...
...
@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
if
(
goe
->
get_n
()
==
i
&&
!
goe
->
get_output_inputs
(
0
).
empty
())
{
// Replace GOE users
s
td
::
set
<
ngraph
::
descriptor
::
Input
*>
goe_users
{
s
et
<
descriptor
::
Input
*>
goe_users
{
begin
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
()),
end
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
())};
for
(
auto
goe_user
:
goe_users
)
...
...
@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
return
modified
;
}
pass
::
FusedOpDecomposition
::
FusedOpDecomposition
(
op_query_t
callback
)
:
m_callback
{
callback
}
{
}
src/ngraph/pass/fused_op_decomposition.hpp
View file @
c1680ce3
...
...
@@ -16,6 +16,9 @@
#pragma once
#include <memory>
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
...
...
@@ -25,13 +28,24 @@ namespace ngraph
class
FusedOpDecomposition
:
public
NodePass
{
public
:
/// \brief Function signature type for callback used to check whether provided node
/// is supported by backend.
using
op_query_t
=
std
::
function
<
bool
(
const
Node
&
node
)
>
;
///
/// \brief Constructor for the Fused operation decomposition pass.
///
/// \param[in] callback The function object used to determine whether current backend
/// provide direct support for passed node. Should have signature:
/// bool fn(const Node&)
///
FusedOpDecomposition
(
op_query_t
callback
=
nullptr
);
bool
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
override
;
private
:
op_query_t
m_callback
=
nullptr
;
/// \brief A function returning whether provided Node is supported by current backend.
/// The returned bool value is used to control whether decompose operator or not.
op_query_t
m_has_direct_support
=
nullptr
;
};
}
}
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
c1680ce3
...
...
@@ -1180,7 +1180,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS
(
RecurrentReshapeElimination
,
false
,
ngraph
::
pass
);
REGISTER_KNOBBED_PASS_WITH_ARGS
(
CoreFusion
,
true
,
ngraph
::
pass
,
ngraph
::
pass
::
FusionType
::
ALL_FUSIONS
);
REGISTER_KNOBBED_PASS_WITH_ARGS
(
FusedOpDecomposition
,
true
,
ngraph
::
pass
,
is_supported
);
REGISTER_KNOBBED_PASS
(
CPUFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUQuantFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUHorizontalFusion
,
true
,
runtime
::
cpu
::
pass
);
...
...
src/ngraph/serializer.cpp
View file @
c1680ce3
...
...
@@ -193,10 +193,15 @@ static OP_TYPEID get_typeid(const string& s)
return
rc
;
}
bool
has_key
(
json
j
,
const
std
::
string
&
key
)
{
return
j
.
count
(
key
)
!=
0
;
}
template
<
typename
T
>
T
get_or_default
(
nlohmann
::
json
&
j
,
const
std
::
string
&
key
,
const
T
&
default_value
)
T
get_or_default
(
json
j
,
const
std
::
string
&
key
,
const
T
&
default_value
)
{
return
j
.
count
(
key
)
!=
0
?
j
.
at
(
key
).
get
<
T
>
()
:
default_value
;
return
has_key
(
j
,
key
)
?
j
.
at
(
key
).
get
<
T
>
()
:
default_value
;
}
class
JSONSerializer
...
...
@@ -215,8 +220,11 @@ public:
json
serialize_function
(
const
Function
&
function
);
json
serialize_output
(
const
Output
<
Node
>&
output
);
json
serialize_parameter_vector
(
const
ParameterVector
&
parameters
);
json
serialize_output_vector
(
const
OutputVector
&
output_vector
);
json
serialize_node_reference
(
const
Node
&
node
);
json
serialize_node
(
const
Node
&
node
);
json
serialize_axis_set
(
const
AxisSet
&
axis_set
);
protected
:
size_t
m_indent
{
0
};
...
...
@@ -235,10 +243,13 @@ public:
m_const_data_callback
=
const_data_callback
;
}
shared_ptr
<
Function
>
deserialize_function
(
json
&
j
);
Output
<
Node
>
deserialize_output
(
json
&
j
);
shared_ptr
<
Node
>
deserialize_node_reference
(
json
&
j
);
shared_ptr
<
Node
>
deserialize_node
(
json
&
j
);
shared_ptr
<
Function
>
deserialize_function
(
json
j
);
Output
<
Node
>
deserialize_output
(
json
j
);
OutputVector
deserialize_output_vector
(
json
j
);
ParameterVector
deserialize_parameter_vector
(
json
j
);
shared_ptr
<
Node
>
deserialize_node_reference
(
json
j
);
shared_ptr
<
Node
>
deserialize_node
(
json
j
);
AxisSet
deserialize_axis_set
(
json
j
);
protected
:
unordered_map
<
string
,
shared_ptr
<
Node
>>
m_node_map
;
...
...
@@ -261,7 +272,7 @@ static json write_dimension(Dimension d)
}
}
static
Dimension
read_dimension
(
const
json
&
j
)
static
Dimension
read_dimension
(
json
j
)
{
if
(
j
.
is_null
())
{
...
...
@@ -290,7 +301,7 @@ static json write_partial_shape(const PartialShape& s)
}
}
static
PartialShape
read_partial_shape
(
const
json
&
j
)
static
PartialShape
read_partial_shape
(
json
j
)
{
if
(
j
.
is_null
())
{
...
...
@@ -315,19 +326,32 @@ static json write_auto_broadcast(const op::AutoBroadcastSpec& autob)
return
j
;
}
static
op
::
AutoBroadcastSpec
read_auto_broadcast
(
const
json
&
j
)
static
op
::
AutoBroadcastSpec
read_auto_broadcast
(
json
js_node
,
const
std
::
string
&
attr
)
{
if
(
!
j
.
is_object
(
))
if
(
has_key
(
js_node
,
attr
))
{
return
op
::
AutoBroadcastSpec
();
json
j
=
js_node
[
attr
];
return
op
::
AutoBroadcastSpec
(
static_cast
<
op
::
AutoBroadcastType
>
(
j
.
at
(
"type"
)),
j
.
at
(
"axis"
).
get
<
size_t
>
());
}
else
{
return
op
::
AutoBroadcastSpec
(
static_cast
<
op
::
AutoBroadcastType
>
(
j
.
at
(
"type"
)),
j
.
at
(
"axis"
).
get
<
size_t
>
());
return
op
::
AutoBroadcastSpec
();
}
}
static
op
::
PadType
read_pad_type
(
json
node_js
)
{
return
has_key
(
node_js
,
"pad_type"
)
?
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
))
:
op
::
PadType
::
EXPLICIT
;
}
static
op
::
PadMode
read_pad_mode
(
json
node_js
)
{
return
has_key
(
node_js
,
"pad_mode"
)
?
static_cast
<
op
::
PadMode
>
(
node_js
.
at
(
"pad_mode"
))
:
op
::
PadMode
::
CONSTANT
;
}
static
json
write_element_type
(
const
ngraph
::
element
::
Type
&
n
)
{
json
j
;
...
...
@@ -335,7 +359,7 @@ static json write_element_type(const ngraph::element::Type& n)
return
j
;
}
static
element
::
Type
read_element_type
(
const
json
&
j
)
static
element
::
Type
read_element_type
(
json
j
)
{
size_t
bitwidth
=
0
;
bool
is_real
=
false
;
...
...
@@ -495,21 +519,24 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
rc
=
deserializer
.
deserialize_function
(
func
);
}
}
return
rc
;
}
json
JSONSerializer
::
serialize_parameter_vector
(
const
ParameterVector
&
parameters
)
{
json
json_parameters
=
json
::
array
();
for
(
auto
param
:
parameters
)
{
json_parameters
.
push_back
(
serialize_node_reference
(
*
param
));
}
return
json_parameters
;
}
json
JSONSerializer
::
serialize_function
(
const
Function
&
f
)
{
json
function
;
function
[
"name"
]
=
f
.
get_name
();
vector
<
string
>
parameter_list
;
for
(
auto
param
:
f
.
get_parameters
())
{
parameter_list
.
push_back
(
serialize_node_reference
(
*
param
));
}
function
[
"parameters"
]
=
parameter_list
;
function
[
"parameters"
]
=
serialize_parameter_vector
(
f
.
get_parameters
());
// TODO Functions can return multiple results
for
(
size_t
i
=
0
;
i
<
f
.
get_output_size
();
++
i
)
...
...
@@ -521,7 +548,7 @@ json JSONSerializer::serialize_function(const Function& f)
}
template
<
typename
T
>
T
get_value
(
nlohmann
::
json
js
,
const
string
&
key
)
T
get_value
(
json
js
,
const
string
&
key
)
{
T
rc
;
auto
it
=
js
.
find
(
key
);
...
...
@@ -532,13 +559,13 @@ T get_value(nlohmann::json js, const string& key)
return
rc
;
}
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node_reference
(
json
&
j
)
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node_reference
(
json
j
)
{
const
string
&
name
=
j
;
return
m_node_map
.
at
(
name
);
}
Output
<
Node
>
JSONDeserializer
::
deserialize_output
(
json
&
j
)
Output
<
Node
>
JSONDeserializer
::
deserialize_output
(
json
j
)
{
size_t
index
;
json
json_node_reference
;
...
...
@@ -559,10 +586,48 @@ Output<Node> JSONDeserializer::deserialize_output(json& j)
return
Output
<
Node
>
(
deserialize_node_reference
(
json_node_reference
),
index
);
}
shared_ptr
<
Function
>
JSONDeserializer
::
deserialize_function
(
json
&
func_js
)
OutputVector
JSONDeserializer
::
deserialize_output_vector
(
json
j
)
{
OutputVector
result
;
if
(
j
.
is_array
())
{
for
(
json
jelt
:
j
)
{
result
.
push_back
(
deserialize_output
(
jelt
));
}
}
return
result
;
}
json
JSONSerializer
::
serialize_axis_set
(
const
AxisSet
&
axis_set
)
{
return
static_cast
<
set
<
size_t
>>
(
axis_set
);
}
AxisSet
JSONDeserializer
::
deserialize_axis_set
(
json
j
)
{
AxisSet
result
;
if
(
j
.
is_array
())
{
result
=
j
.
get
<
set
<
size_t
>>
();
}
return
result
;
}
ParameterVector
JSONDeserializer
::
deserialize_parameter_vector
(
json
json_parameters
)
{
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
params
;
for
(
auto
&
param_ref
:
json_parameters
)
{
params
.
push_back
(
dynamic_pointer_cast
<
op
::
Parameter
>
(
deserialize_node_reference
(
param_ref
)));
}
return
params
;
}
shared_ptr
<
Function
>
JSONDeserializer
::
deserialize_function
(
json
func_js
)
{
string
func_name
=
func_js
.
at
(
"name"
).
get
<
string
>
();
vector
<
json
>
func_parameters
=
func_js
.
at
(
"parameters"
);
vector
<
json
>
func_result
=
func_js
.
at
(
"result"
);
for
(
json
node_js
:
func_js
.
at
(
"ops"
))
{
...
...
@@ -594,12 +659,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
"Graph serialization is inconsistent. Some op::Results appear to be missing"
);
}
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
params
;
for
(
auto
&
param_ref
:
func_parameters
)
{
params
.
push_back
(
dynamic_pointer_cast
<
op
::
Parameter
>
(
deserialize_node_reference
(
param_ref
)));
}
ParameterVector
params
=
deserialize_parameter_vector
(
func_js
.
at
(
"parameters"
));
shared_ptr
<
Function
>
rc
{
make_shared
<
Function
>
(
result
,
params
,
func_name
)};
m_function_map
[
func_name
]
=
rc
;
...
...
@@ -632,7 +692,12 @@ struct OutputHelper
// when all op constructors use the new style arguments.
struct
OutputVectorHelper
{
const
OutputHelper
&
operator
[](
size_t
i
)
const
{
return
m_vector
[
i
];
}
OutputVectorHelper
(
const
OutputVector
&
output_vector
)
:
m_vector
(
output_vector
)
{
}
OutputVectorHelper
()
=
default
;
OutputHelper
operator
[](
size_t
i
)
const
{
return
OutputHelper
(
m_vector
[
i
]);
}
void
push_back
(
const
Output
<
Node
>&
output
)
{
m_vector
.
push_back
(
output
);
}
size_t
size
()
const
{
return
m_vector
.
size
();
}
operator
vector
<
shared_ptr
<
Node
>>
()
const
...
...
@@ -640,14 +705,15 @@ struct OutputVectorHelper
vector
<
shared_ptr
<
Node
>>
result
;
for
(
auto
&
o
:
m_vector
)
{
result
.
push_back
(
o
);
result
.
push_back
(
OutputHelper
(
o
)
);
}
return
result
;
}
vector
<
OutputHelper
>
m_vector
;
operator
const
OutputVector
&
()
const
{
return
m_vector
;
}
OutputVector
m_vector
;
};
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node
(
json
&
node_js
)
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node
(
json
node_js
)
{
shared_ptr
<
Node
>
node
;
try
...
...
@@ -655,14 +721,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
string
node_name
=
node_js
.
at
(
"name"
).
get
<
string
>
();
string
node_op
=
node_js
.
at
(
"op"
).
get
<
string
>
();
string
friendly_name
=
get_value
<
string
>
(
node_js
,
"friendly_name"
);
vector
<
json
>
node_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"inputs"
);
vector
<
json
>
control_deps_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"control_deps"
);
vector
<
string
>
node_outputs
=
get_value
<
vector
<
string
>>
(
node_js
,
"outputs"
);
OutputVectorHelper
args
;
for
(
auto
&
node_input
:
node_inputs
)
{
args
.
push_back
(
deserialize_output
(
node_input
));
}
OutputVectorHelper
args
(
deserialize_output_vector
(
node_js
[
"inputs"
]));
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
...
...
@@ -683,12 +744,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
Add
:
{
node
=
make_shared
<
op
::
Add
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Add
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
All
:
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
All
>
(
args
[
0
],
reduction_axes
);
break
;
}
...
...
@@ -699,12 +760,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
And
:
{
node
=
make_shared
<
op
::
And
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
And
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
Any
:
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Any
>
(
args
[
0
],
reduction_axes
);
break
;
}
...
...
@@ -741,12 +802,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto
padding_above
=
node_js
.
at
(
"padding_above"
).
get
<
vector
<
size_t
>>
();
auto
include_padding_in_avg_computation
=
node_js
.
at
(
"include_padding_in_avg_computation"
).
get
<
bool
>
();
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
bool
ceil_mode
=
node_js
[
"ceil_mode"
].
empty
()
?
false
:
node_js
.
at
(
"ceil_mode"
).
get
<
bool
>
();
;
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
bool
ceil_mode
=
get_or_default
<
bool
>
(
node_js
,
"ceil_mode"
,
false
);
node
=
make_shared
<
op
::
AvgPool
>
(
args
[
0
],
window_shape
,
window_movement_strides
,
...
...
@@ -808,7 +865,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Broadcast
:
{
auto
shape
=
node_js
.
at
(
"shape"
).
get
<
vector
<
size_t
>>
();
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
(
);
auto
axes
=
deserialize_axis_set
(
node_js
.
at
(
"axes"
)
);
node
=
make_shared
<
op
::
Broadcast
>
(
args
[
0
],
shape
,
axes
);
break
;
}
...
...
@@ -819,7 +876,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
BroadcastLike
:
{
auto
initial_axes
=
node_js
.
at
(
"initial_axes"
).
get
<
set
<
size_t
>>
(
);
auto
initial_axes
=
deserialize_axis_set
(
node_js
.
at
(
"initial_axes"
)
);
node
=
make_shared
<
op
::
BroadcastLike
>
(
args
[
0
],
args
[
1
],
initial_axes
);
break
;
}
...
...
@@ -838,13 +895,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Concat
:
{
auto
axis
=
node_js
.
at
(
"axis"
).
get
<
size_t
>
();
node
=
make_shared
<
op
::
Concat
>
(
args
,
axis
);
node
=
make_shared
<
op
::
Concat
>
(
static_cast
<
OutputVector
>
(
args
)
,
axis
);
break
;
}
case
OP_TYPEID
:
:
Constant
:
{
auto
type_node_js
=
node_js
.
count
(
"element_type"
)
==
0
?
node_js
.
at
(
"value_type"
)
:
node_js
;
has_key
(
node_js
,
"element_type"
)
?
node_js
:
node_js
.
at
(
"value_type"
)
;
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
value
=
node_js
.
at
(
"value"
).
get
<
vector
<
string
>>
();
...
...
@@ -868,17 +925,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
// For backwards compatibility, we accept "image_dilation_strides" in place of
// "data_dilation_strides", and we also allow it to be omitted altogether.
auto
data_dilation_strides_maybe
=
node_js
[
"data_dilation_strides"
];
if
(
data_dilation_strides_maybe
.
empty
())
json
data_dilation_strides
;
if
(
has_key
(
node_js
,
"data_dilation_strides"
))
{
data_dilation_strides
=
node_js
[
"data_dilation_strides"
];
}
else
if
(
has_key
(
node_js
,
"image_dilation_strides"
))
{
data_dilation_strides
_maybe
=
node_js
[
"image_dilation_strides"
];
data_dilation_strides
=
node_js
[
"image_dilation_strides"
];
}
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
if
(
data_dilation_strides
_maybe
.
empty
())
if
(
data_dilation_strides
.
empty
())
{
node
=
make_shared
<
op
::
Convolution
>
(
args
[
0
],
args
[
1
],
...
...
@@ -889,14 +948,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
else
{
node
=
make_shared
<
op
::
Convolution
>
(
args
[
0
],
node
=
make_shared
<
op
::
Convolution
>
(
args
[
0
],
args
[
1
],
window_movement_strides
,
window_dilation_strides
,
padding_below
,
padding_above
,
data_dilation_strides_maybe
.
get
<
std
::
vector
<
size_t
>>
(),
data_dilation_strides
.
get
<
std
::
vector
<
size_t
>>
(),
pad_type
);
}
break
;
...
...
@@ -1033,33 +1092,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Dequantize
:
{
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
(
);
auto
axes
=
deserialize_axis_set
(
node_js
.
at
(
"axes"
)
);
node
=
make_shared
<
op
::
Dequantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
);
break
;
}
case
OP_TYPEID
:
:
Divide
:
{
bool
pythondiv
=
true
;
if
(
node_js
[
"pythondiv"
].
is_object
())
{
pythondiv
=
node_js
.
at
(
"pythondiv"
).
get
<
bool
>
();
}
bool
pythondiv
=
get_or_default
(
node_js
,
"pythondiv"
,
true
);
node
=
make_shared
<
op
::
Divide
>
(
args
[
0
],
args
[
1
],
pythondiv
,
read_auto_broadcast
(
node_js
[
"autob"
]
));
args
[
0
],
args
[
1
],
pythondiv
,
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
Dot
:
{
// For backwards compatibility, reduction_axes_count is optional.
auto
obj
=
node_js
[
"reduction_axes_count"
];
if
(
obj
.
empty
())
if
(
has_key
(
node_js
,
"reduction_axes_count"
))
{
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
]);
size_t
reduction_axes_count
=
node_js
[
"reduction_axes_count"
].
get
<
size_t
>
();
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
],
reduction_axes_count
);
}
else
{
size_t
reduction_axes_count
=
obj
.
get
<
size_t
>
();
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
],
reduction_axes_count
);
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
]);
}
break
;
}
...
...
@@ -1095,7 +1149,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
Equal
:
{
node
=
make_shared
<
op
::
Equal
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Equal
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
Erf
:
...
...
@@ -1160,13 +1214,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Greater
:
{
node
=
make_shared
<
op
::
Greater
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Greater
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
GreaterEq
:
{
node
=
make_shared
<
op
::
GreaterEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
GreaterEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
GRN
:
...
...
@@ -1193,10 +1247,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto
data_dilation_strides
=
node_js
.
at
(
"data_dilation_strides"
).
get
<
vector
<
size_t
>>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
node
=
make_shared
<
op
::
GroupConvolution
>
(
args
[
0
],
args
[
1
],
window_movement_strides
,
...
...
@@ -1216,9 +1267,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto
padding_end
=
node_js
.
at
(
"padding_end"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
output_padding
=
node_js
.
at
(
"output_padding"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
auto
output_shape
=
node_js
.
at
(
"output_shape"
).
get
<
vector
<
size_t
>>
();
node
=
make_shared
<
op
::
GroupConvolutionTranspose
>
(
args
[
0
],
...
...
@@ -1240,12 +1289,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
Less
:
{
node
=
make_shared
<
op
::
Less
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Less
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
LessEq
:
{
node
=
make_shared
<
op
::
LessEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
LessEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
Log
:
...
...
@@ -1287,7 +1336,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
Max
:
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Max
>
(
args
[
0
],
reduction_axes
);
break
;
}
...
...
@@ -1298,11 +1347,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
auto
padding_below_maybe
=
node_js
[
"padding_below"
];
auto
padding_above_maybe
=
node_js
[
"padding_above"
];
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
auto
padding_below_maybe
=
get_or_default
(
node_js
,
"padding_below"
,
json
{});
auto
padding_above_maybe
=
get_or_default
(
node_js
,
"padding_above"
,
json
{});
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
if
(
padding_below_maybe
.
empty
()
&&
!
padding_above_maybe
.
empty
())
{
throw
runtime_error
(
...
...
@@ -1361,31 +1408,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Maximum
:
{
node
=
make_shared
<
op
::
Maximum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Maximum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
Min
:
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Min
>
(
args
[
0
],
reduction_axes
);
break
;
}
case
OP_TYPEID
:
:
Minimum
:
{
node
=
make_shared
<
op
::
Minimum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Minimum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
Multiply
:
{
node
=
make_shared
<
op
::
Multiply
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Multiply
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
MVN
:
{
auto
normalize_variance
=
node_js
.
at
(
"normalize_variance"
).
get
<
bool
>
();
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
auto
eps
=
node_js
.
at
(
"eps"
).
get
<
double
>
();
node
=
make_shared
<
op
::
MVN
>
(
args
[
0
],
normalize_variance
,
normalize_variance
,
eps
);
break
;
...
...
@@ -1407,7 +1454,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
NotEqual
:
{
node
=
make_shared
<
op
::
NotEqual
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
NotEqual
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
Not
:
...
...
@@ -1424,7 +1471,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
Or
:
{
node
=
make_shared
<
op
::
Or
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Or
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
Pad
:
...
...
@@ -1441,9 +1488,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
[](
size_t
s
)
{
return
s
==
0
;
}),
"Legacy padding_interior field must be zero everywhere."
);
auto
pad_mode
=
node_js
.
count
(
"pad_mode"
)
==
0
?
op
::
PadMode
::
CONSTANT
:
static_cast
<
op
::
PadMode
>
(
node_js
.
at
(
"pad_mode"
));
auto
pad_mode
=
read_pad_mode
(
node_js
);
node
=
make_shared
<
op
::
Pad
>
(
args
[
0
],
args
[
1
],
padding_below
,
padding_above
,
pad_mode
);
break
;
...
...
@@ -1451,7 +1496,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Parameter
:
{
auto
type_node_js
=
node_js
.
count
(
"element_type"
)
==
0
?
node_js
.
at
(
"value_type"
)
:
node_js
;
has_key
(
node_js
,
"element_type"
)
?
node_js
:
node_js
.
at
(
"value_type"
)
;
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
cacheable
=
get_or_default
<
bool
>
(
node_js
,
"cacheable"
,
false
);
...
...
@@ -1476,7 +1521,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
Power
:
{
node
=
make_shared
<
op
::
Power
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Power
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
PRelu
:
...
...
@@ -1486,14 +1531,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
Product
:
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Product
>
(
args
[
0
],
reduction_axes
);
break
;
}
case
OP_TYPEID
:
:
Quantize
:
{
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
(
);
auto
axes
=
deserialize_axis_set
(
node_js
.
at
(
"axes"
)
);
auto
round_mode
=
node_js
.
at
(
"round_mode"
).
get
<
op
::
Quantize
::
RoundMode
>
();
node
=
make_shared
<
op
::
Quantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
,
round_mode
);
break
;
...
...
@@ -1552,8 +1597,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
auto
padding_below_maybe
=
node_js
[
"padding_below"
]
;
auto
padding_above_maybe
=
node_js
[
"padding_above"
]
;
auto
padding_below_maybe
=
get_or_default
(
node_js
,
"padding_below"
,
json
{})
;
auto
padding_above_maybe
=
get_or_default
(
node_js
,
"padding_above"
,
json
{})
;
auto
padding_below
=
padding_below_maybe
.
get
<
vector
<
size_t
>>
();
auto
padding_above
=
padding_above_maybe
.
get
<
vector
<
size_t
>>
();
node
=
make_shared
<
op
::
QuantizedMaxPool
>
(
...
...
@@ -1607,7 +1652,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
Reverse
:
{
auto
reversed_axes
=
node_js
.
at
(
"reversed_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reversed_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reversed_axes"
)
);
node
=
make_shared
<
op
::
Reverse
>
(
args
[
0
],
reversed_axes
);
break
;
}
...
...
@@ -1697,7 +1742,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
case
OP_TYPEID
:
:
Softmax
:
{
auto
softmax_axes
=
node_js
.
at
(
"softmax_axes"
).
get
<
set
<
size_t
>>
(
);
auto
softmax_axes
=
deserialize_axis_set
(
node_js
.
at
(
"softmax_axes"
)
);
node
=
make_shared
<
op
::
Softmax
>
(
args
[
0
],
softmax_axes
);
break
;
}
...
...
@@ -1732,12 +1777,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Subtract
:
{
node
=
make_shared
<
op
::
Subtract
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Subtract
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
}
case
OP_TYPEID
:
:
Sum
:
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Sum
>
(
args
[
0
],
reduction_axes
);
break
;
}
...
...
@@ -1873,6 +1918,16 @@ json JSONSerializer::serialize_output(const Output<Node>& output)
return
result
;
}
json
JSONSerializer
::
serialize_output_vector
(
const
OutputVector
&
output_vector
)
{
json
result
;
for
(
const
Output
<
Node
>&
output
:
output_vector
)
{
result
.
push_back
(
serialize_output
(
output
));
}
return
result
;
}
json
JSONSerializer
::
serialize_node
(
const
Node
&
n
)
{
m_nodes_serialized
.
insert
(
&
n
);
...
...
@@ -1972,7 +2027,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
All
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
All
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
}
case
OP_TYPEID
:
:
AllReduce
:
{
break
;
...
...
@@ -1989,7 +2044,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Any
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
Any
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
}
case
OP_TYPEID
:
:
Asin
:
{
break
;
...
...
@@ -2045,7 +2100,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Broadcast
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
Broadcast
*>
(
&
n
);
node
[
"axes"
]
=
tmp
->
get_broadcast_axes
(
);
node
[
"axes"
]
=
serialize_axis_set
(
tmp
->
get_broadcast_axes
()
);
node
[
"shape"
]
=
tmp
->
get_broadcast_shape
();
break
;
}
...
...
@@ -2054,7 +2109,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
BroadcastLike
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
BroadcastLike
*>
(
&
n
);
node
[
"initial_axes"
]
=
tmp
->
get_initial_broadcast_axes
(
);
node
[
"initial_axes"
]
=
serialize_axis_set
(
tmp
->
get_initial_broadcast_axes
()
);
break
;
}
case
OP_TYPEID
:
:
Ceiling
:
{
break
;
...
...
@@ -2168,7 +2223,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Dequantize
*>
(
&
n
);
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"axes"
]
=
tmp
->
get_axes
(
);
node
[
"axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
()
);
break
;
}
case
OP_TYPEID
:
:
DepthToSpace
:
...
...
@@ -2361,7 +2416,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Max
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
Max
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
}
case
OP_TYPEID
:
:
MaxPool
:
...
...
@@ -2395,7 +2450,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Min
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
Min
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
}
case
OP_TYPEID
:
:
Minimum
:
...
...
@@ -2419,7 +2474,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
MVN
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
MVN
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
node
[
"normalize_variance"
]
=
tmp
->
get_normalize_variance
();
node
[
"eps"
]
=
tmp
->
get_eps
();
break
;
...
...
@@ -2499,7 +2554,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Product
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
Product
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
}
case
OP_TYPEID
:
:
Power
:
...
...
@@ -2515,7 +2570,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
auto
tmp
=
dynamic_cast
<
const
op
::
Quantize
*>
(
&
n
);
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"axes"
]
=
tmp
->
get_axes
(
);
node
[
"axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
()
);
node
[
"round_mode"
]
=
tmp
->
get_round_mode
();
break
;
}
...
...
@@ -2596,7 +2651,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Reverse
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
Reverse
*>
(
&
n
);
node
[
"reversed_axes"
]
=
tmp
->
get_reversed_axes
(
);
node
[
"reversed_axes"
]
=
serialize_axis_set
(
tmp
->
get_reversed_axes
()
);
break
;
}
case
OP_TYPEID
:
:
ReverseSequence
:
...
...
@@ -2689,13 +2744,13 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Sum
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
Sum
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
}
case
OP_TYPEID
:
:
Softmax
:
{
auto
tmp
=
dynamic_cast
<
const
op
::
Softmax
*>
(
&
n
);
node
[
"softmax_axes"
]
=
tmp
->
get_axes
(
);
node
[
"softmax_axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
()
);
break
;
}
case
OP_TYPEID
:
:
Tan
:
{
break
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment