Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
72df5da0
Unverified
Commit
72df5da0
authored
Jun 26, 2019
by
Fenglei Tian
Committed by
GitHub
Jun 26, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into tfl/send_recv_op
parents
db2de1b3
5e19c25c
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
319 additions
and
59 deletions
+319
-59
types.py
python/ngraph/utils/types.py
+1
-0
element_type.cpp
python/pyngraph/types/element_type.cpp
+1
-0
CMakeLists.txt
src/ngraph/frontend/onnx_import/CMakeLists.txt
+3
-0
node.cpp
src/ngraph/frontend/onnx_import/core/node.cpp
+16
-0
node.hpp
src/ngraph/frontend/onnx_import/core/node.hpp
+2
-0
value_info.hpp
src/ngraph/frontend/onnx_import/core/value_info.hpp
+5
-29
cast.cpp
src/ngraph/frontend/onnx_import/op/cast.cpp
+2
-21
eye_like.cpp
src/ngraph/frontend/onnx_import/op/eye_like.cpp
+63
-0
eye_like.hpp
src/ngraph/frontend/onnx_import/op/eye_like.hpp
+37
-0
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+2
-0
common.cpp
src/ngraph/frontend/onnx_import/utils/common.cpp
+51
-0
common.hpp
src/ngraph/frontend/onnx_import/utils/common.hpp
+43
-9
eye_like.prototxt
test/models/onnx/eye_like.prototxt
+50
-0
onnx_import.in.cpp
test/onnx/onnx_import.in.cpp
+43
-0
No files found.
python/ngraph/utils/types.py
View file @
72df5da0
...
@@ -37,6 +37,7 @@ NodeInput = Union[Node, NumericData]
...
@@ -37,6 +37,7 @@ NodeInput = Union[Node, NumericData]
ngraph_to_numpy_types_map
=
[
ngraph_to_numpy_types_map
=
[
(
NgraphType
.
boolean
,
np
.
bool
),
(
NgraphType
.
boolean
,
np
.
bool
),
(
NgraphType
.
f16
,
np
.
float16
),
(
NgraphType
.
f32
,
np
.
float32
),
(
NgraphType
.
f32
,
np
.
float32
),
(
NgraphType
.
f64
,
np
.
float64
),
(
NgraphType
.
f64
,
np
.
float64
),
(
NgraphType
.
i8
,
np
.
int8
),
(
NgraphType
.
i8
,
np
.
int8
),
...
...
python/pyngraph/types/element_type.cpp
View file @
72df5da0
...
@@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
...
@@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
py
::
class_
<
ngraph
::
element
::
Type
,
std
::
shared_ptr
<
ngraph
::
element
::
Type
>>
type
(
m
,
"Type"
);
py
::
class_
<
ngraph
::
element
::
Type
,
std
::
shared_ptr
<
ngraph
::
element
::
Type
>>
type
(
m
,
"Type"
);
type
.
doc
()
=
"ngraph.impl.Type wraps ngraph::element::Type"
;
type
.
doc
()
=
"ngraph.impl.Type wraps ngraph::element::Type"
;
type
.
attr
(
"boolean"
)
=
ngraph
::
element
::
boolean
;
type
.
attr
(
"boolean"
)
=
ngraph
::
element
::
boolean
;
type
.
attr
(
"f16"
)
=
ngraph
::
element
::
f16
;
type
.
attr
(
"f32"
)
=
ngraph
::
element
::
f32
;
type
.
attr
(
"f32"
)
=
ngraph
::
element
::
f32
;
type
.
attr
(
"f64"
)
=
ngraph
::
element
::
f64
;
type
.
attr
(
"f64"
)
=
ngraph
::
element
::
f64
;
type
.
attr
(
"i8"
)
=
ngraph
::
element
::
i8
;
type
.
attr
(
"i8"
)
=
ngraph
::
element
::
i8
;
...
...
src/ngraph/frontend/onnx_import/CMakeLists.txt
View file @
72df5da0
...
@@ -85,6 +85,8 @@ add_library(onnx_import STATIC
...
@@ -85,6 +85,8 @@ add_library(onnx_import STATIC
op/equal.hpp
op/equal.hpp
op/erf.hpp
op/erf.hpp
op/exp.hpp
op/exp.hpp
op/eye_like.cpp
op/eye_like.hpp
op/flatten.cpp
op/flatten.cpp
op/flatten.hpp
op/flatten.hpp
op/floor.hpp
op/floor.hpp
...
@@ -191,6 +193,7 @@ add_library(onnx_import STATIC
...
@@ -191,6 +193,7 @@ add_library(onnx_import STATIC
op/xor.hpp
op/xor.hpp
ops_bridge.cpp
ops_bridge.cpp
ops_bridge.hpp
ops_bridge.hpp
utils/common.cpp
utils/common.hpp
utils/common.hpp
utils/convpool.cpp
utils/convpool.cpp
utils/convpool.hpp
utils/convpool.hpp
...
...
src/ngraph/frontend/onnx_import/core/node.cpp
View file @
72df5da0
...
@@ -52,6 +52,8 @@ namespace ngraph
...
@@ -52,6 +52,8 @@ namespace ngraph
const
std
::
string
&
output
(
int
index
)
const
;
const
std
::
string
&
output
(
int
index
)
const
;
std
::
size_t
get_outputs_size
()
const
;
std
::
size_t
get_outputs_size
()
const
;
bool
has_attribute
(
const
std
::
string
&
name
)
const
;
template
<
typename
T
>
template
<
typename
T
>
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
...
@@ -87,6 +89,15 @@ namespace ngraph
...
@@ -87,6 +89,15 @@ namespace ngraph
}
}
std
::
size_t
Node
::
Impl
::
get_outputs_size
()
const
{
return
m_output_names
.
size
();
}
std
::
size_t
Node
::
Impl
::
get_outputs_size
()
const
{
return
m_output_names
.
size
();
}
bool
Node
::
Impl
::
has_attribute
(
const
std
::
string
&
name
)
const
{
auto
it
=
std
::
find_if
(
std
::
begin
(
m_attributes
),
std
::
end
(
m_attributes
),
[
&
](
const
Attribute
&
attribute
)
{
return
attribute
.
get_name
()
==
name
;
});
return
it
!=
std
::
end
(
m_attributes
);
}
template
<
typename
T
>
template
<
typename
T
>
T
Node
::
Impl
::
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
T
Node
::
Impl
::
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
{
{
...
@@ -185,6 +196,11 @@ namespace ngraph
...
@@ -185,6 +196,11 @@ namespace ngraph
const
std
::
string
&
Node
::
output
(
int
index
)
const
{
return
m_pimpl
->
output
(
index
);
}
const
std
::
string
&
Node
::
output
(
int
index
)
const
{
return
m_pimpl
->
output
(
index
);
}
std
::
size_t
Node
::
get_outputs_size
()
const
{
return
m_pimpl
->
get_outputs_size
();
}
std
::
size_t
Node
::
get_outputs_size
()
const
{
return
m_pimpl
->
get_outputs_size
();
}
bool
Node
::
has_attribute
(
const
std
::
string
&
name
)
const
{
return
m_pimpl
->
has_attribute
(
name
);
}
template
<>
template
<>
float
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
float
default_value
)
const
float
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
float
default_value
)
const
{
{
...
...
src/ngraph/frontend/onnx_import/core/node.hpp
View file @
72df5da0
...
@@ -78,6 +78,8 @@ namespace ngraph
...
@@ -78,6 +78,8 @@ namespace ngraph
const
std
::
string
&
output
(
int
index
)
const
;
const
std
::
string
&
output
(
int
index
)
const
;
std
::
size_t
get_outputs_size
()
const
;
std
::
size_t
get_outputs_size
()
const
;
bool
has_attribute
(
const
std
::
string
&
name
)
const
;
template
<
typename
T
>
template
<
typename
T
>
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
...
...
src/ngraph/frontend/onnx_import/core/value_info.hpp
View file @
72df5da0
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type.hpp"
#include "node.hpp"
#include "node.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
#include "utils/common.hpp"
#include "weight.hpp"
#include "weight.hpp"
namespace
ngraph
namespace
ngraph
...
@@ -41,17 +42,8 @@ namespace ngraph
...
@@ -41,17 +42,8 @@ namespace ngraph
{
{
}
}
};
};
struct
unsupported_element_type
:
ngraph_error
}
// namespace value_info
{
}
// namespace error
explicit
unsupported_element_type
(
TensorProto_DataType
type
)
:
ngraph_error
{
"unsupported value info element type: "
+
onnx
::
TensorProto_DataType_Name
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
type
))}
{
}
};
}
}
class
ValueInfo
class
ValueInfo
{
{
...
@@ -83,24 +75,8 @@ namespace ngraph
...
@@ -83,24 +75,8 @@ namespace ngraph
{
{
throw
error
::
value_info
::
unspecified_element_type
{};
throw
error
::
value_info
::
unspecified_element_type
{};
}
}
switch
(
m_value_info_proto
->
type
().
tensor_type
().
elem_type
())
return
common
::
get_ngraph_element_type
(
{
m_value_info_proto
->
type
().
tensor_type
().
elem_type
());
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT
:
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT16
:
return
element
::
f32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_DOUBLE
:
return
element
::
f64
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT8
:
return
element
::
i8
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT16
:
return
element
::
i16
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT32
:
return
element
::
i32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT64
:
return
element
::
i64
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT8
:
return
element
::
u8
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT16
:
return
element
::
u16
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT32
:
return
element
::
u32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT64
:
return
element
::
u64
;
default
:
throw
error
::
value_info
::
unsupported_element_type
{
m_value_info_proto
->
type
().
tensor_type
().
elem_type
()};
}
}
}
std
::
shared_ptr
<
ngraph
::
Node
>
std
::
shared_ptr
<
ngraph
::
Node
>
...
...
src/ngraph/frontend/onnx_import/op/cast.cpp
View file @
72df5da0
...
@@ -13,14 +13,13 @@
...
@@ -13,14 +13,13 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
//*****************************************************************************
//*****************************************************************************
#include <memory>
#include <memory>
#include <onnx-ml.pb.h>
#include "cast.hpp"
#include "cast.hpp"
#include "exceptions.hpp"
#include "exceptions.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -34,25 +33,7 @@ namespace ngraph
...
@@ -34,25 +33,7 @@ namespace ngraph
{
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
int64_t
target_type
=
node
.
get_attribute_value
<
int64_t
>
(
"to"
);
int64_t
target_type
=
node
.
get_attribute_value
<
int64_t
>
(
"to"
);
element
::
Type
elem_type
;
element
::
Type
elem_type
=
common
::
get_ngraph_element_type
(
target_type
);
switch
(
target_type
)
{
case
onnx
:
:
TensorProto_DataType_BOOL
:
elem_type
=
element
::
boolean
;
break
;
case
onnx
:
:
TensorProto_DataType_DOUBLE
:
elem_type
=
element
::
f64
;
break
;
case
onnx
:
:
TensorProto_DataType_FLOAT16
:
case
onnx
:
:
TensorProto_DataType_FLOAT
:
elem_type
=
element
::
f32
;
break
;
case
onnx
:
:
TensorProto_DataType_INT8
:
elem_type
=
element
::
i8
;
break
;
case
onnx
:
:
TensorProto_DataType_INT16
:
elem_type
=
element
::
i16
;
break
;
case
onnx
:
:
TensorProto_DataType_INT32
:
elem_type
=
element
::
i32
;
break
;
case
onnx
:
:
TensorProto_DataType_INT64
:
elem_type
=
element
::
i64
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT8
:
elem_type
=
element
::
u8
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT16
:
elem_type
=
element
::
u16
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT32
:
elem_type
=
element
::
u32
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT64
:
elem_type
=
element
::
u64
;
break
;
case
onnx
:
:
TensorProto_DataType_UNDEFINED
:
elem_type
=
element
::
dynamic
;
break
;
default
:
ASSERT_IS_SUPPORTED
(
node
,
false
)
<<
"unsupported type"
;
}
return
{
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
data
,
elem_type
)};
return
{
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
data
,
elem_type
)};
}
}
...
...
src/ngraph/frontend/onnx_import/op/eye_like.cpp
0 → 100644
View file @
72df5da0
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "eye_like.hpp"
#include "exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/common.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
eye_like
(
const
Node
&
node
)
{
const
auto
input
=
node
.
get_ng_inputs
().
at
(
0
);
const
auto
&
input_shape
=
input
->
get_shape
();
std
::
int64_t
dtype
;
element
::
Type
target_type
;
std
::
int64_t
shift
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"k"
,
0
);
if
(
node
.
has_attribute
(
"dtype"
))
{
dtype
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"dtype"
);
target_type
=
common
::
get_ngraph_element_type
(
dtype
);
}
else
{
target_type
=
input
->
get_element_type
();
}
ASSERT_VALID_ARGUMENT
(
node
,
input_shape
.
size
()
==
2
)
<<
"The provided shape rank: "
<<
input_shape
.
size
()
<<
" is unsupported, only 2D shapes are supported"
;
std
::
shared_ptr
<
ngraph
::
Node
>
eye_like_matrix
=
common
::
shifted_square_identity
(
input_shape
,
target_type
,
shift
);
return
{
eye_like_matrix
};
}
}
// namespace set_1
}
// namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/op/eye_like.hpp
0 → 100644
View file @
72df5da0
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
eye_like
(
const
Node
&
node
);
}
// namespace set_1
}
//namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
72df5da0
...
@@ -53,6 +53,7 @@
...
@@ -53,6 +53,7 @@
#include "op/equal.hpp"
#include "op/equal.hpp"
#include "op/erf.hpp"
#include "op/erf.hpp"
#include "op/exp.hpp"
#include "op/exp.hpp"
#include "op/eye_like.hpp"
#include "op/flatten.hpp"
#include "op/flatten.hpp"
#include "op/floor.hpp"
#include "op/floor.hpp"
#include "op/gather.hpp"
#include "op/gather.hpp"
...
@@ -260,6 +261,7 @@ namespace ngraph
...
@@ -260,6 +261,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Equal"
,
1
,
equal
);
REGISTER_OPERATOR
(
"Equal"
,
1
,
equal
);
REGISTER_OPERATOR
(
"Erf"
,
1
,
erf
);
REGISTER_OPERATOR
(
"Erf"
,
1
,
erf
);
REGISTER_OPERATOR
(
"Exp"
,
1
,
exp
);
REGISTER_OPERATOR
(
"Exp"
,
1
,
exp
);
REGISTER_OPERATOR
(
"EyeLike"
,
1
,
eye_like
);
REGISTER_OPERATOR
(
"Flatten"
,
1
,
flatten
);
REGISTER_OPERATOR
(
"Flatten"
,
1
,
flatten
);
REGISTER_OPERATOR
(
"Floor"
,
1
,
floor
);
REGISTER_OPERATOR
(
"Floor"
,
1
,
floor
);
REGISTER_OPERATOR
(
"Gather"
,
1
,
gather
);
REGISTER_OPERATOR
(
"Gather"
,
1
,
gather
);
...
...
src/ngraph/frontend/onnx_import/utils/common.cpp
0 → 100644
View file @
72df5da0
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <onnx-ml.pb.h> // onnx types
#include "common.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
common
{
const
ngraph
::
element
::
Type
&
get_ngraph_element_type
(
int64_t
onnx_type
)
{
switch
(
onnx_type
)
{
case
onnx
:
:
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType_DOUBLE
:
return
element
::
f64
;
case
onnx
:
:
TensorProto_DataType_FLOAT16
:
return
element
::
f16
;
case
onnx
:
:
TensorProto_DataType_FLOAT
:
return
element
::
f32
;
case
onnx
:
:
TensorProto_DataType_INT8
:
return
element
::
i8
;
case
onnx
:
:
TensorProto_DataType_INT16
:
return
element
::
i16
;
case
onnx
:
:
TensorProto_DataType_INT32
:
return
element
::
i32
;
case
onnx
:
:
TensorProto_DataType_INT64
:
return
element
::
i64
;
case
onnx
:
:
TensorProto_DataType_UINT8
:
return
element
::
u8
;
case
onnx
:
:
TensorProto_DataType_UINT16
:
return
element
::
u16
;
case
onnx
:
:
TensorProto_DataType_UINT32
:
return
element
::
u32
;
case
onnx
:
:
TensorProto_DataType_UINT64
:
return
element
::
u64
;
case
onnx
:
:
TensorProto_DataType_UNDEFINED
:
return
element
::
dynamic
;
}
throw
ngraph_error
(
"unsupported element type: "
+
onnx
::
TensorProto_DataType_Name
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
onnx_type
)));
}
}
// namespace common
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/utils/common.hpp
View file @
72df5da0
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <algorithm> // std::generate
#include <algorithm> // std::generate
#include <cmath> // std::floor, std::min
#include <cmath> // std::floor, std::min
#include <cstddef> // std::size_t
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <iterator> // std::begin, std::end
#include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared
#include <memory> // std::shared_ptr, std::make_shared
#include <type_traits> // std::enable_if
#include <type_traits> // std::enable_if
...
@@ -27,6 +28,7 @@
...
@@ -27,6 +28,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace
ngraph
namespace
ngraph
{
{
...
@@ -34,6 +36,8 @@ namespace ngraph
...
@@ -34,6 +36,8 @@ namespace ngraph
{
{
namespace
common
namespace
common
{
{
const
ngraph
::
element
::
Type
&
get_ngraph_element_type
(
std
::
int64_t
onnx_type
);
/// \brief Return a monotonic sequence.
/// \brief Return a monotonic sequence.
///
///
/// \note Limitations: this function may not work for very large integer values
/// \note Limitations: this function may not work for very large integer values
...
@@ -86,6 +90,43 @@ namespace ngraph
...
@@ -86,6 +90,43 @@ namespace ngraph
}
}
}
}
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main diagonal.
/// Shifting adds an offset and moves the diagonal up or down
///
/// \param[in] output_shape Shape of the resulting matrix.
/// \param[in] output_type Element type of the resulting matrix.
/// \param[in] shift Shifting of diagonal.
///
/// \return A Constant node representing shifted identity matrix.
template
<
typename
T
=
double
>
std
::
shared_ptr
<
ngraph
::
op
::
Constant
>
shifted_square_identity
(
const
Shape
output_shape
,
const
element
::
Type
&
output_type
,
const
std
::
int64_t
shift
)
{
std
::
vector
<
T
>
identity_matrix
(
shape_size
(
output_shape
),
T
{
0
});
std
::
int64_t
rows
=
output_shape
[
0
];
std
::
int64_t
cols
=
output_shape
[
1
];
for
(
std
::
int64_t
row
=
0
;
row
<
rows
;
++
row
)
{
const
std
::
int64_t
diagonal_element_idx
=
(
row
*
cols
)
+
row
+
shift
;
if
(
row
+
shift
<
0
)
{
continue
;
}
else
if
(
row
+
shift
>=
cols
)
{
break
;
}
identity_matrix
.
at
(
diagonal_element_idx
)
=
T
{
1
};
}
return
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
output_type
,
output_shape
,
identity_matrix
);
}
/// \brief Creates a square identity matrix.
/// \brief Creates a square identity matrix.
///
///
/// \param[in] n Order of the resulting matrix.
/// \param[in] n Order of the resulting matrix.
...
@@ -95,16 +136,9 @@ namespace ngraph
...
@@ -95,16 +136,9 @@ namespace ngraph
std
::
shared_ptr
<
ngraph
::
op
::
Constant
>
square_identity
(
const
size_t
n
,
std
::
shared_ptr
<
ngraph
::
op
::
Constant
>
square_identity
(
const
size_t
n
,
const
element
::
Type
&
type
)
const
element
::
Type
&
type
)
{
{
std
::
vector
<
T
>
identity_matrix
(
n
*
n
,
T
{
0
});
return
shifted_square_identity
(
Shape
{
n
,
n
},
type
,
0
);
for
(
size_t
row
=
0
;
row
<
n
;
++
row
)
{
const
size_t
diagonal_element
=
(
n
*
row
)
+
row
;
identity_matrix
.
at
(
diagonal_element
)
=
T
{
1
};
}
return
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
type
,
Shape
{{
n
,
n
}},
identity_matrix
);
}
}
}
// namespace common
}
// namespace common
}
// namespace onnx_import
}
// namespace onnx_import
}
// namespace ngraph
}
// namespace ngraph
test/models/onnx/eye_like.prototxt
0 → 100644
View file @
72df5da0
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "EyeLike"
attribute {
name: "k"
i: -1
type: INT
}
}
name: "hardmax_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 9
}
test/onnx/onnx_import.in.cpp
View file @
72df5da0
...
@@ -1482,3 +1482,46 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int)
...
@@ -1482,3 +1482,46 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int)
test_case
.
run
();
test_case
.
run
();
}
}
NGRAPH_TEST
(
onnx_
$
{
BACKEND_NAME
},
model_eye_like
)
{
const
auto
eye_like_fn
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/eye_like.prototxt"
));
auto
test_case
=
ngraph
::
test
::
NgraphTestCase
(
eye_like_fn
,
"${BACKEND_NAME}"
);
test_case
.
add_input
<
float
>
({
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
});
test_case
.
add_expected_output
<
float
>
(
Shape
{
3
,
4
},
{
0.
f
,
0.
f
,
0.
f
,
0.
f
,
1.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
1.
f
,
0.
f
,
0.
f
,
});
test_case
.
run
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment