Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
24c715f4
Unverified
Commit
24c715f4
authored
Jun 26, 2019
by
Scott Cyphers
Committed by
GitHub
Jun 26, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into mlir
parents
3b5bfdab
5e19c25c
Show whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
556 additions
and
214 deletions
+556
-214
types.py
python/ngraph/utils/types.py
+1
-0
element_type.cpp
python/pyngraph/types/element_type.cpp
+1
-0
CMakeLists.txt
src/ngraph/frontend/onnx_import/CMakeLists.txt
+3
-0
node.cpp
src/ngraph/frontend/onnx_import/core/node.cpp
+16
-0
node.hpp
src/ngraph/frontend/onnx_import/core/node.hpp
+2
-0
value_info.hpp
src/ngraph/frontend/onnx_import/core/value_info.hpp
+5
-29
cast.cpp
src/ngraph/frontend/onnx_import/op/cast.cpp
+2
-21
eye_like.cpp
src/ngraph/frontend/onnx_import/op/eye_like.cpp
+63
-0
eye_like.hpp
src/ngraph/frontend/onnx_import/op/eye_like.hpp
+37
-0
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+2
-0
common.cpp
src/ngraph/frontend/onnx_import/utils/common.cpp
+51
-0
common.hpp
src/ngraph/frontend/onnx_import/utils/common.hpp
+42
-8
node.hpp
src/ngraph/node.hpp
+1
-0
generate_mask.hpp
src/ngraph/op/experimental/generate_mask.hpp
+2
-0
fused_op_tbl.hpp
src/ngraph/op/fused_op_tbl.hpp
+3
-3
op_tbl.hpp
src/ngraph/op/op_tbl.hpp
+5
-5
fused_op_decomposition.cpp
src/ngraph/pass/fused_op_decomposition.cpp
+24
-14
fused_op_decomposition.hpp
src/ngraph/pass/fused_op_decomposition.hpp
+15
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+3
-3
serializer.cpp
src/ngraph/serializer.cpp
+185
-130
eye_like.prototxt
test/models/onnx/eye_like.prototxt
+50
-0
onnx_import.in.cpp
test/onnx/onnx_import.in.cpp
+43
-0
No files found.
python/ngraph/utils/types.py
View file @
24c715f4
...
@@ -37,6 +37,7 @@ NodeInput = Union[Node, NumericData]
...
@@ -37,6 +37,7 @@ NodeInput = Union[Node, NumericData]
ngraph_to_numpy_types_map
=
[
ngraph_to_numpy_types_map
=
[
(
NgraphType
.
boolean
,
np
.
bool
),
(
NgraphType
.
boolean
,
np
.
bool
),
(
NgraphType
.
f16
,
np
.
float16
),
(
NgraphType
.
f32
,
np
.
float32
),
(
NgraphType
.
f32
,
np
.
float32
),
(
NgraphType
.
f64
,
np
.
float64
),
(
NgraphType
.
f64
,
np
.
float64
),
(
NgraphType
.
i8
,
np
.
int8
),
(
NgraphType
.
i8
,
np
.
int8
),
...
...
python/pyngraph/types/element_type.cpp
View file @
24c715f4
...
@@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
...
@@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
py
::
class_
<
ngraph
::
element
::
Type
,
std
::
shared_ptr
<
ngraph
::
element
::
Type
>>
type
(
m
,
"Type"
);
py
::
class_
<
ngraph
::
element
::
Type
,
std
::
shared_ptr
<
ngraph
::
element
::
Type
>>
type
(
m
,
"Type"
);
type
.
doc
()
=
"ngraph.impl.Type wraps ngraph::element::Type"
;
type
.
doc
()
=
"ngraph.impl.Type wraps ngraph::element::Type"
;
type
.
attr
(
"boolean"
)
=
ngraph
::
element
::
boolean
;
type
.
attr
(
"boolean"
)
=
ngraph
::
element
::
boolean
;
type
.
attr
(
"f16"
)
=
ngraph
::
element
::
f16
;
type
.
attr
(
"f32"
)
=
ngraph
::
element
::
f32
;
type
.
attr
(
"f32"
)
=
ngraph
::
element
::
f32
;
type
.
attr
(
"f64"
)
=
ngraph
::
element
::
f64
;
type
.
attr
(
"f64"
)
=
ngraph
::
element
::
f64
;
type
.
attr
(
"i8"
)
=
ngraph
::
element
::
i8
;
type
.
attr
(
"i8"
)
=
ngraph
::
element
::
i8
;
...
...
src/ngraph/frontend/onnx_import/CMakeLists.txt
View file @
24c715f4
...
@@ -85,6 +85,8 @@ add_library(onnx_import STATIC
...
@@ -85,6 +85,8 @@ add_library(onnx_import STATIC
op/equal.hpp
op/equal.hpp
op/erf.hpp
op/erf.hpp
op/exp.hpp
op/exp.hpp
op/eye_like.cpp
op/eye_like.hpp
op/flatten.cpp
op/flatten.cpp
op/flatten.hpp
op/flatten.hpp
op/floor.hpp
op/floor.hpp
...
@@ -191,6 +193,7 @@ add_library(onnx_import STATIC
...
@@ -191,6 +193,7 @@ add_library(onnx_import STATIC
op/xor.hpp
op/xor.hpp
ops_bridge.cpp
ops_bridge.cpp
ops_bridge.hpp
ops_bridge.hpp
utils/common.cpp
utils/common.hpp
utils/common.hpp
utils/convpool.cpp
utils/convpool.cpp
utils/convpool.hpp
utils/convpool.hpp
...
...
src/ngraph/frontend/onnx_import/core/node.cpp
View file @
24c715f4
...
@@ -52,6 +52,8 @@ namespace ngraph
...
@@ -52,6 +52,8 @@ namespace ngraph
const
std
::
string
&
output
(
int
index
)
const
;
const
std
::
string
&
output
(
int
index
)
const
;
std
::
size_t
get_outputs_size
()
const
;
std
::
size_t
get_outputs_size
()
const
;
bool
has_attribute
(
const
std
::
string
&
name
)
const
;
template
<
typename
T
>
template
<
typename
T
>
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
...
@@ -87,6 +89,15 @@ namespace ngraph
...
@@ -87,6 +89,15 @@ namespace ngraph
}
}
std
::
size_t
Node
::
Impl
::
get_outputs_size
()
const
{
return
m_output_names
.
size
();
}
std
::
size_t
Node
::
Impl
::
get_outputs_size
()
const
{
return
m_output_names
.
size
();
}
bool
Node
::
Impl
::
has_attribute
(
const
std
::
string
&
name
)
const
{
auto
it
=
std
::
find_if
(
std
::
begin
(
m_attributes
),
std
::
end
(
m_attributes
),
[
&
](
const
Attribute
&
attribute
)
{
return
attribute
.
get_name
()
==
name
;
});
return
it
!=
std
::
end
(
m_attributes
);
}
template
<
typename
T
>
template
<
typename
T
>
T
Node
::
Impl
::
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
T
Node
::
Impl
::
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
{
{
...
@@ -185,6 +196,11 @@ namespace ngraph
...
@@ -185,6 +196,11 @@ namespace ngraph
const
std
::
string
&
Node
::
output
(
int
index
)
const
{
return
m_pimpl
->
output
(
index
);
}
const
std
::
string
&
Node
::
output
(
int
index
)
const
{
return
m_pimpl
->
output
(
index
);
}
std
::
size_t
Node
::
get_outputs_size
()
const
{
return
m_pimpl
->
get_outputs_size
();
}
std
::
size_t
Node
::
get_outputs_size
()
const
{
return
m_pimpl
->
get_outputs_size
();
}
bool
Node
::
has_attribute
(
const
std
::
string
&
name
)
const
{
return
m_pimpl
->
has_attribute
(
name
);
}
template
<>
template
<>
float
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
float
default_value
)
const
float
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
float
default_value
)
const
{
{
...
...
src/ngraph/frontend/onnx_import/core/node.hpp
View file @
24c715f4
...
@@ -78,6 +78,8 @@ namespace ngraph
...
@@ -78,6 +78,8 @@ namespace ngraph
const
std
::
string
&
output
(
int
index
)
const
;
const
std
::
string
&
output
(
int
index
)
const
;
std
::
size_t
get_outputs_size
()
const
;
std
::
size_t
get_outputs_size
()
const
;
bool
has_attribute
(
const
std
::
string
&
name
)
const
;
template
<
typename
T
>
template
<
typename
T
>
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
...
...
src/ngraph/frontend/onnx_import/core/value_info.hpp
View file @
24c715f4
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type.hpp"
#include "node.hpp"
#include "node.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
#include "utils/common.hpp"
#include "weight.hpp"
#include "weight.hpp"
namespace
ngraph
namespace
ngraph
...
@@ -41,17 +42,8 @@ namespace ngraph
...
@@ -41,17 +42,8 @@ namespace ngraph
{
{
}
}
};
};
struct
unsupported_element_type
:
ngraph_error
}
// namespace value_info
{
}
// namespace error
explicit
unsupported_element_type
(
TensorProto_DataType
type
)
:
ngraph_error
{
"unsupported value info element type: "
+
onnx
::
TensorProto_DataType_Name
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
type
))}
{
}
};
}
}
class
ValueInfo
class
ValueInfo
{
{
...
@@ -83,24 +75,8 @@ namespace ngraph
...
@@ -83,24 +75,8 @@ namespace ngraph
{
{
throw
error
::
value_info
::
unspecified_element_type
{};
throw
error
::
value_info
::
unspecified_element_type
{};
}
}
switch
(
m_value_info_proto
->
type
().
tensor_type
().
elem_type
())
return
common
::
get_ngraph_element_type
(
{
m_value_info_proto
->
type
().
tensor_type
().
elem_type
());
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT
:
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT16
:
return
element
::
f32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_DOUBLE
:
return
element
::
f64
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT8
:
return
element
::
i8
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT16
:
return
element
::
i16
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT32
:
return
element
::
i32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT64
:
return
element
::
i64
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT8
:
return
element
::
u8
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT16
:
return
element
::
u16
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT32
:
return
element
::
u32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT64
:
return
element
::
u64
;
default
:
throw
error
::
value_info
::
unsupported_element_type
{
m_value_info_proto
->
type
().
tensor_type
().
elem_type
()};
}
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
std
::
shared_ptr
<
ngraph
::
Node
>
...
...
src/ngraph/frontend/onnx_import/op/cast.cpp
View file @
24c715f4
...
@@ -13,14 +13,13 @@
...
@@ -13,14 +13,13 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
//*****************************************************************************
//*****************************************************************************
#include <memory>
#include <memory>
#include <onnx-ml.pb.h>
#include "cast.hpp"
#include "cast.hpp"
#include "exceptions.hpp"
#include "exceptions.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -34,25 +33,7 @@ namespace ngraph
...
@@ -34,25 +33,7 @@ namespace ngraph
{
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
int64_t
target_type
=
node
.
get_attribute_value
<
int64_t
>
(
"to"
);
int64_t
target_type
=
node
.
get_attribute_value
<
int64_t
>
(
"to"
);
element
::
Type
elem_type
;
element
::
Type
elem_type
=
common
::
get_ngraph_element_type
(
target_type
);
switch
(
target_type
)
{
case
onnx
:
:
TensorProto_DataType_BOOL
:
elem_type
=
element
::
boolean
;
break
;
case
onnx
:
:
TensorProto_DataType_DOUBLE
:
elem_type
=
element
::
f64
;
break
;
case
onnx
:
:
TensorProto_DataType_FLOAT16
:
case
onnx
:
:
TensorProto_DataType_FLOAT
:
elem_type
=
element
::
f32
;
break
;
case
onnx
:
:
TensorProto_DataType_INT8
:
elem_type
=
element
::
i8
;
break
;
case
onnx
:
:
TensorProto_DataType_INT16
:
elem_type
=
element
::
i16
;
break
;
case
onnx
:
:
TensorProto_DataType_INT32
:
elem_type
=
element
::
i32
;
break
;
case
onnx
:
:
TensorProto_DataType_INT64
:
elem_type
=
element
::
i64
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT8
:
elem_type
=
element
::
u8
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT16
:
elem_type
=
element
::
u16
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT32
:
elem_type
=
element
::
u32
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT64
:
elem_type
=
element
::
u64
;
break
;
case
onnx
:
:
TensorProto_DataType_UNDEFINED
:
elem_type
=
element
::
dynamic
;
break
;
default
:
ASSERT_IS_SUPPORTED
(
node
,
false
)
<<
"unsupported type"
;
}
return
{
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
data
,
elem_type
)};
return
{
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
data
,
elem_type
)};
}
}
...
...
src/ngraph/frontend/onnx_import/op/eye_like.cpp
0 → 100644
View file @
24c715f4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "eye_like.hpp"
#include "exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/common.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
eye_like
(
const
Node
&
node
)
{
const
auto
input
=
node
.
get_ng_inputs
().
at
(
0
);
const
auto
&
input_shape
=
input
->
get_shape
();
std
::
int64_t
dtype
;
element
::
Type
target_type
;
std
::
int64_t
shift
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"k"
,
0
);
if
(
node
.
has_attribute
(
"dtype"
))
{
dtype
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"dtype"
);
target_type
=
common
::
get_ngraph_element_type
(
dtype
);
}
else
{
target_type
=
input
->
get_element_type
();
}
ASSERT_VALID_ARGUMENT
(
node
,
input_shape
.
size
()
==
2
)
<<
"The provided shape rank: "
<<
input_shape
.
size
()
<<
" is unsupported, only 2D shapes are supported"
;
std
::
shared_ptr
<
ngraph
::
Node
>
eye_like_matrix
=
common
::
shifted_square_identity
(
input_shape
,
target_type
,
shift
);
return
{
eye_like_matrix
};
}
}
// namespace set_1
}
// namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/op/eye_like.hpp
0 → 100644
View file @
24c715f4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
eye_like
(
const
Node
&
node
);
}
// namespace set_1
}
//namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
24c715f4
...
@@ -53,6 +53,7 @@
...
@@ -53,6 +53,7 @@
#include "op/equal.hpp"
#include "op/equal.hpp"
#include "op/erf.hpp"
#include "op/erf.hpp"
#include "op/exp.hpp"
#include "op/exp.hpp"
#include "op/eye_like.hpp"
#include "op/flatten.hpp"
#include "op/flatten.hpp"
#include "op/floor.hpp"
#include "op/floor.hpp"
#include "op/gather.hpp"
#include "op/gather.hpp"
...
@@ -260,6 +261,7 @@ namespace ngraph
...
@@ -260,6 +261,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Equal"
,
1
,
equal
);
REGISTER_OPERATOR
(
"Equal"
,
1
,
equal
);
REGISTER_OPERATOR
(
"Erf"
,
1
,
erf
);
REGISTER_OPERATOR
(
"Erf"
,
1
,
erf
);
REGISTER_OPERATOR
(
"Exp"
,
1
,
exp
);
REGISTER_OPERATOR
(
"Exp"
,
1
,
exp
);
REGISTER_OPERATOR
(
"EyeLike"
,
1
,
eye_like
);
REGISTER_OPERATOR
(
"Flatten"
,
1
,
flatten
);
REGISTER_OPERATOR
(
"Flatten"
,
1
,
flatten
);
REGISTER_OPERATOR
(
"Floor"
,
1
,
floor
);
REGISTER_OPERATOR
(
"Floor"
,
1
,
floor
);
REGISTER_OPERATOR
(
"Gather"
,
1
,
gather
);
REGISTER_OPERATOR
(
"Gather"
,
1
,
gather
);
...
...
src/ngraph/frontend/onnx_import/utils/common.cpp
0 → 100644
View file @
24c715f4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <onnx-ml.pb.h> // onnx types
#include "common.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
common
{
const
ngraph
::
element
::
Type
&
get_ngraph_element_type
(
int64_t
onnx_type
)
{
switch
(
onnx_type
)
{
case
onnx
:
:
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType_DOUBLE
:
return
element
::
f64
;
case
onnx
:
:
TensorProto_DataType_FLOAT16
:
return
element
::
f16
;
case
onnx
:
:
TensorProto_DataType_FLOAT
:
return
element
::
f32
;
case
onnx
:
:
TensorProto_DataType_INT8
:
return
element
::
i8
;
case
onnx
:
:
TensorProto_DataType_INT16
:
return
element
::
i16
;
case
onnx
:
:
TensorProto_DataType_INT32
:
return
element
::
i32
;
case
onnx
:
:
TensorProto_DataType_INT64
:
return
element
::
i64
;
case
onnx
:
:
TensorProto_DataType_UINT8
:
return
element
::
u8
;
case
onnx
:
:
TensorProto_DataType_UINT16
:
return
element
::
u16
;
case
onnx
:
:
TensorProto_DataType_UINT32
:
return
element
::
u32
;
case
onnx
:
:
TensorProto_DataType_UINT64
:
return
element
::
u64
;
case
onnx
:
:
TensorProto_DataType_UNDEFINED
:
return
element
::
dynamic
;
}
throw
ngraph_error
(
"unsupported element type: "
+
onnx
::
TensorProto_DataType_Name
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
onnx_type
)));
}
}
// namespace common
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/utils/common.hpp
View file @
24c715f4
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <algorithm> // std::generate
#include <algorithm> // std::generate
#include <cmath> // std::floor, std::min
#include <cmath> // std::floor, std::min
#include <cstddef> // std::size_t
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <iterator> // std::begin, std::end
#include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared
#include <memory> // std::shared_ptr, std::make_shared
#include <type_traits> // std::enable_if
#include <type_traits> // std::enable_if
...
@@ -27,6 +28,7 @@
...
@@ -27,6 +28,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -34,6 +36,8 @@ namespace ngraph
...
@@ -34,6 +36,8 @@ namespace ngraph
{
{
namespace
common
namespace
common
{
{
const
ngraph
::
element
::
Type
&
get_ngraph_element_type
(
std
::
int64_t
onnx_type
);
/// \brief Return a monotonic sequence.
/// \brief Return a monotonic sequence.
///
///
/// \note Limitations: this function may not work for very large integer values
/// \note Limitations: this function may not work for very large integer values
...
@@ -86,6 +90,43 @@ namespace ngraph
...
@@ -86,6 +90,43 @@ namespace ngraph
}
}
}
}
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main diagonal.
/// Shifting adds an offset and moves the diagonal up or down
///
/// \param[in] output_shape Shape of the resulting matrix.
/// \param[in] output_type Element type of the resulting matrix.
/// \param[in] shift Shifting of diagonal.
///
/// \return A Constant node representing shifted identity matrix.
template
<
typename
T
=
double
>
std
::
shared_ptr
<
ngraph
::
op
::
Constant
>
shifted_square_identity
(
const
Shape
output_shape
,
const
element
::
Type
&
output_type
,
const
std
::
int64_t
shift
)
{
std
::
vector
<
T
>
identity_matrix
(
shape_size
(
output_shape
),
T
{
0
});
std
::
int64_t
rows
=
output_shape
[
0
];
std
::
int64_t
cols
=
output_shape
[
1
];
for
(
std
::
int64_t
row
=
0
;
row
<
rows
;
++
row
)
{
const
std
::
int64_t
diagonal_element_idx
=
(
row
*
cols
)
+
row
+
shift
;
if
(
row
+
shift
<
0
)
{
continue
;
}
else
if
(
row
+
shift
>=
cols
)
{
break
;
}
identity_matrix
.
at
(
diagonal_element_idx
)
=
T
{
1
};
}
return
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
output_type
,
output_shape
,
identity_matrix
);
}
/// \brief Creates a square identity matrix.
/// \brief Creates a square identity matrix.
///
///
/// \param[in] n Order of the resulting matrix.
/// \param[in] n Order of the resulting matrix.
...
@@ -95,16 +136,9 @@ namespace ngraph
...
@@ -95,16 +136,9 @@ namespace ngraph
std
::
shared_ptr
<
ngraph
::
op
::
Constant
>
square_identity
(
const
size_t
n
,
std
::
shared_ptr
<
ngraph
::
op
::
Constant
>
square_identity
(
const
size_t
n
,
const
element
::
Type
&
type
)
const
element
::
Type
&
type
)
{
{
std
::
vector
<
T
>
identity_matrix
(
n
*
n
,
T
{
0
});
return
shifted_square_identity
(
Shape
{
n
,
n
},
type
,
0
);
for
(
size_t
row
=
0
;
row
<
n
;
++
row
)
{
const
size_t
diagonal_element
=
(
n
*
row
)
+
row
;
identity_matrix
.
at
(
diagonal_element
)
=
T
{
1
};
}
}
return
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
type
,
Shape
{{
n
,
n
}},
identity_matrix
);
}
}
// namespace common
}
// namespace common
}
// namespace onnx_import
}
// namespace onnx_import
}
// namespace ngraph
}
// namespace ngraph
src/ngraph/node.hpp
View file @
24c715f4
...
@@ -216,6 +216,7 @@ namespace ngraph
...
@@ -216,6 +216,7 @@ namespace ngraph
virtual
bool
is_op
()
const
{
return
false
;
}
virtual
bool
is_op
()
const
{
return
false
;
}
virtual
bool
is_commutative
()
{
return
false
;
}
virtual
bool
is_commutative
()
{
return
false
;
}
virtual
bool
is_dynamic
()
const
;
virtual
bool
is_dynamic
()
const
;
virtual
bool
has_state
()
const
{
return
false
;
}
size_t
get_instance_id
()
const
{
return
m_instance_id
;
}
size_t
get_instance_id
()
const
{
return
m_instance_id
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
virtual
std
::
ostream
&
write_short_description
(
std
::
ostream
&
)
const
;
virtual
std
::
ostream
&
write_short_description
(
std
::
ostream
&
)
const
;
...
...
src/ngraph/op/experimental/generate_mask.hpp
View file @
24c715f4
...
@@ -72,6 +72,8 @@ namespace ngraph
...
@@ -72,6 +72,8 @@ namespace ngraph
/// \brief Returns the seed value supplied to a random generator
/// \brief Returns the seed value supplied to a random generator
uint64_t
get_seed
()
const
{
return
m_seed
;
}
uint64_t
get_seed
()
const
{
return
m_seed
;
}
bool
get_use_seed
()
const
{
return
m_use_seed
;
}
bool
get_use_seed
()
const
{
return
m_use_seed
;
}
/// GenerateMask has state.
bool
has_state
()
const
override
{
return
true
;
}
protected
:
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
const
NodeVector
&
deltas
)
override
...
...
src/ngraph/op/fused_op_tbl.hpp
View file @
24c715f4
...
@@ -24,8 +24,8 @@ NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
...
@@ -24,8 +24,8 @@ NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP
(
DepthToSpace
,
ngraph
::
op
)
NGRAPH_OP
(
DepthToSpace
,
ngraph
::
op
)
NGRAPH_OP
(
Elu
,
ngraph
::
op
)
NGRAPH_OP
(
Elu
,
ngraph
::
op
)
NGRAPH_OP
(
FakeQuantize
,
ngraph
::
op
)
NGRAPH_OP
(
FakeQuantize
,
ngraph
::
op
)
NGRAPH_OP
(
GRN
,
ngraph
::
op
)
NGRAPH_OP
(
Gemm
,
ngraph
::
op
)
NGRAPH_OP
(
Gemm
,
ngraph
::
op
)
NGRAPH_OP
(
GRN
,
ngraph
::
op
)
NGRAPH_OP
(
GroupConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
GroupConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
GroupConvolutionTranspose
,
ngraph
::
op
)
NGRAPH_OP
(
GroupConvolutionTranspose
,
ngraph
::
op
)
NGRAPH_OP
(
HardSigmoid
,
ngraph
::
op
)
NGRAPH_OP
(
HardSigmoid
,
ngraph
::
op
)
...
@@ -35,9 +35,9 @@ NGRAPH_OP(MVN, ngraph::op)
...
@@ -35,9 +35,9 @@ NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP
(
Normalize
,
ngraph
::
op
)
NGRAPH_OP
(
Normalize
,
ngraph
::
op
)
NGRAPH_OP
(
PRelu
,
ngraph
::
op
)
NGRAPH_OP
(
PRelu
,
ngraph
::
op
)
NGRAPH_OP
(
ScaleShift
,
ngraph
::
op
)
NGRAPH_OP
(
ScaleShift
,
ngraph
::
op
)
NGRAPH_OP
(
SpaceToDepth
,
ngraph
::
op
)
NGRAPH_OP
(
ShuffleChannels
,
ngraph
::
op
)
NGRAPH_OP
(
ShuffleChannels
,
ngraph
::
op
)
NGRAPH_OP
(
SpaceToDepth
,
ngraph
::
op
)
NGRAPH_OP
(
Split
,
ngraph
::
op
)
NGRAPH_OP
(
SquaredDifference
,
ngraph
::
op
)
NGRAPH_OP
(
SquaredDifference
,
ngraph
::
op
)
NGRAPH_OP
(
Squeeze
,
ngraph
::
op
)
NGRAPH_OP
(
Squeeze
,
ngraph
::
op
)
NGRAPH_OP
(
Split
,
ngraph
::
op
)
NGRAPH_OP
(
Unsqueeze
,
ngraph
::
op
)
NGRAPH_OP
(
Unsqueeze
,
ngraph
::
op
)
src/ngraph/op/op_tbl.hpp
View file @
24c715f4
...
@@ -81,11 +81,12 @@ NGRAPH_OP(Cos, ngraph::op)
...
@@ -81,11 +81,12 @@ NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP
(
Cosh
,
ngraph
::
op
)
NGRAPH_OP
(
Cosh
,
ngraph
::
op
)
NGRAPH_OP
(
Dequantize
,
ngraph
::
op
)
NGRAPH_OP
(
Dequantize
,
ngraph
::
op
)
NGRAPH_OP
(
Divide
,
ngraph
::
op
)
NGRAPH_OP
(
Divide
,
ngraph
::
op
)
NGRAPH_OP
(
DynBroadcast
,
ngraph
::
op
)
NGRAPH_OP
(
Dot
,
ngraph
::
op
)
NGRAPH_OP
(
Dot
,
ngraph
::
op
)
NGRAPH_OP
(
DynBroadcast
,
ngraph
::
op
)
NGRAPH_OP
(
DynPad
,
ngraph
::
op
)
NGRAPH_OP
(
DynPad
,
ngraph
::
op
)
NGRAPH_OP
(
DynReshape
,
ngraph
::
op
)
NGRAPH_OP
(
DynReshape
,
ngraph
::
op
)
NGRAPH_OP
(
DynSlice
,
ngraph
::
op
)
NGRAPH_OP
(
DynSlice
,
ngraph
::
op
)
NGRAPH_OP
(
EmbeddingLookup
,
ngraph
::
op
)
NGRAPH_OP
(
Equal
,
ngraph
::
op
)
NGRAPH_OP
(
Equal
,
ngraph
::
op
)
NGRAPH_OP
(
Erf
,
ngraph
::
op
)
NGRAPH_OP
(
Erf
,
ngraph
::
op
)
NGRAPH_OP
(
Exp
,
ngraph
::
op
)
NGRAPH_OP
(
Exp
,
ngraph
::
op
)
...
@@ -119,13 +120,13 @@ NGRAPH_OP(Power, ngraph::op)
...
@@ -119,13 +120,13 @@ NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP
(
Product
,
ngraph
::
op
)
NGRAPH_OP
(
Product
,
ngraph
::
op
)
NGRAPH_OP
(
Quantize
,
ngraph
::
op
)
NGRAPH_OP
(
Quantize
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedAvgPool
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedAvgPool
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBiasAdd
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBiasAdd
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBiasSignedAdd
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBiasSignedAdd
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionRelu
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionRelu
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDotBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDot
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDot
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDotBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedMaxPool
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedMaxPool
,
ngraph
::
op
)
NGRAPH_OP
(
Range
,
ngraph
::
op
)
NGRAPH_OP
(
Range
,
ngraph
::
op
)
NGRAPH_OP
(
Relu
,
ngraph
::
op
)
NGRAPH_OP
(
Relu
,
ngraph
::
op
)
...
@@ -153,7 +154,6 @@ NGRAPH_OP(Subtract, ngraph::op)
...
@@ -153,7 +154,6 @@ NGRAPH_OP(Subtract, ngraph::op)
NGRAPH_OP
(
Sum
,
ngraph
::
op
)
NGRAPH_OP
(
Sum
,
ngraph
::
op
)
NGRAPH_OP
(
Tan
,
ngraph
::
op
)
NGRAPH_OP
(
Tan
,
ngraph
::
op
)
NGRAPH_OP
(
Tanh
,
ngraph
::
op
)
NGRAPH_OP
(
Tanh
,
ngraph
::
op
)
NGRAPH_OP
(
TopK
,
ngraph
::
op
)
NGRAPH_OP
(
Tile
,
ngraph
::
op
)
NGRAPH_OP
(
Tile
,
ngraph
::
op
)
NGRAPH_OP
(
TopK
,
ngraph
::
op
)
NGRAPH_OP
(
Transpose
,
ngraph
::
op
)
NGRAPH_OP
(
Transpose
,
ngraph
::
op
)
NGRAPH_OP
(
EmbeddingLookup
,
ngraph
::
op
)
src/ngraph/pass/fused_op_decomposition.cpp
View file @
24c715f4
...
@@ -13,36 +13,51 @@
...
@@ -13,36 +13,51 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
//*****************************************************************************
//*****************************************************************************
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/fused_op.hpp"
using
namespace
std
;
using
namespace
std
;
using
namespace
ngraph
;
using
namespace
ngraph
;
bool
ngraph
::
pass
::
FusedOpDecomposition
::
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
pass
::
FusedOpDecomposition
::
FusedOpDecomposition
(
op_query_t
callback
)
:
m_has_direct_support
{
callback
}
{
}
bool
pass
::
FusedOpDecomposition
::
run_on_node
(
shared_ptr
<
Node
>
node
)
{
{
bool
modified
=
false
;
bool
modified
=
false
;
if
(
auto
fused_op
=
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
util
::
FusedOp
>
(
node
))
if
(
auto
fused_op
=
dynamic_pointer_cast
<
op
::
util
::
FusedOp
>
(
node
))
{
{
if
(
m_
callback
&&
m_callback
(
*
node
))
if
(
m_
has_direct_support
&&
m_has_direct_support
(
*
node
))
{
{
// Op supported by backend. Do not decompose
// Op supported by backend. Do not decompose
return
modified
;
return
modified
;
}
}
auto
subgraph_outputs
=
fused_op
->
decompose_op
();
auto
subgraph_outputs
=
fused_op
->
decompose_op
();
// Run recursively untill no more fused ops
auto
subgraph
=
extract_subgraph
(
subgraph_outputs
,
fused_op
->
get_arguments
());
for
(
auto
subgraph_node
:
subgraph
)
{
if
(
auto
nested_fused_op
=
dynamic_pointer_cast
<
op
::
util
::
FusedOp
>
(
subgraph_node
))
{
if
(
!
(
m_has_direct_support
&&
m_has_direct_support
(
*
nested_fused_op
)))
{
run_on_node
(
nested_fused_op
);
}
}
}
size_t
i
=
0
;
size_t
i
=
0
;
for
(
auto
output_node
:
subgraph_outputs
)
for
(
auto
output_node
:
subgraph_outputs
)
{
{
for
(
size_t
j
=
0
;
j
<
output_node
->
get_outputs
().
size
();
j
++
,
i
++
)
for
(
size_t
j
=
0
;
j
<
output_node
->
get_outputs
().
size
();
j
++
,
i
++
)
{
{
// TODO: Provenance
// TODO: Provenance
std
::
set
<
ngraph
::
descriptor
::
Input
*>
fop_users
{
set
<
descriptor
::
Input
*>
fop_users
{
begin
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
()),
begin
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
()),
end
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
())};
end
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
())};
for
(
auto
fop_user
:
fop_users
)
for
(
auto
fop_user
:
fop_users
)
{
{
...
@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
...
@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
if
(
goe
->
get_n
()
==
i
&&
!
goe
->
get_output_inputs
(
0
).
empty
())
if
(
goe
->
get_n
()
==
i
&&
!
goe
->
get_output_inputs
(
0
).
empty
())
{
{
// Replace GOE users
// Replace GOE users
s
td
::
set
<
ngraph
::
descriptor
::
Input
*>
goe_users
{
s
et
<
descriptor
::
Input
*>
goe_users
{
begin
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
()),
begin
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
()),
end
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
())};
end
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
())};
for
(
auto
goe_user
:
goe_users
)
for
(
auto
goe_user
:
goe_users
)
...
@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
...
@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
return
modified
;
return
modified
;
}
}
pass
::
FusedOpDecomposition
::
FusedOpDecomposition
(
op_query_t
callback
)
:
m_callback
{
callback
}
{
}
src/ngraph/pass/fused_op_decomposition.hpp
View file @
24c715f4
...
@@ -16,6 +16,9 @@
...
@@ -16,6 +16,9 @@
#pragma once
#pragma once
#include <memory>
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
namespace
ngraph
...
@@ -25,13 +28,24 @@ namespace ngraph
...
@@ -25,13 +28,24 @@ namespace ngraph
class
FusedOpDecomposition
:
public
NodePass
class
FusedOpDecomposition
:
public
NodePass
{
{
public
:
public
:
/// \brief Function signature type for callback used to check whether provided node
/// is supported by backend.
using
op_query_t
=
std
::
function
<
bool
(
const
Node
&
node
)
>
;
using
op_query_t
=
std
::
function
<
bool
(
const
Node
&
node
)
>
;
///
/// \brief Constructor for the Fused operation decomposition pass.
///
/// \param[in] callback The function object used to determine whether current backend
/// provide direct support for passed node. Should have signature:
/// bool fn(const Node&)
///
FusedOpDecomposition
(
op_query_t
callback
=
nullptr
);
FusedOpDecomposition
(
op_query_t
callback
=
nullptr
);
bool
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
override
;
bool
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
override
;
private
:
private
:
op_query_t
m_callback
=
nullptr
;
/// \brief A function returning whether provided Node is supported by current backend.
/// The returned bool value is used to control whether decompose operator or not.
op_query_t
m_has_direct_support
=
nullptr
;
};
};
}
}
}
}
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
24c715f4
...
@@ -924,7 +924,8 @@ using namespace ngraph::runtime;
...
@@ -924,7 +924,8 @@ using namespace ngraph::runtime;
// Always enable nodes computing output tensors or nodes whose outputs might get
// Always enable nodes computing output tensors or nodes whose outputs might get
// overwritten due to inplace kernels
// overwritten due to inplace kernels
// TODO (jbobba) - Do we need to handle cacheability
// TODO (jbobba) - Do we need to handle cacheability
if
(
computes_result
(
node
.
get
())
||
possibly_overwritten
(
node
.
get
()))
if
(
computes_result
(
node
.
get
())
||
possibly_overwritten
(
node
.
get
())
||
node
->
has_state
())
{
{
writer
<<
" || 1"
;
writer
<<
" || 1"
;
}
}
...
@@ -1187,7 +1188,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
...
@@ -1187,7 +1188,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
{
{
REGISTER_KNOBBED_PASS
(
CPUFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUFusion
,
true
,
runtime
::
cpu
::
pass
);
}
}
REGISTER_KNOBBED_PASS
(
CPUQuantFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUQuantFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUHorizontalFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUHorizontalFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUCollapseDims
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUCollapseDims
,
true
,
runtime
::
cpu
::
pass
);
...
@@ -1437,7 +1437,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
...
@@ -1437,7 +1437,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
bool
disable_caching
=
bool
disable_caching
=
(
reuse_memory
&&
(
reuse_memory
&&
!
cacheable
)
// Check cacheability only if we are reusing intermediate tensors
!
cacheable
)
// Check cacheability only if we are reusing intermediate tensors
||
computes_result
(
node
.
get
())
||
possibly_overwritten
(
node
.
get
());
||
computes_result
(
node
.
get
())
||
possibly_overwritten
(
node
.
get
())
||
node
->
has_state
()
;
vector
<
reference_wrapper
<
bool
>>
in_stale
,
out_stale
;
vector
<
reference_wrapper
<
bool
>>
in_stale
,
out_stale
;
for
(
const
auto
&
name
:
in_names
)
for
(
const
auto
&
name
:
in_names
)
...
...
src/ngraph/serializer.cpp
View file @
24c715f4
...
@@ -192,10 +192,15 @@ static OP_TYPEID get_typeid(const string& s)
...
@@ -192,10 +192,15 @@ static OP_TYPEID get_typeid(const string& s)
return
rc
;
return
rc
;
}
}
bool
has_key
(
json
j
,
const
std
::
string
&
key
)
{
return
j
.
count
(
key
)
!=
0
;
}
template
<
typename
T
>
template
<
typename
T
>
T
get_or_default
(
nlohmann
::
json
&
j
,
const
std
::
string
&
key
,
const
T
&
default_value
)
T
get_or_default
(
json
j
,
const
std
::
string
&
key
,
const
T
&
default_value
)
{
{
return
j
.
count
(
key
)
!=
0
?
j
.
at
(
key
).
get
<
T
>
()
:
default_value
;
return
has_key
(
j
,
key
)
?
j
.
at
(
key
).
get
<
T
>
()
:
default_value
;
}
}
class
JSONSerializer
class
JSONSerializer
...
@@ -214,8 +219,11 @@ public:
...
@@ -214,8 +219,11 @@ public:
json
serialize_function
(
const
Function
&
function
);
json
serialize_function
(
const
Function
&
function
);
json
serialize_output
(
const
Output
<
Node
>&
output
);
json
serialize_output
(
const
Output
<
Node
>&
output
);
json
serialize_parameter_vector
(
const
ParameterVector
&
parameters
);
json
serialize_output_vector
(
const
OutputVector
&
output_vector
);
json
serialize_node_reference
(
const
Node
&
node
);
json
serialize_node_reference
(
const
Node
&
node
);
json
serialize_node
(
const
Node
&
node
);
json
serialize_node
(
const
Node
&
node
);
json
serialize_axis_set
(
const
AxisSet
&
axis_set
);
protected
:
protected
:
size_t
m_indent
{
0
};
size_t
m_indent
{
0
};
...
@@ -234,10 +242,13 @@ public:
...
@@ -234,10 +242,13 @@ public:
m_const_data_callback
=
const_data_callback
;
m_const_data_callback
=
const_data_callback
;
}
}
shared_ptr
<
Function
>
deserialize_function
(
json
&
j
);
shared_ptr
<
Function
>
deserialize_function
(
json
j
);
Output
<
Node
>
deserialize_output
(
json
&
j
);
Output
<
Node
>
deserialize_output
(
json
j
);
shared_ptr
<
Node
>
deserialize_node_reference
(
json
&
j
);
OutputVector
deserialize_output_vector
(
json
j
);
shared_ptr
<
Node
>
deserialize_node
(
json
&
j
);
ParameterVector
deserialize_parameter_vector
(
json
j
);
shared_ptr
<
Node
>
deserialize_node_reference
(
json
j
);
shared_ptr
<
Node
>
deserialize_node
(
json
j
);
AxisSet
deserialize_axis_set
(
json
j
);
protected
:
protected
:
unordered_map
<
string
,
shared_ptr
<
Node
>>
m_node_map
;
unordered_map
<
string
,
shared_ptr
<
Node
>>
m_node_map
;
...
@@ -260,7 +271,7 @@ static json write_dimension(Dimension d)
...
@@ -260,7 +271,7 @@ static json write_dimension(Dimension d)
}
}
}
}
static
Dimension
read_dimension
(
const
json
&
j
)
static
Dimension
read_dimension
(
json
j
)
{
{
if
(
j
.
is_null
())
if
(
j
.
is_null
())
{
{
...
@@ -289,7 +300,7 @@ static json write_partial_shape(const PartialShape& s)
...
@@ -289,7 +300,7 @@ static json write_partial_shape(const PartialShape& s)
}
}
}
}
static
PartialShape
read_partial_shape
(
const
json
&
j
)
static
PartialShape
read_partial_shape
(
json
j
)
{
{
if
(
j
.
is_null
())
if
(
j
.
is_null
())
{
{
...
@@ -314,19 +325,32 @@ static json write_auto_broadcast(const op::AutoBroadcastSpec& autob)
...
@@ -314,19 +325,32 @@ static json write_auto_broadcast(const op::AutoBroadcastSpec& autob)
return
j
;
return
j
;
}
}
static
op
::
AutoBroadcastSpec
read_auto_broadcast
(
const
json
&
j
)
static
op
::
AutoBroadcastSpec
read_auto_broadcast
(
json
js_node
,
const
std
::
string
&
attr
)
{
{
if
(
!
j
.
is_object
(
))
if
(
has_key
(
js_node
,
attr
))
{
{
return
op
::
AutoBroadcastSpec
();
json
j
=
js_node
[
attr
];
return
op
::
AutoBroadcastSpec
(
static_cast
<
op
::
AutoBroadcastType
>
(
j
.
at
(
"type"
)),
j
.
at
(
"axis"
).
get
<
size_t
>
());
}
}
else
else
{
{
return
op
::
AutoBroadcastSpec
(
static_cast
<
op
::
AutoBroadcastType
>
(
j
.
at
(
"type"
)),
return
op
::
AutoBroadcastSpec
();
j
.
at
(
"axis"
).
get
<
size_t
>
());
}
}
}
}
static
op
::
PadType
read_pad_type
(
json
node_js
)
{
return
has_key
(
node_js
,
"pad_type"
)
?
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
))
:
op
::
PadType
::
EXPLICIT
;
}
static
op
::
PadMode
read_pad_mode
(
json
node_js
)
{
return
has_key
(
node_js
,
"pad_mode"
)
?
static_cast
<
op
::
PadMode
>
(
node_js
.
at
(
"pad_mode"
))
:
op
::
PadMode
::
CONSTANT
;
}
static
json
write_element_type
(
const
ngraph
::
element
::
Type
&
n
)
static
json
write_element_type
(
const
ngraph
::
element
::
Type
&
n
)
{
{
json
j
;
json
j
;
...
@@ -334,7 +358,7 @@ static json write_element_type(const ngraph::element::Type& n)
...
@@ -334,7 +358,7 @@ static json write_element_type(const ngraph::element::Type& n)
return
j
;
return
j
;
}
}
static
element
::
Type
read_element_type
(
const
json
&
j
)
static
element
::
Type
read_element_type
(
json
j
)
{
{
size_t
bitwidth
=
0
;
size_t
bitwidth
=
0
;
bool
is_real
=
false
;
bool
is_real
=
false
;
...
@@ -494,21 +518,24 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
...
@@ -494,21 +518,24 @@ shared_ptr<ngraph::Function> ngraph::deserialize(const string& s)
rc
=
deserializer
.
deserialize_function
(
func
);
rc
=
deserializer
.
deserialize_function
(
func
);
}
}
}
}
return
rc
;
return
rc
;
}
}
json
JSONSerializer
::
serialize_parameter_vector
(
const
ParameterVector
&
parameters
)
{
json
json_parameters
=
json
::
array
();
for
(
auto
param
:
parameters
)
{
json_parameters
.
push_back
(
serialize_node_reference
(
*
param
));
}
return
json_parameters
;
}
json
JSONSerializer
::
serialize_function
(
const
Function
&
f
)
json
JSONSerializer
::
serialize_function
(
const
Function
&
f
)
{
{
json
function
;
json
function
;
function
[
"name"
]
=
f
.
get_name
();
function
[
"name"
]
=
f
.
get_name
();
function
[
"parameters"
]
=
serialize_parameter_vector
(
f
.
get_parameters
());
vector
<
string
>
parameter_list
;
for
(
auto
param
:
f
.
get_parameters
())
{
parameter_list
.
push_back
(
serialize_node_reference
(
*
param
));
}
function
[
"parameters"
]
=
parameter_list
;
// TODO Functions can return multiple results
// TODO Functions can return multiple results
for
(
size_t
i
=
0
;
i
<
f
.
get_output_size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
f
.
get_output_size
();
++
i
)
...
@@ -520,7 +547,7 @@ json JSONSerializer::serialize_function(const Function& f)
...
@@ -520,7 +547,7 @@ json JSONSerializer::serialize_function(const Function& f)
}
}
template
<
typename
T
>
template
<
typename
T
>
T
get_value
(
nlohmann
::
json
js
,
const
string
&
key
)
T
get_value
(
json
js
,
const
string
&
key
)
{
{
T
rc
;
T
rc
;
auto
it
=
js
.
find
(
key
);
auto
it
=
js
.
find
(
key
);
...
@@ -531,13 +558,13 @@ T get_value(nlohmann::json js, const string& key)
...
@@ -531,13 +558,13 @@ T get_value(nlohmann::json js, const string& key)
return
rc
;
return
rc
;
}
}
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node_reference
(
json
&
j
)
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node_reference
(
json
j
)
{
{
const
string
&
name
=
j
;
const
string
&
name
=
j
;
return
m_node_map
.
at
(
name
);
return
m_node_map
.
at
(
name
);
}
}
Output
<
Node
>
JSONDeserializer
::
deserialize_output
(
json
&
j
)
Output
<
Node
>
JSONDeserializer
::
deserialize_output
(
json
j
)
{
{
size_t
index
;
size_t
index
;
json
json_node_reference
;
json
json_node_reference
;
...
@@ -558,10 +585,48 @@ Output<Node> JSONDeserializer::deserialize_output(json& j)
...
@@ -558,10 +585,48 @@ Output<Node> JSONDeserializer::deserialize_output(json& j)
return
Output
<
Node
>
(
deserialize_node_reference
(
json_node_reference
),
index
);
return
Output
<
Node
>
(
deserialize_node_reference
(
json_node_reference
),
index
);
}
}
shared_ptr
<
Function
>
JSONDeserializer
::
deserialize_function
(
json
&
func_js
)
OutputVector
JSONDeserializer
::
deserialize_output_vector
(
json
j
)
{
OutputVector
result
;
if
(
j
.
is_array
())
{
for
(
json
jelt
:
j
)
{
result
.
push_back
(
deserialize_output
(
jelt
));
}
}
return
result
;
}
json
JSONSerializer
::
serialize_axis_set
(
const
AxisSet
&
axis_set
)
{
return
static_cast
<
set
<
size_t
>>
(
axis_set
);
}
AxisSet
JSONDeserializer
::
deserialize_axis_set
(
json
j
)
{
AxisSet
result
;
if
(
j
.
is_array
())
{
result
=
j
.
get
<
set
<
size_t
>>
();
}
return
result
;
}
ParameterVector
JSONDeserializer
::
deserialize_parameter_vector
(
json
json_parameters
)
{
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
params
;
for
(
auto
&
param_ref
:
json_parameters
)
{
params
.
push_back
(
dynamic_pointer_cast
<
op
::
Parameter
>
(
deserialize_node_reference
(
param_ref
)));
}
return
params
;
}
shared_ptr
<
Function
>
JSONDeserializer
::
deserialize_function
(
json
func_js
)
{
{
string
func_name
=
func_js
.
at
(
"name"
).
get
<
string
>
();
string
func_name
=
func_js
.
at
(
"name"
).
get
<
string
>
();
vector
<
json
>
func_parameters
=
func_js
.
at
(
"parameters"
);
vector
<
json
>
func_result
=
func_js
.
at
(
"result"
);
vector
<
json
>
func_result
=
func_js
.
at
(
"result"
);
for
(
json
node_js
:
func_js
.
at
(
"ops"
))
for
(
json
node_js
:
func_js
.
at
(
"ops"
))
{
{
...
@@ -593,12 +658,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
...
@@ -593,12 +658,7 @@ shared_ptr<Function> JSONDeserializer::deserialize_function(json& func_js)
"Graph serialization is inconsistent. Some op::Results appear to be missing"
);
"Graph serialization is inconsistent. Some op::Results appear to be missing"
);
}
}
std
::
vector
<
std
::
shared_ptr
<
op
::
Parameter
>>
params
;
ParameterVector
params
=
deserialize_parameter_vector
(
func_js
.
at
(
"parameters"
));
for
(
auto
&
param_ref
:
func_parameters
)
{
params
.
push_back
(
dynamic_pointer_cast
<
op
::
Parameter
>
(
deserialize_node_reference
(
param_ref
)));
}
shared_ptr
<
Function
>
rc
{
make_shared
<
Function
>
(
result
,
params
,
func_name
)};
shared_ptr
<
Function
>
rc
{
make_shared
<
Function
>
(
result
,
params
,
func_name
)};
m_function_map
[
func_name
]
=
rc
;
m_function_map
[
func_name
]
=
rc
;
...
@@ -631,7 +691,12 @@ struct OutputHelper
...
@@ -631,7 +691,12 @@ struct OutputHelper
// when all op constructors use the new style arguments.
// when all op constructors use the new style arguments.
struct
OutputVectorHelper
struct
OutputVectorHelper
{
{
const
OutputHelper
&
operator
[](
size_t
i
)
const
{
return
m_vector
[
i
];
}
OutputVectorHelper
(
const
OutputVector
&
output_vector
)
:
m_vector
(
output_vector
)
{
}
OutputVectorHelper
()
=
default
;
OutputHelper
operator
[](
size_t
i
)
const
{
return
OutputHelper
(
m_vector
[
i
]);
}
void
push_back
(
const
Output
<
Node
>&
output
)
{
m_vector
.
push_back
(
output
);
}
void
push_back
(
const
Output
<
Node
>&
output
)
{
m_vector
.
push_back
(
output
);
}
size_t
size
()
const
{
return
m_vector
.
size
();
}
size_t
size
()
const
{
return
m_vector
.
size
();
}
operator
vector
<
shared_ptr
<
Node
>>
()
const
operator
vector
<
shared_ptr
<
Node
>>
()
const
...
@@ -639,14 +704,15 @@ struct OutputVectorHelper
...
@@ -639,14 +704,15 @@ struct OutputVectorHelper
vector
<
shared_ptr
<
Node
>>
result
;
vector
<
shared_ptr
<
Node
>>
result
;
for
(
auto
&
o
:
m_vector
)
for
(
auto
&
o
:
m_vector
)
{
{
result
.
push_back
(
o
);
result
.
push_back
(
OutputHelper
(
o
)
);
}
}
return
result
;
return
result
;
}
}
vector
<
OutputHelper
>
m_vector
;
operator
const
OutputVector
&
()
const
{
return
m_vector
;
}
OutputVector
m_vector
;
};
};
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node
(
json
&
node_js
)
shared_ptr
<
Node
>
JSONDeserializer
::
deserialize_node
(
json
node_js
)
{
{
shared_ptr
<
Node
>
node
;
shared_ptr
<
Node
>
node
;
try
try
...
@@ -654,14 +720,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -654,14 +720,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
string
node_name
=
node_js
.
at
(
"name"
).
get
<
string
>
();
string
node_name
=
node_js
.
at
(
"name"
).
get
<
string
>
();
string
node_op
=
node_js
.
at
(
"op"
).
get
<
string
>
();
string
node_op
=
node_js
.
at
(
"op"
).
get
<
string
>
();
string
friendly_name
=
get_value
<
string
>
(
node_js
,
"friendly_name"
);
string
friendly_name
=
get_value
<
string
>
(
node_js
,
"friendly_name"
);
vector
<
json
>
node_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"inputs"
);
vector
<
json
>
control_deps_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"control_deps"
);
vector
<
json
>
control_deps_inputs
=
get_value
<
vector
<
json
>>
(
node_js
,
"control_deps"
);
vector
<
string
>
node_outputs
=
get_value
<
vector
<
string
>>
(
node_js
,
"outputs"
);
vector
<
string
>
node_outputs
=
get_value
<
vector
<
string
>>
(
node_js
,
"outputs"
);
OutputVectorHelper
args
;
OutputVectorHelper
args
(
deserialize_output_vector
(
node_js
[
"inputs"
]));
for
(
auto
&
node_input
:
node_inputs
)
{
args
.
push_back
(
deserialize_output
(
node_input
));
}
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#if !(defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
...
@@ -682,12 +743,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -682,12 +743,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Add
:
case
OP_TYPEID
:
:
Add
:
{
{
node
=
make_shared
<
op
::
Add
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Add
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
All
:
case
OP_TYPEID
:
:
All
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
All
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
All
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
...
@@ -698,12 +759,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -698,12 +759,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
And
:
case
OP_TYPEID
:
:
And
:
{
{
node
=
make_shared
<
op
::
And
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
And
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Any
:
case
OP_TYPEID
:
:
Any
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Any
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Any
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
...
@@ -740,12 +801,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -740,12 +801,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto
padding_above
=
node_js
.
at
(
"padding_above"
).
get
<
vector
<
size_t
>>
();
auto
padding_above
=
node_js
.
at
(
"padding_above"
).
get
<
vector
<
size_t
>>
();
auto
include_padding_in_avg_computation
=
auto
include_padding_in_avg_computation
=
node_js
.
at
(
"include_padding_in_avg_computation"
).
get
<
bool
>
();
node_js
.
at
(
"include_padding_in_avg_computation"
).
get
<
bool
>
();
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
bool
ceil_mode
=
get_or_default
<
bool
>
(
node_js
,
"ceil_mode"
,
false
);
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
bool
ceil_mode
=
node_js
[
"ceil_mode"
].
empty
()
?
false
:
node_js
.
at
(
"ceil_mode"
).
get
<
bool
>
();
;
node
=
make_shared
<
op
::
AvgPool
>
(
args
[
0
],
node
=
make_shared
<
op
::
AvgPool
>
(
args
[
0
],
window_shape
,
window_shape
,
window_movement_strides
,
window_movement_strides
,
...
@@ -807,7 +864,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -807,7 +864,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Broadcast
:
case
OP_TYPEID
:
:
Broadcast
:
{
{
auto
shape
=
node_js
.
at
(
"shape"
).
get
<
vector
<
size_t
>>
();
auto
shape
=
node_js
.
at
(
"shape"
).
get
<
vector
<
size_t
>>
();
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
(
);
auto
axes
=
deserialize_axis_set
(
node_js
.
at
(
"axes"
)
);
node
=
make_shared
<
op
::
Broadcast
>
(
args
[
0
],
shape
,
axes
);
node
=
make_shared
<
op
::
Broadcast
>
(
args
[
0
],
shape
,
axes
);
break
;
break
;
}
}
...
@@ -818,7 +875,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -818,7 +875,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
BroadcastLike
:
case
OP_TYPEID
:
:
BroadcastLike
:
{
{
auto
initial_axes
=
node_js
.
at
(
"initial_axes"
).
get
<
set
<
size_t
>>
(
);
auto
initial_axes
=
deserialize_axis_set
(
node_js
.
at
(
"initial_axes"
)
);
node
=
make_shared
<
op
::
BroadcastLike
>
(
args
[
0
],
args
[
1
],
initial_axes
);
node
=
make_shared
<
op
::
BroadcastLike
>
(
args
[
0
],
args
[
1
],
initial_axes
);
break
;
break
;
}
}
...
@@ -837,13 +894,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -837,13 +894,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Concat
:
case
OP_TYPEID
:
:
Concat
:
{
{
auto
axis
=
node_js
.
at
(
"axis"
).
get
<
size_t
>
();
auto
axis
=
node_js
.
at
(
"axis"
).
get
<
size_t
>
();
node
=
make_shared
<
op
::
Concat
>
(
args
,
axis
);
node
=
make_shared
<
op
::
Concat
>
(
static_cast
<
OutputVector
>
(
args
)
,
axis
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Constant
:
case
OP_TYPEID
:
:
Constant
:
{
{
auto
type_node_js
=
auto
type_node_js
=
node_js
.
count
(
"element_type"
)
==
0
?
node_js
.
at
(
"value_type"
)
:
node_js
;
has_key
(
node_js
,
"element_type"
)
?
node_js
:
node_js
.
at
(
"value_type"
)
;
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
value
=
node_js
.
at
(
"value"
).
get
<
vector
<
string
>>
();
auto
value
=
node_js
.
at
(
"value"
).
get
<
vector
<
string
>>
();
...
@@ -867,17 +924,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -867,17 +924,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
// For backwards compatibility, we accept "image_dilation_strides" in place of
// For backwards compatibility, we accept "image_dilation_strides" in place of
// "data_dilation_strides", and we also allow it to be omitted altogether.
// "data_dilation_strides", and we also allow it to be omitted altogether.
auto
data_dilation_strides_maybe
=
node_js
[
"data_dilation_strides"
];
json
data_dilation_strides
;
if
(
data_dilation_strides_maybe
.
empty
())
if
(
has_key
(
node_js
,
"data_dilation_strides"
))
{
data_dilation_strides
=
node_js
[
"data_dilation_strides"
];
}
else
if
(
has_key
(
node_js
,
"image_dilation_strides"
))
{
{
data_dilation_strides
_maybe
=
node_js
[
"image_dilation_strides"
];
data_dilation_strides
=
node_js
[
"image_dilation_strides"
];
}
}
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
if
(
data_dilation_strides
_maybe
.
empty
())
if
(
data_dilation_strides
.
empty
())
{
{
node
=
make_shared
<
op
::
Convolution
>
(
args
[
0
],
node
=
make_shared
<
op
::
Convolution
>
(
args
[
0
],
args
[
1
],
args
[
1
],
...
@@ -888,14 +947,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -888,14 +947,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
else
else
{
{
node
=
make_shared
<
op
::
Convolution
>
(
node
=
args
[
0
],
make_shared
<
op
::
Convolution
>
(
args
[
0
],
args
[
1
],
args
[
1
],
window_movement_strides
,
window_movement_strides
,
window_dilation_strides
,
window_dilation_strides
,
padding_below
,
padding_below
,
padding_above
,
padding_above
,
data_dilation_strides_maybe
.
get
<
std
::
vector
<
size_t
>>
(),
data_dilation_strides
.
get
<
std
::
vector
<
size_t
>>
(),
pad_type
);
pad_type
);
}
}
break
;
break
;
...
@@ -1032,33 +1091,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1032,33 +1091,28 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Dequantize
:
case
OP_TYPEID
:
:
Dequantize
:
{
{
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
(
);
auto
axes
=
deserialize_axis_set
(
node_js
.
at
(
"axes"
)
);
node
=
make_shared
<
op
::
Dequantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
);
node
=
make_shared
<
op
::
Dequantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Divide
:
case
OP_TYPEID
:
:
Divide
:
{
{
bool
pythondiv
=
true
;
bool
pythondiv
=
get_or_default
(
node_js
,
"pythondiv"
,
true
);
if
(
node_js
[
"pythondiv"
].
is_object
())
{
pythondiv
=
node_js
.
at
(
"pythondiv"
).
get
<
bool
>
();
}
node
=
make_shared
<
op
::
Divide
>
(
node
=
make_shared
<
op
::
Divide
>
(
args
[
0
],
args
[
1
],
pythondiv
,
read_auto_broadcast
(
node_js
[
"autob"
]
));
args
[
0
],
args
[
1
],
pythondiv
,
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Dot
:
case
OP_TYPEID
:
:
Dot
:
{
{
// For backwards compatibility, reduction_axes_count is optional.
// For backwards compatibility, reduction_axes_count is optional.
auto
obj
=
node_js
[
"reduction_axes_count"
];
if
(
has_key
(
node_js
,
"reduction_axes_count"
))
if
(
obj
.
empty
())
{
{
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
]);
size_t
reduction_axes_count
=
node_js
[
"reduction_axes_count"
].
get
<
size_t
>
();
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
],
reduction_axes_count
);
}
}
else
else
{
{
size_t
reduction_axes_count
=
obj
.
get
<
size_t
>
();
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
]);
node
=
make_shared
<
op
::
Dot
>
(
args
[
0
],
args
[
1
],
reduction_axes_count
);
}
}
break
;
break
;
}
}
...
@@ -1094,7 +1148,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1094,7 +1148,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Equal
:
case
OP_TYPEID
:
:
Equal
:
{
{
node
=
make_shared
<
op
::
Equal
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Equal
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Erf
:
case
OP_TYPEID
:
:
Erf
:
...
@@ -1159,13 +1213,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1159,13 +1213,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Greater
:
case
OP_TYPEID
:
:
Greater
:
{
{
node
=
node
=
make_shared
<
op
::
Greater
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Greater
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
GreaterEq
:
case
OP_TYPEID
:
:
GreaterEq
:
{
{
node
=
node
=
make_shared
<
op
::
GreaterEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
GreaterEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
GRN
:
case
OP_TYPEID
:
:
GRN
:
...
@@ -1192,10 +1246,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1192,10 +1246,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto
data_dilation_strides
=
node_js
.
at
(
"data_dilation_strides"
).
get
<
vector
<
size_t
>>
();
auto
data_dilation_strides
=
node_js
.
at
(
"data_dilation_strides"
).
get
<
vector
<
size_t
>>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
node
=
make_shared
<
op
::
GroupConvolution
>
(
args
[
0
],
node
=
make_shared
<
op
::
GroupConvolution
>
(
args
[
0
],
args
[
1
],
args
[
1
],
window_movement_strides
,
window_movement_strides
,
...
@@ -1215,9 +1266,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1215,9 +1266,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
auto
padding_end
=
node_js
.
at
(
"padding_end"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
padding_end
=
node_js
.
at
(
"padding_end"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
output_padding
=
node_js
.
at
(
"output_padding"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
output_padding
=
node_js
.
at
(
"output_padding"
).
get
<
vector
<
ptrdiff_t
>>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
auto
groups
=
node_js
.
at
(
"groups"
).
get
<
size_t
>
();
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
auto
output_shape
=
node_js
.
at
(
"output_shape"
).
get
<
vector
<
size_t
>>
();
auto
output_shape
=
node_js
.
at
(
"output_shape"
).
get
<
vector
<
size_t
>>
();
node
=
make_shared
<
op
::
GroupConvolutionTranspose
>
(
args
[
0
],
node
=
make_shared
<
op
::
GroupConvolutionTranspose
>
(
args
[
0
],
...
@@ -1239,12 +1288,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1239,12 +1288,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Less
:
case
OP_TYPEID
:
:
Less
:
{
{
node
=
make_shared
<
op
::
Less
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Less
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
LessEq
:
case
OP_TYPEID
:
:
LessEq
:
{
{
node
=
make_shared
<
op
::
LessEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
LessEq
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Log
:
case
OP_TYPEID
:
:
Log
:
...
@@ -1286,7 +1335,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1286,7 +1335,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Max
:
case
OP_TYPEID
:
:
Max
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Max
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Max
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
...
@@ -1297,11 +1346,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1297,11 +1346,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
// omitted.
auto
padding_below_maybe
=
node_js
[
"padding_below"
];
auto
padding_below_maybe
=
get_or_default
(
node_js
,
"padding_below"
,
json
{});
auto
padding_above_maybe
=
node_js
[
"padding_above"
];
auto
padding_above_maybe
=
get_or_default
(
node_js
,
"padding_above"
,
json
{});
op
::
PadType
pad_type
=
node_js
[
"pad_type"
].
empty
()
op
::
PadType
pad_type
=
read_pad_type
(
node_js
);
?
op
::
PadType
::
EXPLICIT
:
static_cast
<
op
::
PadType
>
(
node_js
.
at
(
"pad_type"
));
if
(
padding_below_maybe
.
empty
()
&&
!
padding_above_maybe
.
empty
())
if
(
padding_below_maybe
.
empty
()
&&
!
padding_above_maybe
.
empty
())
{
{
throw
runtime_error
(
throw
runtime_error
(
...
@@ -1360,31 +1407,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1360,31 +1407,31 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Maximum
:
case
OP_TYPEID
:
:
Maximum
:
{
{
node
=
node
=
make_shared
<
op
::
Maximum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Maximum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Min
:
case
OP_TYPEID
:
:
Min
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Min
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Min
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Minimum
:
case
OP_TYPEID
:
:
Minimum
:
{
{
node
=
node
=
make_shared
<
op
::
Minimum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Minimum
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Multiply
:
case
OP_TYPEID
:
:
Multiply
:
{
{
node
=
node
=
make_shared
<
op
::
Multiply
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Multiply
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
MVN
:
case
OP_TYPEID
:
:
MVN
:
{
{
auto
normalize_variance
=
node_js
.
at
(
"normalize_variance"
).
get
<
bool
>
();
auto
normalize_variance
=
node_js
.
at
(
"normalize_variance"
).
get
<
bool
>
();
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
auto
eps
=
node_js
.
at
(
"eps"
).
get
<
double
>
();
auto
eps
=
node_js
.
at
(
"eps"
).
get
<
double
>
();
node
=
make_shared
<
op
::
MVN
>
(
args
[
0
],
normalize_variance
,
normalize_variance
,
eps
);
node
=
make_shared
<
op
::
MVN
>
(
args
[
0
],
normalize_variance
,
normalize_variance
,
eps
);
break
;
break
;
...
@@ -1406,7 +1453,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1406,7 +1453,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
NotEqual
:
case
OP_TYPEID
:
:
NotEqual
:
{
{
node
=
node
=
make_shared
<
op
::
NotEqual
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
NotEqual
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Not
:
case
OP_TYPEID
:
:
Not
:
...
@@ -1423,7 +1470,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1423,7 +1470,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Or
:
case
OP_TYPEID
:
:
Or
:
{
{
node
=
make_shared
<
op
::
Or
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Or
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Pad
:
case
OP_TYPEID
:
:
Pad
:
...
@@ -1440,9 +1487,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1440,9 +1487,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
[](
size_t
s
)
{
return
s
==
0
;
}),
[](
size_t
s
)
{
return
s
==
0
;
}),
"Legacy padding_interior field must be zero everywhere."
);
"Legacy padding_interior field must be zero everywhere."
);
auto
pad_mode
=
node_js
.
count
(
"pad_mode"
)
==
0
auto
pad_mode
=
read_pad_mode
(
node_js
);
?
op
::
PadMode
::
CONSTANT
:
static_cast
<
op
::
PadMode
>
(
node_js
.
at
(
"pad_mode"
));
node
=
make_shared
<
op
::
Pad
>
(
args
[
0
],
args
[
1
],
padding_below
,
padding_above
,
pad_mode
);
node
=
make_shared
<
op
::
Pad
>
(
args
[
0
],
args
[
1
],
padding_below
,
padding_above
,
pad_mode
);
break
;
break
;
...
@@ -1450,7 +1495,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1450,7 +1495,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Parameter
:
case
OP_TYPEID
:
:
Parameter
:
{
{
auto
type_node_js
=
auto
type_node_js
=
node_js
.
count
(
"element_type"
)
==
0
?
node_js
.
at
(
"value_type"
)
:
node_js
;
has_key
(
node_js
,
"element_type"
)
?
node_js
:
node_js
.
at
(
"value_type"
)
;
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
element_type
=
read_element_type
(
type_node_js
.
at
(
"element_type"
));
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
shape
=
type_node_js
.
at
(
"shape"
);
auto
cacheable
=
get_or_default
<
bool
>
(
node_js
,
"cacheable"
,
false
);
auto
cacheable
=
get_or_default
<
bool
>
(
node_js
,
"cacheable"
,
false
);
...
@@ -1475,7 +1520,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1475,7 +1520,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Power
:
case
OP_TYPEID
:
:
Power
:
{
{
node
=
make_shared
<
op
::
Power
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
node
=
make_shared
<
op
::
Power
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
PRelu
:
case
OP_TYPEID
:
:
PRelu
:
...
@@ -1485,14 +1530,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1485,14 +1530,14 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Product
:
case
OP_TYPEID
:
:
Product
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Product
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Product
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Quantize
:
case
OP_TYPEID
:
:
Quantize
:
{
{
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
type
=
read_element_type
(
node_js
.
at
(
"type"
));
auto
axes
=
node_js
.
at
(
"axes"
).
get
<
set
<
size_t
>>
(
);
auto
axes
=
deserialize_axis_set
(
node_js
.
at
(
"axes"
)
);
auto
round_mode
=
node_js
.
at
(
"round_mode"
).
get
<
op
::
Quantize
::
RoundMode
>
();
auto
round_mode
=
node_js
.
at
(
"round_mode"
).
get
<
op
::
Quantize
::
RoundMode
>
();
node
=
make_shared
<
op
::
Quantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
,
round_mode
);
node
=
make_shared
<
op
::
Quantize
>
(
args
[
0
],
args
[
1
],
args
[
2
],
type
,
axes
,
round_mode
);
break
;
break
;
...
@@ -1551,8 +1596,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1551,8 +1596,8 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
node_js
.
at
(
"window_movement_strides"
).
get
<
vector
<
size_t
>>
();
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// For backwards compatibility, both (but not just one) of the padding_ fields may be
// omitted.
// omitted.
auto
padding_below_maybe
=
node_js
[
"padding_below"
]
;
auto
padding_below_maybe
=
get_or_default
(
node_js
,
"padding_below"
,
json
{})
;
auto
padding_above_maybe
=
node_js
[
"padding_above"
]
;
auto
padding_above_maybe
=
get_or_default
(
node_js
,
"padding_above"
,
json
{})
;
auto
padding_below
=
padding_below_maybe
.
get
<
vector
<
size_t
>>
();
auto
padding_below
=
padding_below_maybe
.
get
<
vector
<
size_t
>>
();
auto
padding_above
=
padding_above_maybe
.
get
<
vector
<
size_t
>>
();
auto
padding_above
=
padding_above_maybe
.
get
<
vector
<
size_t
>>
();
node
=
make_shared
<
op
::
QuantizedMaxPool
>
(
node
=
make_shared
<
op
::
QuantizedMaxPool
>
(
...
@@ -1600,7 +1645,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1600,7 +1645,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Reverse
:
case
OP_TYPEID
:
:
Reverse
:
{
{
auto
reversed_axes
=
node_js
.
at
(
"reversed_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reversed_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reversed_axes"
)
);
node
=
make_shared
<
op
::
Reverse
>
(
args
[
0
],
reversed_axes
);
node
=
make_shared
<
op
::
Reverse
>
(
args
[
0
],
reversed_axes
);
break
;
break
;
}
}
...
@@ -1684,7 +1729,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1684,7 +1729,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
}
}
case
OP_TYPEID
:
:
Softmax
:
case
OP_TYPEID
:
:
Softmax
:
{
{
auto
softmax_axes
=
node_js
.
at
(
"softmax_axes"
).
get
<
set
<
size_t
>>
(
);
auto
softmax_axes
=
deserialize_axis_set
(
node_js
.
at
(
"softmax_axes"
)
);
node
=
make_shared
<
op
::
Softmax
>
(
args
[
0
],
softmax_axes
);
node
=
make_shared
<
op
::
Softmax
>
(
args
[
0
],
softmax_axes
);
break
;
break
;
}
}
...
@@ -1719,12 +1764,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
...
@@ -1719,12 +1764,12 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json& node_js)
case
OP_TYPEID
:
:
Subtract
:
case
OP_TYPEID
:
:
Subtract
:
{
{
node
=
node
=
make_shared
<
op
::
Subtract
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
[
"autob"
]
));
make_shared
<
op
::
Subtract
>
(
args
[
0
],
args
[
1
],
read_auto_broadcast
(
node_js
,
"autob"
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sum
:
case
OP_TYPEID
:
:
Sum
:
{
{
auto
reduction_axes
=
node_js
.
at
(
"reduction_axes"
).
get
<
set
<
size_t
>>
(
);
auto
reduction_axes
=
deserialize_axis_set
(
node_js
.
at
(
"reduction_axes"
)
);
node
=
make_shared
<
op
::
Sum
>
(
args
[
0
],
reduction_axes
);
node
=
make_shared
<
op
::
Sum
>
(
args
[
0
],
reduction_axes
);
break
;
break
;
}
}
...
@@ -1860,6 +1905,16 @@ json JSONSerializer::serialize_output(const Output<Node>& output)
...
@@ -1860,6 +1905,16 @@ json JSONSerializer::serialize_output(const Output<Node>& output)
return
result
;
return
result
;
}
}
json
JSONSerializer
::
serialize_output_vector
(
const
OutputVector
&
output_vector
)
{
json
result
;
for
(
const
Output
<
Node
>&
output
:
output_vector
)
{
result
.
push_back
(
serialize_output
(
output
));
}
return
result
;
}
json
JSONSerializer
::
serialize_node
(
const
Node
&
n
)
json
JSONSerializer
::
serialize_node
(
const
Node
&
n
)
{
{
m_nodes_serialized
.
insert
(
&
n
);
m_nodes_serialized
.
insert
(
&
n
);
...
@@ -1959,7 +2014,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -1959,7 +2014,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
All
:
case
OP_TYPEID
:
:
All
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
All
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
All
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
AllReduce
:
{
break
;
case
OP_TYPEID
:
:
AllReduce
:
{
break
;
...
@@ -1976,7 +2031,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -1976,7 +2031,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Any
:
case
OP_TYPEID
:
:
Any
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Any
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Any
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Asin
:
{
break
;
case
OP_TYPEID
:
:
Asin
:
{
break
;
...
@@ -2032,7 +2087,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2032,7 +2087,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Broadcast
:
case
OP_TYPEID
:
:
Broadcast
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Broadcast
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Broadcast
*>
(
&
n
);
node
[
"axes"
]
=
tmp
->
get_broadcast_axes
(
);
node
[
"axes"
]
=
serialize_axis_set
(
tmp
->
get_broadcast_axes
()
);
node
[
"shape"
]
=
tmp
->
get_broadcast_shape
();
node
[
"shape"
]
=
tmp
->
get_broadcast_shape
();
break
;
break
;
}
}
...
@@ -2041,7 +2096,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2041,7 +2096,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
BroadcastLike
:
case
OP_TYPEID
:
:
BroadcastLike
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
BroadcastLike
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
BroadcastLike
*>
(
&
n
);
node
[
"initial_axes"
]
=
tmp
->
get_initial_broadcast_axes
(
);
node
[
"initial_axes"
]
=
serialize_axis_set
(
tmp
->
get_initial_broadcast_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Ceiling
:
{
break
;
case
OP_TYPEID
:
:
Ceiling
:
{
break
;
...
@@ -2155,7 +2210,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2155,7 +2210,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Dequantize
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Dequantize
*>
(
&
n
);
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"axes"
]
=
tmp
->
get_axes
(
);
node
[
"axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
DepthToSpace
:
case
OP_TYPEID
:
:
DepthToSpace
:
...
@@ -2348,7 +2403,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2348,7 +2403,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Max
:
case
OP_TYPEID
:
:
Max
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Max
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Max
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
MaxPool
:
case
OP_TYPEID
:
:
MaxPool
:
...
@@ -2382,7 +2437,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2382,7 +2437,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Min
:
case
OP_TYPEID
:
:
Min
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Min
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Min
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Minimum
:
case
OP_TYPEID
:
:
Minimum
:
...
@@ -2406,7 +2461,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2406,7 +2461,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
MVN
:
case
OP_TYPEID
:
:
MVN
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
MVN
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
MVN
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
node
[
"normalize_variance"
]
=
tmp
->
get_normalize_variance
();
node
[
"normalize_variance"
]
=
tmp
->
get_normalize_variance
();
node
[
"eps"
]
=
tmp
->
get_eps
();
node
[
"eps"
]
=
tmp
->
get_eps
();
break
;
break
;
...
@@ -2486,7 +2541,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2486,7 +2541,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Product
:
case
OP_TYPEID
:
:
Product
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Product
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Product
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Power
:
case
OP_TYPEID
:
:
Power
:
...
@@ -2502,7 +2557,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2502,7 +2557,7 @@ json JSONSerializer::serialize_node(const Node& n)
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Quantize
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Quantize
*>
(
&
n
);
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"type"
]
=
write_element_type
(
tmp
->
get_element_type
());
node
[
"axes"
]
=
tmp
->
get_axes
(
);
node
[
"axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
()
);
node
[
"round_mode"
]
=
tmp
->
get_round_mode
();
node
[
"round_mode"
]
=
tmp
->
get_round_mode
();
break
;
break
;
}
}
...
@@ -2577,7 +2632,7 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2577,7 +2632,7 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Reverse
:
case
OP_TYPEID
:
:
Reverse
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Reverse
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Reverse
*>
(
&
n
);
node
[
"reversed_axes"
]
=
tmp
->
get_reversed_axes
(
);
node
[
"reversed_axes"
]
=
serialize_axis_set
(
tmp
->
get_reversed_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ReverseSequence
:
case
OP_TYPEID
:
:
ReverseSequence
:
...
@@ -2664,13 +2719,13 @@ json JSONSerializer::serialize_node(const Node& n)
...
@@ -2664,13 +2719,13 @@ json JSONSerializer::serialize_node(const Node& n)
case
OP_TYPEID
:
:
Sum
:
case
OP_TYPEID
:
:
Sum
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Sum
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Sum
*>
(
&
n
);
node
[
"reduction_axes"
]
=
tmp
->
get_reduction_axes
(
);
node
[
"reduction_axes"
]
=
serialize_axis_set
(
tmp
->
get_reduction_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Softmax
:
case
OP_TYPEID
:
:
Softmax
:
{
{
auto
tmp
=
dynamic_cast
<
const
op
::
Softmax
*>
(
&
n
);
auto
tmp
=
dynamic_cast
<
const
op
::
Softmax
*>
(
&
n
);
node
[
"softmax_axes"
]
=
tmp
->
get_axes
(
);
node
[
"softmax_axes"
]
=
serialize_axis_set
(
tmp
->
get_axes
()
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Tan
:
{
break
;
case
OP_TYPEID
:
:
Tan
:
{
break
;
...
...
test/models/onnx/eye_like.prototxt
0 → 100644
View file @
24c715f4
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "EyeLike"
attribute {
name: "k"
i: -1
type: INT
}
}
name: "hardmax_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 9
}
test/onnx/onnx_import.in.cpp
View file @
24c715f4
...
@@ -1482,3 +1482,46 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int)
...
@@ -1482,3 +1482,46 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int)
test_case
.
run
();
test_case
.
run
();
}
}
NGRAPH_TEST
(
onnx_
$
{
BACKEND_NAME
},
model_eye_like
)
{
const
auto
eye_like_fn
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/eye_like.prototxt"
));
auto
test_case
=
ngraph
::
test
::
NgraphTestCase
(
eye_like_fn
,
"${BACKEND_NAME}"
);
test_case
.
add_input
<
float
>
({
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
});
test_case
.
add_expected_output
<
float
>
(
Shape
{
3
,
4
},
{
0.
f
,
0.
f
,
0.
f
,
0.
f
,
1.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
1.
f
,
0.
f
,
0.
f
,
});
test_case
.
run
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment