Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
603cbdab
Unverified
Commit
603cbdab
authored
Jun 26, 2019
by
Fenglei Tian
Committed by
GitHub
Jun 26, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into tfl/send_recv_op
parents
ac3743d3
79587d93
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
224 additions
and
146 deletions
+224
-146
fused_op_decomposition.cpp
src/ngraph/pass/fused_op_decomposition.cpp
+24
-14
fused_op_decomposition.hpp
src/ngraph/pass/fused_op_decomposition.hpp
+15
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+0
-1
serializer.cpp
src/ngraph/serializer.cpp
+185
-130
No files found.
src/ngraph/pass/fused_op_decomposition.cpp
View file @
603cbdab
...
@@ -13,36 +13,51 @@
...
@@ -13,36 +13,51 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
//*****************************************************************************
//*****************************************************************************
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/fused_op.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
bool
ngraph
::
pass
::
FusedOpDecomposition
::
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
pass
::
FusedOpDecomposition
::
FusedOpDecomposition
(
op_query_t
callback
)
:
m_has_direct_support
{
callback
}
{
}
bool
pass
::
FusedOpDecomposition
::
run_on_node
(
shared_ptr
<
Node
>
node
)
{
{
bool
modified
=
false
;
bool
modified
=
false
;
if
(
auto
fused_op
=
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
util
::
FusedOp
>
(
node
))
if
(
auto
fused_op
=
dynamic_pointer_cast
<
op
::
util
::
FusedOp
>
(
node
))
{
{
if
(
m_
callback
&&
m_callback
(
*
node
))
if
(
m_
has_direct_support
&&
m_has_direct_support
(
*
node
))
{
{
// Op supported by backend. Do not decompose
// Op supported by backend. Do not decompose
return
modified
;
return
modified
;
}
}
auto
subgraph_outputs
=
fused_op
->
decompose_op
();
auto
subgraph_outputs
=
fused_op
->
decompose_op
();
// Run recursively untill no more fused ops
auto
subgraph
=
extract_subgraph
(
subgraph_outputs
,
fused_op
->
get_arguments
());
for
(
auto
subgraph_node
:
subgraph
)
{
if
(
auto
nested_fused_op
=
dynamic_pointer_cast
<
op
::
util
::
FusedOp
>
(
subgraph_node
))
{
if
(
!
(
m_has_direct_support
&&
m_has_direct_support
(
*
nested_fused_op
)))
{
run_on_node
(
nested_fused_op
);
}
}
}
size_t
i
=
0
;
size_t
i
=
0
;
for
(
auto
output_node
:
subgraph_outputs
)
for
(
auto
output_node
:
subgraph_outputs
)
{
{
for
(
size_t
j
=
0
;
j
<
output_node
->
get_outputs
().
size
();
j
++
,
i
++
)
for
(
size_t
j
=
0
;
j
<
output_node
->
get_outputs
().
size
();
j
++
,
i
++
)
{
{
// TODO: Provenance
// TODO: Provenance
std
::
set
<
ngraph
::
descriptor
::
Input
*>
fop_users
{
set
<
descriptor
::
Input
*>
fop_users
{
begin
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
()),
begin
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
()),
end
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
())};
end
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
())};
for
(
auto
fop_user
:
fop_users
)
for
(
auto
fop_user
:
fop_users
)
{
{
...
@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
...
@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
if
(
goe
->
get_n
()
==
i
&&
!
goe
->
get_output_inputs
(
0
).
empty
())
if
(
goe
->
get_n
()
==
i
&&
!
goe
->
get_output_inputs
(
0
).
empty
())
{
{
// Replace GOE users
// Replace GOE users
s
td
::
set
<
ngraph
::
descriptor
::
Input
*>
goe_users
{
s
et
<
descriptor
::
Input
*>
goe_users
{
begin
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
()),
begin
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
()),
end
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
())};
end
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
())};
for
(
auto
goe_user
:
goe_users
)
for
(
auto
goe_user
:
goe_users
)
...
@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
...
@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
return
modified
;
return
modified
;
}
}
pass
::
FusedOpDecomposition
::
FusedOpDecomposition
(
op_query_t
callback
)
:
m_callback
{
callback
}
{
}
src/ngraph/pass/fused_op_decomposition.hpp
View file @
603cbdab
...
@@ -16,6 +16,9 @@
...
@@ -16,6 +16,9 @@
#pragma once
#pragma once
#include <memory>
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
...
@@ -25,13 +28,24 @@ namespace ngraph
...
@@ -25,13 +28,24 @@ namespace ngraph
class
FusedOpDecomposition
:
public
NodePass
class
FusedOpDecomposition
:
public
NodePass
{
{
public
:
public
:
/// \brief Function signature type for callback used to check whether provided node
/// is supported by backend.
using
op_query_t
=
std
::
function
<
bool
(
const
Node
&
node
)
>
;
using
op_query_t
=
std
::
function
<
bool
(
const
Node
&
node
)
>
;
///
/// \brief Constructor for the Fused operation decomposition pass.
///
/// \param[in] callback The function object used to determine whether current backend
/// provide direct support for passed node. Should have signature:
/// bool fn(const Node&)
///
FusedOpDecomposition
(
op_query_t
callback
=
nullptr
);
FusedOpDecomposition
(
op_query_t
callback
=
nullptr
);
bool
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
override
;
bool
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
override
;
private
:
private
:
op_query_t
m_callback
=
nullptr
;
/// \brief A function returning whether provided Node is supported by current backend.
/// The returned bool value is used to control whether decompose operator or not.
op_query_t
m_has_direct_support
=
nullptr
;
};
};
}
}
}
}
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
603cbdab
...
@@ -1180,7 +1180,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
...
@@ -1180,7 +1180,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS
(
RecurrentReshapeElimination
,
false
,
ngraph
::
pass
);
REGISTER_KNOBBED_PASS
(
RecurrentReshapeElimination
,
false
,
ngraph
::
pass
);
REGISTER_KNOBBED_PASS_WITH_ARGS
(
REGISTER_KNOBBED_PASS_WITH_ARGS
(
CoreFusion
,
true
,
ngraph
::
pass
,
ngraph
::
pass
::
FusionType
::
ALL_FUSIONS
);
CoreFusion
,
true
,
ngraph
::
pass
,
ngraph
::
pass
::
FusionType
::
ALL_FUSIONS
);
REGISTER_KNOBBED_PASS_WITH_ARGS
(
FusedOpDecomposition
,
true
,
ngraph
::
pass
,
is_supported
);
REGISTER_KNOBBED_PASS
(
CPUFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUQuantFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUQuantFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUHorizontalFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUHorizontalFusion
,
true
,
runtime
::
cpu
::
pass
);
...
...
src/ngraph/serializer.cpp
View file @
603cbdab
...
@@ -193,10 +193,15 @@ static OP_TYPEID get_typeid(const string& s)
...
@@ -193,10 +193,15 @@ static OP_TYPEID get_typeid(const string& s)
return
rc
;
return
rc
;
}
}
bool
has_key
(
json
j
,
const
std
::
string
&
key
)
{
return
j
.
count
(
key
)
!=
0
;
}
template
<
typename
T
>
template
<
typename
T
>
T
get_or_default
(
nlohmann
::
json
&
j
,
const
std
::
string
&
key
,
const
T
&
default_value
)
T
get_or_default
(
json
j
,
const
std
::
string
&
key
,
const
T
&
default_value
)
{
{
return
j
.
count
(
key
)
!=
0
?
j
.
at
(
key
).
get
<
T
>
()
:
default_value
;
return
has_key
(
j
,
key
)
?
j
.
at
(
key
).
get
<
T
>
()
:
default_value
;
}
}
class
JSONSerializer
class
JSONSerializer
...
@@ -215,8 +220,11 @@ public:
...
@@ -215,8 +220,11 @@ public:
json
serialize_function
(
const
Function
&
function
);
json
serialize_function
(
const
Function
&
function
);
json
serialize_output
(
const
Output
<
Node
>&
output
);
json
serialize_output
(
const
Output
<
Node
>&
output
);
json
serialize_parameter_vector
(
const
ParameterVector
&
parameters
);
json
serialize_output_vector
(
const
OutputVector
&
output_vector
);
json
serialize_node_reference
(
const
Node
&
node
);
json
serialize_node_reference
(
const
Node
&
node
);
json
serialize_node
(
const
Node
&
node
);
json
serialize_node
(
const
Node
&
node
);
json
serialize_axis_set
(
const
AxisSet
&
axis_set
);
protected
:
protected
:
size_t
m_indent
{
0
};
size_t
m_indent
{
0
};
...
@@ -235,10 +243,13 @@ public:
...
@@ -235,10 +243,13 @@ public:
m_const_data_callback
=
const_data_callback
;
m_const_data_callback
=
const_data_callback
;
}
}
shared_ptr
<
Function
>
deserialize_function
(
json
&
j
);
shared_ptr
<
Function
>
deserialize_function
(
json
j
);
Output
<
Node
>
deserialize_output
(
json
&
j
);
Output
<
Node
>
deserialize_output
(
json
j
);
shared_ptr
<
Node
>
deserialize_node_reference
(
json
&
j
);
OutputVector
deserialize_output_vector
(
json
j
);
shared_ptr
<
Node
>
deserialize_node
(
json
&
j
);
ParameterVector
deserialize_parameter_vector
(
json
j
);
shared_ptr
<
Node
>
deserialize_node_reference
(
json
j
);
shared_ptr
<
Node
>
deserialize_node
(
json
j
);
AxisSet
deserialize_axis_set
(
json
j
);
protected
:
protected
:
unordered_map
<
string
,
shared_ptr
<
Node
>>
m_node_map
;
unordered_map
<
string
,
shared_ptr
<
Node
>>
m_node_map
;
...
@@ -261,7 +272,7 @@ static json write_dimension(Dimension d)
...
@@ -261,7 +272,7 @@ static json write_dimension(Dimension d)
}
}
}
}
static
Dimension
read_dimension
(
const
json
&
j
)
static
Dimension
read_dimension
(
json
j
)
{
{
if
(
j
.
is_null
())
if
(
j
.
is_null
())
{
{
...
@@ -290,7 +301,7 @@ static json write_partial_shape(const PartialShape& s)
...
@@ -290,7 +301,7 @@ static json write_partial_shape(const PartialShape& s)
}
}
}
}
static
PartialShape
read_partial_shape
(
const
json
&
j
)
static
PartialShape
read_partial_shape
(
json
j
)
{
{
if
(
j
.
is_null
())
if
(
j
.
is_null
())
{
{
...
@@ -315,19 +326,32 @@ static json write_auto_broadcast(const op::AutoBroadcastSpec& autob)
...
@@ -315,19 +326,32 @@ static json write_auto_broadcast(const op::AutoBroadcastSpec& autob)
return
j
;
return
j
;
}
}
static
op
::
AutoBroadcastSpec
read_auto_broadcast
(
const
json
&
j
)
static
op
::
AutoBroadcastSpec
read_auto_broadcast
(
json
js_node
,
const
std
::
string
&
attr
)
{
{
if
(
!
j
.
is_object
(
))
if
(
has_key
(
js_node
,
attr
))
{
{
return
op
::
AutoBroadcastSpec
();
json
j
=
js_node
[
attr
];
return
op
::
AutoBroadcastSpec
(
static_cast
<
op
::
AutoBroadcastType
>
(
j
.
at
(
"type"
)),
j
.
at
(
"axis"
).
get
<
size_t
>
());
}
}
else
else
{
{
return
op
::
AutoBroadcastSpec
(
static_cast
<
op
::
AutoBroadcastType
>
(
j
.
at
(
"type"
)),
return
op
::
AutoBroadcastSpec
();
j
.
at
(
"axis"
).
get
<
size_t
>
());
}
}
}
}
static
op
::
PadType
read_pad_type
(
json
node_js
)
{
return
has_key
(
node_js
,
"pad_type"
)
?
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
))
:
op
::
PadType
::
EXPLICIT
;
}
static
op
::
PadMode
read_pad_mode
(
json
node_js
)
{
return
has_key
(
node_js
,
"pad_mode"
)
?
static_cast
<
op
::
PadMode
>
(
node_js
.
at
(
"pad_mode"
))
:
op
::
PadMode
::
CONSTANT
;
}
static
json
write_element_type
(
const
ngraph
::
element
::
Type
&
n
)
static
json
write_element_type
(
const
ngraph
::
element
::
Type
&
n
)
{
{
json
j
;
json
j
;
...
@@ -335,7 +359,7 @@ static json write_element_type(const ngraph::element::Type& n)
...
@@ -335,7 +359,7 @@ static json write_element_type(const ngraph::element::Type& n)
return
j
;
return
j
;
}
}
static
element
::
Type
read_element_type
(
const
json
&
j
)
static
element
::
Type
read_element_type
(
json
j
)
{
{
size_t
bitwidth
=
0
;
size_t
bitwidth
=
0
;
bool
is_real
=
false
;
bool
is_real
=
false
;
...
@@ -495,21 +519,24 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
...
@@ -495,21 +519,24 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
rc
=
deserializer
.
deserialize_function
(
func
);
rc
=
deserializer
.
deserialize_function
(
func
);
}
}
}
}
return
rc
;
return
rc
;
}
}
json
JSONSerializer
::
serialize_parameter_vector
(
const
ParameterVector
&
parameters
)
{
json
json_parameters
=
json
::
array
();
for
(
auto
param
:
parameters
)
{
json_parameters
.
push_back
(
serialize_node_reference
(
*
param
));
}
return
json_parameters
;
}
json
JSONSerializer
::
serialize_function
(
const
Function
&
f
)
json
JSONSerializer
::
serialize_function
(
const
Function
&
f
)
{
{
json
function
;
json
function
;
function
[
"name"
]
=
f
.
get_name
();
function
[
"name"
]
=
f
.
get_name
();
function
[
"parameters"
]
=
serialize_parameter_vector
(
f
.
get_parameters
());
vector
<
string
>
parameter_list
;
for
(
auto
param
:
f
.
get_parameters
())
{
parameter_list
.
push_back
(
serialize_node_reference
(
*
param
));
}
function
[
"parameters"
]
=
parameter_list
;
// TODO Functions can return multiple results
// TODO Functions can return multiple results
for
(
size_t
i
=
0
;
i
<
f
.
get_output_size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
f
.
get_output_size
();
++
i
)
...
@@ -521,7 +548,7 @@ json JSONSerializer::serialize_function(const Function& f)
...
@@ -521,7 +548,7 @@ json JSONSerializer::serialize_function(const Function& f)
}
}
template
<
typename
T
>
template
<
typename
T
>
T
get_value
(
nlohmann
::
json
js
,
const
string
&
key
)
T
get_value
(
json
js
,
const
string
&
key
)
{
{
T
rc
;
T
rc
;
auto
it
=
js
.
find
(
key
);
auto
it
=
js
.
find
(
key
);
...
@@ -532,13 +559,13 @@ T get_value(nlohmann::json js, const string& key)
...
@@ -532,13 +559,13 @@ T get_value(nlohmann::json js, const string& key)
return
rc
;
return
rc
;
}
}
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node_reference
(
json
&
j
)
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node_reference
(
json
j
)
{
{
const
string
&
name
=
j
;
const
string
&
name
=
j
;
return
m_node_map
.
at
(
name
);
return
m_node_map
.
at
(
name
);
}
}
Output
<
Node
>
JSONDeserializer
::
deserialize_output
(
json
&
j
)
Output
<
Node
>
JSONDeserializer
::
deserialize_output
(
json
j
)
{
{
size_t
index
;
size_t
index
;
json
json_node_reference
;
json
json_node_reference
;
...
@@ -559,10 +586,48 @@ Output<Node> JSONDeserializer::deserialize_output(json& j)
...
@@ -559,10 +586,48 @@ Output<Node> JSONDeserializer::deserialize_output(json& j)
return
Output
<
Node
>
(
deserialize_node_reference
(
json_node_reference
),
index
);
return
Output
<
Node
>
(
deserialize_node_reference
(
json_node_reference
),
index
);
}
}
shared_ptr
<
Function
>
JSONDeserializer
::
deserialize_function
(
json
&
func_js
)
OutputVector
JSONDeserializer
::
deserialize_output_vector
(
json
j
)
{
OutputVector
result
;
if
(
j
.
is_array
())
{
for
(
json
jelt
:
j
)
{
result
.
push_back
(
deserialize_output
(
jelt
));
}
}
return
result
;
}
json
JSONSerializer
::
serialize_axis_set
(
const
AxisSet
&
axis_set
)
{
return
static_cast
<
set
<
size_t
>>
(
axis_set
);
}
AxisSet
JSONDeserializer
::
deserialize_axis_set
(
json
j
)
{
AxisSet
result
;
if
(
j
.
is_array
())
{
result
=
j
.
get
<
set
<
size_t
>>
();
}
return
result
;
}
ParameterVector
JSONDeserializer
::
deserialize_parameter_vector
(
json
json_parameters
)
{
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
params
;
for
(
auto
&
param_ref
:
json_parameters
)
{
params
.
push_back
(
dynamic_pointer_cast
<
op
::
Parameter
>
(
deserialize_node_reference
(
param_ref
)));
}
return
params
;
}
shared_ptr
<
Function
>
JSONDeserializer
::
deserialize_function
(
json
func_js
)
{
{
string
func_name
=
func_js
.
at
(
"name"
).
get
<
string
>
();
string
func_name
=
func_js
.
at
(
"name"
).
get
<
string
>
();
vector
<
json
>
func_parameters
=
func_js
.
at
(
"parameters"
);
vector
<
json
>
func_result
=
func_js
.
at
(
"result"
);
vector
<
json
>
func_result
=
func_js
.
at
(
"result"
);
for
(
json
node_js
:
func_js
.
at
(
"ops"
))
for
(
json
node_js
:
func_js
.
at
(
"ops"
))
{
{
...
@@ -594,12 +659,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
...
@@ -594,12 +659,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
"Graph serialization is inconsistent. Some op::Results appear to be missing"
);
"Graph serialization is inconsistent. Some op::Results appear to be missing"
);
}
}
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
params
;
ParameterVector
params
=
deserialize_parameter_vector
(
func_js
.
at
(
"parameters"
));
for
(
auto
&
param_ref
:
func_parameters
)
{
params
.
push_back
(
dynamic_pointer_cast
<
op
::
Parameter
>
(
deserialize_node_reference
(
param_ref
)));
}
shared_ptr
<
Function
>
rc
{
make_shared
<
Function
>
(
result
,
params
,
func_name
)};
shared_ptr
<
Function
>
rc
{
make_shared
<
Function
>
(
result
,
params
,
func_name
)};
m_function_map
[
func_name
]
=
rc
;
m_function_map
[
func_name
]
=
rc
;
...
@@ -632,7 +692,12 @@ struct OutputHelper
...
@@ -632,7 +692,12 @@ struct OutputHelper
// when all op constructors use the new style arguments.
// when all op constructors use the new style arguments.
struct
OutputVectorHelper
struct
OutputVectorHelper
{
{
const
OutputHelper
&
operator
[](
size_t
i
)
const
{
return
m_vector
[
i
];
}
OutputVectorHelper
(
const
OutputVector
&
output_vector
)
:
m_vector
(
output_vector
)
{
}
OutputVectorHelper
()
=
default
;
OutputHelper
operator
[](
size_t
i
)
const
{
return
OutputHelper
(
m_vector
[
i
]);
}
void
push_back
(
const
Output
<
Node
>&
output
)
{
m_vector
.
push_back
(
output
);
}
void
push_back
(
const
Output
<
Node
>&
output
)
{
m_vector
.
push_back
(
output
);
}
size_t
size
()
const
{
return
m_vector
.
size
();
}
size_t
size
()
const
{
return
m_vector
.
size
();
}
operator
vector
<
shared_ptr
<
Node
>>
()
const
operator
vector
<
shared_ptr
<
Node
>>
()
const
...
@@ -640,14 +705,15 @@ struct OutputVectorHelper
...
@@ -640,14 +705,15 @@ struct OutputVectorHelper
vector
<
shared_ptr
<
Node
>>
result
;
vector
<
shared_ptr
<
Node
>>
result
;
for
(
auto
&
o
:
m_vector
)
for
(
auto
&
o
:
m_vector
)
{
{
result
.
push_back
(
o
);
result
.
push_back
(
OutputHelper
(
o
)
);
}
}
return
result
;
return
result
;
}
}
vector
<
OutputHelper
>
m_vector
;
operator
const
OutputVector
&
()
const
{
return
m_vector
;
}
OutputVector
m_vector
;
};
};
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node
(
json
&
node_js
)
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node
(
json
node_js
)
{
{
shared_ptr
<
Node
>
node
;
shared_ptr
<
Node
>
node
;
try
try
...
@@ -655,14 +721,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -655,14 +721,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
string
node_name
=
node_js
.
at
(
"name"
).
get
<
string
>
();
string
node_name
=
node_js
.
at
(
"name"
).
get
<
string
>
();
string
node_op
=
node_js
.
at
(
"op"
).
get
<
string
>
();
string
node_op
=
node_js
.
at
(
"op"
).
get
<
string
>
();
string
friendly_name
=
get_value
<
string
>
(
node_js
,
"friendly_name"
);
string
friendly_name
=
get_value
<
string
>
(
node_js
,
"friendly_name"
);
vector
<
json
>
node_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"inputs"
);
vector
<
json
>
control_deps_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"control_deps"
);
vector
<
json
>
control_deps_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"control_deps"
);
vector
<
string
>
node_outputs
=
get_value
<
vector
<
string
>>
(
node_js
,
"outputs"
);
vector
<
string
>
node_outputs
=
get_value
<
vector
<
string
>>
(
node_js
,
"outputs"
);
OutputVectorHelper
args
;
OutputVectorHelper
args
(
deserialize_output_vector
(
node_js
[
"inputs"
]));
for
(
auto
&
node_input
:
node_inputs
)
{
args
.
push_back
(
deserialize_output
(
node_input
));
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
...
@@ -683,12 +744,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -683,12 +744,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Add
:
case
OP_TYPEID
:
:
Add
:
{
{
node
=
make_shared
<
op
::
Add
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Add
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
All
:
case
OP_TYPEID
:
:
All
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
All
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
All
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
...
@@ -699,12 +760,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -699,12 +760,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
And
:
case
OP_TYPEID
:
:
And
:
{
{
node
=
make_shared
<
op
::
And
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
And
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Any
:
case
OP_TYPEID
:
:
Any
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Any
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Any
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
...
@@ -741,12 +802,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -741,12 +802,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto
padding_above
=
node_js
.
at
(
"padding_above"
).
get
<
vector
<
size_t
>>
();
auto
padding_above
=
node_js
.
at
(
"padding_above"
).
get
<
vector
<
size_t
>>
();
auto
include_padding_in_avg_computation
=
auto
include_padding_in_avg_computation
=
node_js
.
at
(
"include_padding_in_avg_computation"
).
get
<
bool
>
();
node_js
.
at
(
"include_padding_in_avg_computation"
).
get
<
bool
>
();
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
bool
ceil_mode
=
get_or_default
<
bool
>
(
node_js
,
"ceil_mode"
,
false
);
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
bool
ceil_mode
=
node_js
[
"ceil_mode"
].
empty
()
?
false
:
node_js
.
at
(
"ceil_mode"
).
get
<
bool
>
();
;
node
=
make_shared
<
op
::
AvgPool
>
(
args
[
0
],
node
=
make_shared
<
op
::
AvgPool
>
(
args
[
0
],
window_shape
,
window_shape
,
window_movement_strides
,
window_movement_strides
,
...
@@ -808,7 +865,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -808,7 +865,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Broadcast
:
case
OP_TYPEID
:
:
Broadcast
:
{
{
auto
shape
=
node_js
.
at
(
"shape"
).
get
<
vector
<
size_t
>>
();
auto
shape
=
node_js
.
at
(
"shape"
).
get
<
vector
<
size_t
>>
();
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
(
);
auto
axes
=
deserialize_axis_set
(
node_js
.
at
(
"axes"
)
);
node
=
make_shared
<
op
::
Broadcast
>
(
args
[
0
],
shape
,
axes
);
node
=
make_shared
<
op
::
Broadcast
>
(
args
[
0
],
shape
,
axes
);
break
;
break
;
}
}
...
@@ -819,7 +876,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -819,7 +876,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
BroadcastLike
:
case
OP_TYPEID
:
:
BroadcastLike
:
{
{
auto
initial_axes
=
node_js
.
at
(
"initial_axes"
).
get
<
set
<
size_t
>>
(
);
auto
initial_axes
=
deserialize_axis_set
(
node_js
.
at
(
"initial_axes"
)
);
node
=
make_shared
<
op
::
BroadcastLike
>
(
args
[
0
],
args
[
1
],
initial_axes
);
node
=
make_shared
<
op
::
BroadcastLike
>
(
args
[
0
],
args
[
1
],
initial_axes
);
break
;
break
;
}
}
...
@@ -838,13 +895,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -838,13 +895,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Concat
:
case
OP_TYPEID
:
:
Concat
:
{
{
auto
axis
=
node_js
.
at
(
"axis"
).
get
<
size_t
>
();
auto
axis
=
node_js
.
at
(
"axis"
).
get
<
size_t
>
();
node
=
make_shared
<
op
::
Concat
>
(
args
,
axis
);
node
=
make_shared
<
op
::
Concat
>
(
static_cast
<
OutputVector
>
(
args
)
,
axis
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Constant
:
case
OP_TYPEID
:
:
Constant
:
{
{
auto
type_node_js
=
auto
type_node_js
=
node_js
.
count
(
"element_type"
)
==
0
?
node_js
.
at
(
"value_type"
)
:
node_js
;
has_key
(
node_js
,
"element_type"
)
?
node_js
:
node_js
.
at
(
"value_type"
)
;
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
value
=
node_js
.
at
(
"value"
).
get
<
vector
<
string
>>
();
auto
value
=
node_js
.
at
(
"value"
).
get
<
vector
<
string
>>
();
...
@@ -868,17 +925,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -868,17 +925,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
// For backwards compatibility, we accept "image_dilation_strides" in place of
// For backwards compatibility, we accept "image_dilation_strides" in place of
// "data_dilation_strides", and we also allow it to be omitted altogether.
// "data_dilation_strides", and we also allow it to be omitted altogether.
auto
data_dilation_strides_maybe
=
node_js
[
"data_dilation_strides"
];
json
data_dilation_strides
;
if
(
data_dilation_strides_maybe
.
empty
())
if
(
has_key
(
node_js
,
"data_dilation_strides"
))
{
data_dilation_strides
=
node_js
[
"data_dilation_strides"
];
}
else
if
(
has_key
(
node_js
,
"image_dilation_strides"
))
{
{
data_dilation_strides
_maybe
=
node_js
[
"image_dilation_strides"
];
data_dilation_strides
=
node_js
[
"image_dilation_strides"
];
}
}
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
if
(
data_dilation_strides
_maybe
.
empty
())
if
(
data_dilation_strides
.
empty
())
{
{
node
=
make_shared
<
op
::
Convolution
>
(
args
[
0
],
node
=
make_shared
<
op
::
Convolution
>
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -889,14 +948,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -889,14 +948,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
else
else
{
{
node
=
make_shared
<
op
::
Convolution
>
(
node
=
args
[
0
],
make_shared
<
op
::
Convolution
>
(
args
[
0
],
args
[
1
],
args
[
1
],
window_movement_strides
,
window_movement_strides
,
window_dilation_strides
,
window_dilation_strides
,
padding_below
,
padding_below
,
padding_above
,
padding_above
,
data_dilation_strides_maybe
.
get
<
std
::
vector
<
size_t
>>
(),
data_dilation_strides
.
get
<
std
::
vector
<
size_t
>>
(),
pad_type
);
pad_type
);
}
}
break
;
break
;
...
@@ -1033,33 +1092,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1033,33 +1092,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Dequantize
:
case
OP_TYPEID
:
:
Dequantize
:
{
{
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
(
);
auto
axes
=
deserialize_axis_set
(
node_js
.
at
(
"axes"
)
);
node
=
make_shared
<
op
::
Dequantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
);
node
=
make_shared
<
op
::
Dequantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Divide
:
case
OP_TYPEID
:
:
Divide
:
{
{
bool
pythondiv
=
true
;
bool
pythondiv
=
get_or_default
(
node_js
,
"pythondiv"
,
true
);
if
(
node_js
[
"pythondiv"
].
is_object
())
{
pythondiv
=
node_js
.
at
(
"pythondiv"
).
get
<
bool
>
();
}
node
=
make_shared
<
op
::
Divide
>
(
node
=
make_shared
<
op
::
Divide
>
(
args
[
0
],
args
[
1
],
pythondiv
,
read_auto_broadcast
(
node_js
[
"autob"
]
));
args
[
0
],
args
[
1
],
pythondiv
,
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Dot
:
case
OP_TYPEID
:
:
Dot
:
{
{
// For backwards compatibility, reduction_axes_count is optional.
// For backwards compatibility, reduction_axes_count is optional.
auto
obj
=
node_js
[
"reduction_axes_count"
];
if
(
has_key
(
node_js
,
"reduction_axes_count"
))
if
(
obj
.
empty
())
{
{
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
]);
size_t
reduction_axes_count
=
node_js
[
"reduction_axes_count"
].
get
<
size_t
>
();
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
],
reduction_axes_count
);
}
}
else
else
{
{
size_t
reduction_axes_count
=
obj
.
get
<
size_t
>
();
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
]);
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
],
reduction_axes_count
);
}
}
break
;
break
;
}
}
...
@@ -1095,7 +1149,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1095,7 +1149,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Equal
:
case
OP_TYPEID
:
:
Equal
:
{
{
node
=
make_shared
<
op
::
Equal
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Equal
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Erf
:
case
OP_TYPEID
:
:
Erf
:
...
@@ -1160,13 +1214,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1160,13 +1214,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Greater
:
case
OP_TYPEID
:
:
Greater
:
{
{
node
=
node
=
make_shared
<
op
::
Greater
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Greater
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
GreaterEq
:
case
OP_TYPEID
:
:
GreaterEq
:
{
{
node
=
node
=
make_shared
<
op
::
GreaterEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
GreaterEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
GRN
:
case
OP_TYPEID
:
:
GRN
:
...
@@ -1193,10 +1247,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1193,10 +1247,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto
data_dilation_strides
=
node_js
.
at
(
"data_dilation_strides"
).
get
<
vector
<
size_t
>>
();
auto
data_dilation_strides
=
node_js
.
at
(
"data_dilation_strides"
).
get
<
vector
<
size_t
>>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
node
=
make_shared
<
op
::
GroupConvolution
>
(
args
[
0
],
node
=
make_shared
<
op
::
GroupConvolution
>
(
args
[
0
],
args
[
1
],
args
[
1
],
window_movement_strides
,
window_movement_strides
,
...
@@ -1216,9 +1267,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1216,9 +1267,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto
padding_end
=
node_js
.
at
(
"padding_end"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
padding_end
=
node_js
.
at
(
"padding_end"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
output_padding
=
node_js
.
at
(
"output_padding"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
output_padding
=
node_js
.
at
(
"output_padding"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
auto
output_shape
=
node_js
.
at
(
"output_shape"
).
get
<
vector
<
size_t
>>
();
auto
output_shape
=
node_js
.
at
(
"output_shape"
).
get
<
vector
<
size_t
>>
();
node
=
make_shared
<
op
::
GroupConvolutionTranspose
>
(
args
[
0
],
node
=
make_shared
<
op
::
GroupConvolutionTranspose
>
(
args
[
0
],
...
@@ -1240,12 +1289,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1240,12 +1289,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Less
:
case
OP_TYPEID
:
:
Less
:
{
{
node
=
make_shared
<
op
::
Less
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Less
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
LessEq
:
case
OP_TYPEID
:
:
LessEq
:
{
{
node
=
make_shared
<
op
::
LessEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
LessEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Log
:
case
OP_TYPEID
:
:
Log
:
...
@@ -1287,7 +1336,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1287,7 +1336,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Max
:
case
OP_TYPEID
:
:
Max
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Max
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Max
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
...
@@ -1298,11 +1347,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1298,11 +1347,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
// omitted.
auto
padding_below_maybe
=
node_js
[
"padding_below"
];
auto
padding_below_maybe
=
get_or_default
(
node_js
,
"padding_below"
,
json
{});
auto
padding_above_maybe
=
node_js
[
"padding_above"
];
auto
padding_above_maybe
=
get_or_default
(
node_js
,
"padding_above"
,
json
{});
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
if
(
padding_below_maybe
.
empty
()
&&
!
padding_above_maybe
.
empty
())
if
(
padding_below_maybe
.
empty
()
&&
!
padding_above_maybe
.
empty
())
{
{
throw
runtime_error
(
throw
runtime_error
(
...
@@ -1361,31 +1408,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1361,31 +1408,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Maximum
:
case
OP_TYPEID
:
:
Maximum
:
{
{
node
=
node
=
make_shared
<
op
::
Maximum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Maximum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Min
:
case
OP_TYPEID
:
:
Min
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Min
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Min
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Minimum
:
case
OP_TYPEID
:
:
Minimum
:
{
{
node
=
node
=
make_shared
<
op
::
Minimum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Minimum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Multiply
:
case
OP_TYPEID
:
:
Multiply
:
{
{
node
=
node
=
make_shared
<
op
::
Multiply
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Multiply
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
MVN
:
case
OP_TYPEID
:
:
MVN
:
{
{
auto
normalize_variance
=
node_js
.
at
(
"normalize_variance"
).
get
<
bool
>
();
auto
normalize_variance
=
node_js
.
at
(
"normalize_variance"
).
get
<
bool
>
();
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
auto
eps
=
node_js
.
at
(
"eps"
).
get
<
double
>
();
auto
eps
=
node_js
.
at
(
"eps"
).
get
<
double
>
();
node
=
make_shared
<
op
::
MVN
>
(
args
[
0
],
normalize_variance
,
normalize_variance
,
eps
);
node
=
make_shared
<
op
::
MVN
>
(
args
[
0
],
normalize_variance
,
normalize_variance
,
eps
);
break
;
break
;
...
@@ -1407,7 +1454,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1407,7 +1454,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
NotEqual
:
case
OP_TYPEID
:
:
NotEqual
:
{
{
node
=
node
=
make_shared
<
op
::
NotEqual
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
NotEqual
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Not
:
case
OP_TYPEID
:
:
Not
:
...
@@ -1424,7 +1471,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1424,7 +1471,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Or
:
case
OP_TYPEID
:
:
Or
:
{
{
node
=
make_shared
<
op
::
Or
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Or
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Pad
:
case
OP_TYPEID
:
:
Pad
:
...
@@ -1441,9 +1488,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1441,9 +1488,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
[](
size_t
s
)
{
return
s
==
0
;
}),
[](
size_t
s
)
{
return
s
==
0
;
}),
"Legacy padding_interior field must be zero everywhere."
);
"Legacy padding_interior field must be zero everywhere."
);
auto
pad_mode
=
node_js
.
count
(
"pad_mode"
)
==
0
auto
pad_mode
=
read_pad_mode
(
node_js
);
?
op
::
PadMode
::
CONSTANT
:
static_cast
<
op
::
PadMode
>
(
node_js
.
at
(
"pad_mode"
));
node
=
make_shared
<
op
::
Pad
>
(
args
[
0
],
args
[
1
],
padding_below
,
padding_above
,
pad_mode
);
node
=
make_shared
<
op
::
Pad
>
(
args
[
0
],
args
[
1
],
padding_below
,
padding_above
,
pad_mode
);
break
;
break
;
...
@@ -1451,7 +1496,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1451,7 +1496,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Parameter
:
case
OP_TYPEID
:
:
Parameter
:
{
{
auto
type_node_js
=
auto
type_node_js
=
node_js
.
count
(
"element_type"
)
==
0
?
node_js
.
at
(
"value_type"
)
:
node_js
;
has_key
(
node_js
,
"element_type"
)
?
node_js
:
node_js
.
at
(
"value_type"
)
;
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
cacheable
=
get_or_default
<
bool
>
(
node_js
,
"cacheable"
,
false
);
auto
cacheable
=
get_or_default
<
bool
>
(
node_js
,
"cacheable"
,
false
);
...
@@ -1476,7 +1521,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1476,7 +1521,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Power
:
case
OP_TYPEID
:
:
Power
:
{
{
node
=
make_shared
<
op
::
Power
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Power
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
PRelu
:
case
OP_TYPEID
:
:
PRelu
:
...
@@ -1486,14 +1531,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1486,14 +1531,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Product
:
case
OP_TYPEID
:
:
Product
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Product
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Product
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Quantize
:
case
OP_TYPEID
:
:
Quantize
:
{
{
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
(
);
auto
axes
=
deserialize_axis_set
(
node_js
.
at
(
"axes"
)
);
auto
round_mode
=
node_js
.
at
(
"round_mode"
).
get
<
op
::
Quantize
::
RoundMode
>
();
auto
round_mode
=
node_js
.
at
(
"round_mode"
).
get
<
op
::
Quantize
::
RoundMode
>
();
node
=
make_shared
<
op
::
Quantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
,
round_mode
);
node
=
make_shared
<
op
::
Quantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
,
round_mode
);
break
;
break
;
...
@@ -1552,8 +1597,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1552,8 +1597,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
// omitted.
auto
padding_below_maybe
=
node_js
[
"padding_below"
]
;
auto
padding_below_maybe
=
get_or_default
(
node_js
,
"padding_below"
,
json
{})
;
auto
padding_above_maybe
=
node_js
[
"padding_above"
]
;
auto
padding_above_maybe
=
get_or_default
(
node_js
,
"padding_above"
,
json
{})
;
auto
padding_below
=
padding_below_maybe
.
get
<
vector
<
size_t
>>
();
auto
padding_below
=
padding_below_maybe
.
get
<
vector
<
size_t
>>
();
auto
padding_above
=
padding_above_maybe
.
get
<
vector
<
size_t
>>
();
auto
padding_above
=
padding_above_maybe
.
get
<
vector
<
size_t
>>
();
node
=
make_shared
<
op
::
QuantizedMaxPool
>
(
node
=
make_shared
<
op
::
QuantizedMaxPool
>
(
...
@@ -1607,7 +1652,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1607,7 +1652,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Reverse
:
case
OP_TYPEID
:
:
Reverse
:
{
{
auto
reversed_axes
=
node_js
.
at
(
"reversed_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reversed_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reversed_axes"
)
);
node
=
make_shared
<
op
::
Reverse
>
(
args
[
0
],
reversed_axes
);
node
=
make_shared
<
op
::
Reverse
>
(
args
[
0
],
reversed_axes
);
break
;
break
;
}
}
...
@@ -1697,7 +1742,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1697,7 +1742,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Softmax
:
case
OP_TYPEID
:
:
Softmax
:
{
{
auto
softmax_axes
=
node_js
.
at
(
"softmax_axes"
).
get
<
set
<
size_t
>>
(
);
auto
softmax_axes
=
deserialize_axis_set
(
node_js
.
at
(
"softmax_axes"
)
);
node
=
make_shared
<
op
::
Softmax
>
(
args
[
0
],
softmax_axes
);
node
=
make_shared
<
op
::
Softmax
>
(
args
[
0
],
softmax_axes
);
break
;
break
;
}
}
...
@@ -1732,12 +1777,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1732,12 +1777,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Subtract
:
case
OP_TYPEID
:
:
Subtract
:
{
{
node
=
node
=
make_shared
<
op
::
Subtract
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Subtract
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sum
:
case
OP_TYPEID
:
:
Sum
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Sum
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Sum
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
...
@@ -1873,6 +1918,16 @@ json JSONSerializer::serialize_output(const Output<Node>& output)
...
@@ -1873,6 +1918,16 @@ json JSONSerializer::serialize_output(const Output<Node>& output)
return
result
;
return
result
;
}
}
json
JSONSerializer
::
serialize_output_vector
(
const
OutputVector
&
output_vector
)
{
json
result
;
for
(
const
Output
<
Node
>&
output
:
output_vector
)
{
result
.
push_back
(
serialize_output
(
output
));
}
return
result
;
}
json
JSONSerializer
::
serialize_node
(
const
Node
&
n
)
json
JSONSerializer
::
serialize_node
(
const
Node
&
n
)
{
{
m_nodes_serialized
.
insert
(
&
n
);
m_nodes_serialized
.
insert
(
&
n
);
...
@@ -1972,7 +2027,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -1972,7 +2027,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
All
:
case
OP_TYPEID
:
:
All
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
All
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
All
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
AllReduce
:
{
break
;
case
OP_TYPEID
:
:
AllReduce
:
{
break
;
...
@@ -1989,7 +2044,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -1989,7 +2044,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Any
:
case
OP_TYPEID
:
:
Any
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Any
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Any
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Asin
:
{
break
;
case
OP_TYPEID
:
:
Asin
:
{
break
;
...
@@ -2045,7 +2100,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2045,7 +2100,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Broadcast
:
case
OP_TYPEID
:
:
Broadcast
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Broadcast
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Broadcast
*>
(
&
n
);
node
[
"axes"
]
=
tmp
->
get_broadcast_axes
(
);
node
[
"axes"
]
=
serialize_axis_set
(
tmp
->
get_broadcast_axes
()
);
node
[
"shape"
]
=
tmp
->
get_broadcast_shape
();
node
[
"shape"
]
=
tmp
->
get_broadcast_shape
();
break
;
break
;
}
}
...
@@ -2054,7 +2109,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2054,7 +2109,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
BroadcastLike
:
case
OP_TYPEID
:
:
BroadcastLike
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
BroadcastLike
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
BroadcastLike
*>
(
&
n
);
node
[
"initial_axes"
]
=
tmp
->
get_initial_broadcast_axes
(
);
node
[
"initial_axes"
]
=
serialize_axis_set
(
tmp
->
get_initial_broadcast_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Ceiling
:
{
break
;
case
OP_TYPEID
:
:
Ceiling
:
{
break
;
...
@@ -2168,7 +2223,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2168,7 +2223,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Dequantize
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Dequantize
*>
(
&
n
);
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"axes"
]
=
tmp
->
get_axes
(
);
node
[
"axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
DepthToSpace
:
case
OP_TYPEID
:
:
DepthToSpace
:
...
@@ -2361,7 +2416,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2361,7 +2416,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Max
:
case
OP_TYPEID
:
:
Max
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Max
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Max
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
MaxPool
:
case
OP_TYPEID
:
:
MaxPool
:
...
@@ -2395,7 +2450,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2395,7 +2450,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Min
:
case
OP_TYPEID
:
:
Min
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Min
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Min
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Minimum
:
case
OP_TYPEID
:
:
Minimum
:
...
@@ -2419,7 +2474,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2419,7 +2474,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
MVN
:
case
OP_TYPEID
:
:
MVN
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
MVN
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
MVN
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
node
[
"normalize_variance"
]
=
tmp
->
get_normalize_variance
();
node
[
"normalize_variance"
]
=
tmp
->
get_normalize_variance
();
node
[
"eps"
]
=
tmp
->
get_eps
();
node
[
"eps"
]
=
tmp
->
get_eps
();
break
;
break
;
...
@@ -2499,7 +2554,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2499,7 +2554,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Product
:
case
OP_TYPEID
:
:
Product
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Product
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Product
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Power
:
case
OP_TYPEID
:
:
Power
:
...
@@ -2515,7 +2570,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2515,7 +2570,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Quantize
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Quantize
*>
(
&
n
);
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"axes"
]
=
tmp
->
get_axes
(
);
node
[
"axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
()
);
node
[
"round_mode"
]
=
tmp
->
get_round_mode
();
node
[
"round_mode"
]
=
tmp
->
get_round_mode
();
break
;
break
;
}
}
...
@@ -2596,7 +2651,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2596,7 +2651,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Reverse
:
case
OP_TYPEID
:
:
Reverse
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Reverse
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Reverse
*>
(
&
n
);
node
[
"reversed_axes"
]
=
tmp
->
get_reversed_axes
(
);
node
[
"reversed_axes"
]
=
serialize_axis_set
(
tmp
->
get_reversed_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ReverseSequence
:
case
OP_TYPEID
:
:
ReverseSequence
:
...
@@ -2689,13 +2744,13 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2689,13 +2744,13 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Sum
:
case
OP_TYPEID
:
:
Sum
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Sum
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Sum
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Softmax
:
case
OP_TYPEID
:
:
Softmax
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Softmax
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Softmax
*>
(
&
n
);
node
[
"softmax_axes"
]
=
tmp
->
get_axes
(
);
node
[
"softmax_axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Tan
:
{
break
;
case
OP_TYPEID
:
:
Tan
:
{
break
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment