Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
24c715f4
Unverified
Commit
24c715f4
authored
Jun 26, 2019
by
Scott Cyphers
Committed by
GitHub
Jun 26, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into mlir
parents
3b5bfdab
5e19c25c
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
373 additions
and
86 deletions
+373
-86
types.py
python/ngraph/utils/types.py
+1
-0
element_type.cpp
python/pyngraph/types/element_type.cpp
+1
-0
CMakeLists.txt
src/ngraph/frontend/onnx_import/CMakeLists.txt
+3
-0
node.cpp
src/ngraph/frontend/onnx_import/core/node.cpp
+16
-0
node.hpp
src/ngraph/frontend/onnx_import/core/node.hpp
+2
-0
value_info.hpp
src/ngraph/frontend/onnx_import/core/value_info.hpp
+5
-29
cast.cpp
src/ngraph/frontend/onnx_import/op/cast.cpp
+2
-21
eye_like.cpp
src/ngraph/frontend/onnx_import/op/eye_like.cpp
+63
-0
eye_like.hpp
src/ngraph/frontend/onnx_import/op/eye_like.hpp
+37
-0
ops_bridge.cpp
src/ngraph/frontend/onnx_import/ops_bridge.cpp
+2
-0
common.cpp
src/ngraph/frontend/onnx_import/utils/common.cpp
+51
-0
common.hpp
src/ngraph/frontend/onnx_import/utils/common.hpp
+43
-9
node.hpp
src/ngraph/node.hpp
+1
-0
generate_mask.hpp
src/ngraph/op/experimental/generate_mask.hpp
+2
-0
fused_op_tbl.hpp
src/ngraph/op/fused_op_tbl.hpp
+3
-3
op_tbl.hpp
src/ngraph/op/op_tbl.hpp
+5
-5
fused_op_decomposition.cpp
src/ngraph/pass/fused_op_decomposition.cpp
+25
-15
fused_op_decomposition.hpp
src/ngraph/pass/fused_op_decomposition.hpp
+15
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+3
-3
serializer.cpp
src/ngraph/serializer.cpp
+0
-0
eye_like.prototxt
test/models/onnx/eye_like.prototxt
+50
-0
onnx_import.in.cpp
test/onnx/onnx_import.in.cpp
+43
-0
No files found.
python/ngraph/utils/types.py
View file @
24c715f4
...
...
@@ -37,6 +37,7 @@ NodeInput = Union[Node, NumericData]
ngraph_to_numpy_types_map
=
[
(
NgraphType
.
boolean
,
np
.
bool
),
(
NgraphType
.
f16
,
np
.
float16
),
(
NgraphType
.
f32
,
np
.
float32
),
(
NgraphType
.
f64
,
np
.
float64
),
(
NgraphType
.
i8
,
np
.
int8
),
...
...
python/pyngraph/types/element_type.cpp
View file @
24c715f4
...
...
@@ -28,6 +28,7 @@ void regclass_pyngraph_Type(py::module m)
py
::
class_
<
ngraph
::
element
::
Type
,
std
::
shared_ptr
<
ngraph
::
element
::
Type
>>
type
(
m
,
"Type"
);
type
.
doc
()
=
"ngraph.impl.Type wraps ngraph::element::Type"
;
type
.
attr
(
"boolean"
)
=
ngraph
::
element
::
boolean
;
type
.
attr
(
"f16"
)
=
ngraph
::
element
::
f16
;
type
.
attr
(
"f32"
)
=
ngraph
::
element
::
f32
;
type
.
attr
(
"f64"
)
=
ngraph
::
element
::
f64
;
type
.
attr
(
"i8"
)
=
ngraph
::
element
::
i8
;
...
...
src/ngraph/frontend/onnx_import/CMakeLists.txt
View file @
24c715f4
...
...
@@ -85,6 +85,8 @@ add_library(onnx_import STATIC
op/equal.hpp
op/erf.hpp
op/exp.hpp
op/eye_like.cpp
op/eye_like.hpp
op/flatten.cpp
op/flatten.hpp
op/floor.hpp
...
...
@@ -191,6 +193,7 @@ add_library(onnx_import STATIC
op/xor.hpp
ops_bridge.cpp
ops_bridge.hpp
utils/common.cpp
utils/common.hpp
utils/convpool.cpp
utils/convpool.hpp
...
...
src/ngraph/frontend/onnx_import/core/node.cpp
View file @
24c715f4
...
...
@@ -52,6 +52,8 @@ namespace ngraph
const
std
::
string
&
output
(
int
index
)
const
;
std
::
size_t
get_outputs_size
()
const
;
bool
has_attribute
(
const
std
::
string
&
name
)
const
;
template
<
typename
T
>
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
...
...
@@ -87,6 +89,15 @@ namespace ngraph
}
std
::
size_t
Node
::
Impl
::
get_outputs_size
()
const
{
return
m_output_names
.
size
();
}
bool
Node
::
Impl
::
has_attribute
(
const
std
::
string
&
name
)
const
{
auto
it
=
std
::
find_if
(
std
::
begin
(
m_attributes
),
std
::
end
(
m_attributes
),
[
&
](
const
Attribute
&
attribute
)
{
return
attribute
.
get_name
()
==
name
;
});
return
it
!=
std
::
end
(
m_attributes
);
}
template
<
typename
T
>
T
Node
::
Impl
::
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
{
...
...
@@ -185,6 +196,11 @@ namespace ngraph
const
std
::
string
&
Node
::
output
(
int
index
)
const
{
return
m_pimpl
->
output
(
index
);
}
std
::
size_t
Node
::
get_outputs_size
()
const
{
return
m_pimpl
->
get_outputs_size
();
}
bool
Node
::
has_attribute
(
const
std
::
string
&
name
)
const
{
return
m_pimpl
->
has_attribute
(
name
);
}
template
<>
float
Node
::
get_attribute_value
(
const
std
::
string
&
name
,
float
default_value
)
const
{
...
...
src/ngraph/frontend/onnx_import/core/node.hpp
View file @
24c715f4
...
...
@@ -78,6 +78,8 @@ namespace ngraph
const
std
::
string
&
output
(
int
index
)
const
;
std
::
size_t
get_outputs_size
()
const
;
bool
has_attribute
(
const
std
::
string
&
name
)
const
;
template
<
typename
T
>
T
get_attribute_value
(
const
std
::
string
&
name
,
T
default_value
)
const
;
...
...
src/ngraph/frontend/onnx_import/core/value_info.hpp
View file @
24c715f4
...
...
@@ -24,6 +24,7 @@
#include "ngraph/type/element_type.hpp"
#include "node.hpp"
#include "tensor.hpp"
#include "utils/common.hpp"
#include "weight.hpp"
namespace
ngraph
...
...
@@ -41,17 +42,8 @@ namespace ngraph
{
}
};
struct
unsupported_element_type
:
ngraph_error
{
explicit
unsupported_element_type
(
TensorProto_DataType
type
)
:
ngraph_error
{
"unsupported value info element type: "
+
onnx
::
TensorProto_DataType_Name
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
type
))}
{
}
};
}
}
}
// namespace value_info
}
// namespace error
class
ValueInfo
{
...
...
@@ -83,24 +75,8 @@ namespace ngraph
{
throw
error
::
value_info
::
unspecified_element_type
{};
}
switch
(
m_value_info_proto
->
type
().
tensor_type
().
elem_type
())
{
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT
:
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_FLOAT16
:
return
element
::
f32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_DOUBLE
:
return
element
::
f64
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT8
:
return
element
::
i8
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT16
:
return
element
::
i16
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT32
:
return
element
::
i32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_INT64
:
return
element
::
i64
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT8
:
return
element
::
u8
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT16
:
return
element
::
u16
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT32
:
return
element
::
u32
;
case
onnx
:
:
TensorProto_DataType
::
TensorProto_DataType_UINT64
:
return
element
::
u64
;
default
:
throw
error
::
value_info
::
unsupported_element_type
{
m_value_info_proto
->
type
().
tensor_type
().
elem_type
()};
}
return
common
::
get_ngraph_element_type
(
m_value_info_proto
->
type
().
tensor_type
().
elem_type
());
}
std
::
shared_ptr
<
ngraph
::
Node
>
...
...
src/ngraph/frontend/onnx_import/op/cast.cpp
View file @
24c715f4
...
...
@@ -13,14 +13,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <onnx-ml.pb.h>
#include "cast.hpp"
#include "exceptions.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/type/element_type.hpp"
#include "utils/common.hpp"
namespace
ngraph
{
...
...
@@ -34,25 +33,7 @@ namespace ngraph
{
auto
data
=
node
.
get_ng_inputs
().
at
(
0
);
int64_t
target_type
=
node
.
get_attribute_value
<
int64_t
>
(
"to"
);
element
::
Type
elem_type
;
switch
(
target_type
)
{
case
onnx
:
:
TensorProto_DataType_BOOL
:
elem_type
=
element
::
boolean
;
break
;
case
onnx
:
:
TensorProto_DataType_DOUBLE
:
elem_type
=
element
::
f64
;
break
;
case
onnx
:
:
TensorProto_DataType_FLOAT16
:
case
onnx
:
:
TensorProto_DataType_FLOAT
:
elem_type
=
element
::
f32
;
break
;
case
onnx
:
:
TensorProto_DataType_INT8
:
elem_type
=
element
::
i8
;
break
;
case
onnx
:
:
TensorProto_DataType_INT16
:
elem_type
=
element
::
i16
;
break
;
case
onnx
:
:
TensorProto_DataType_INT32
:
elem_type
=
element
::
i32
;
break
;
case
onnx
:
:
TensorProto_DataType_INT64
:
elem_type
=
element
::
i64
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT8
:
elem_type
=
element
::
u8
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT16
:
elem_type
=
element
::
u16
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT32
:
elem_type
=
element
::
u32
;
break
;
case
onnx
:
:
TensorProto_DataType_UINT64
:
elem_type
=
element
::
u64
;
break
;
case
onnx
:
:
TensorProto_DataType_UNDEFINED
:
elem_type
=
element
::
dynamic
;
break
;
default
:
ASSERT_IS_SUPPORTED
(
node
,
false
)
<<
"unsupported type"
;
}
element
::
Type
elem_type
=
common
::
get_ngraph_element_type
(
target_type
);
return
{
std
::
make_shared
<
ngraph
::
op
::
Convert
>
(
data
,
elem_type
)};
}
...
...
src/ngraph/frontend/onnx_import/op/eye_like.cpp
0 → 100644
View file @
24c715f4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "eye_like.hpp"
#include "exceptions.hpp"
#include "ngraph/frontend/onnx_import/utils/common.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
eye_like
(
const
Node
&
node
)
{
const
auto
input
=
node
.
get_ng_inputs
().
at
(
0
);
const
auto
&
input_shape
=
input
->
get_shape
();
std
::
int64_t
dtype
;
element
::
Type
target_type
;
std
::
int64_t
shift
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"k"
,
0
);
if
(
node
.
has_attribute
(
"dtype"
))
{
dtype
=
node
.
get_attribute_value
<
std
::
int64_t
>
(
"dtype"
);
target_type
=
common
::
get_ngraph_element_type
(
dtype
);
}
else
{
target_type
=
input
->
get_element_type
();
}
ASSERT_VALID_ARGUMENT
(
node
,
input_shape
.
size
()
==
2
)
<<
"The provided shape rank: "
<<
input_shape
.
size
()
<<
" is unsupported, only 2D shapes are supported"
;
std
::
shared_ptr
<
ngraph
::
Node
>
eye_like_matrix
=
common
::
shifted_square_identity
(
input_shape
,
target_type
,
shift
);
return
{
eye_like_matrix
};
}
}
// namespace set_1
}
// namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/op/eye_like.hpp
0 → 100644
View file @
24c715f4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
op
{
namespace
set_1
{
NodeVector
eye_like
(
const
Node
&
node
);
}
// namespace set_1
}
//namespace op
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/ops_bridge.cpp
View file @
24c715f4
...
...
@@ -53,6 +53,7 @@
#include "op/equal.hpp"
#include "op/erf.hpp"
#include "op/exp.hpp"
#include "op/eye_like.hpp"
#include "op/flatten.hpp"
#include "op/floor.hpp"
#include "op/gather.hpp"
...
...
@@ -260,6 +261,7 @@ namespace ngraph
REGISTER_OPERATOR
(
"Equal"
,
1
,
equal
);
REGISTER_OPERATOR
(
"Erf"
,
1
,
erf
);
REGISTER_OPERATOR
(
"Exp"
,
1
,
exp
);
REGISTER_OPERATOR
(
"EyeLike"
,
1
,
eye_like
);
REGISTER_OPERATOR
(
"Flatten"
,
1
,
flatten
);
REGISTER_OPERATOR
(
"Floor"
,
1
,
floor
);
REGISTER_OPERATOR
(
"Gather"
,
1
,
gather
);
...
...
src/ngraph/frontend/onnx_import/utils/common.cpp
0 → 100644
View file @
24c715f4
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <onnx-ml.pb.h> // onnx types
#include "common.hpp"
namespace
ngraph
{
namespace
onnx_import
{
namespace
common
{
const
ngraph
::
element
::
Type
&
get_ngraph_element_type
(
int64_t
onnx_type
)
{
switch
(
onnx_type
)
{
case
onnx
:
:
TensorProto_DataType_BOOL
:
return
element
::
boolean
;
case
onnx
:
:
TensorProto_DataType_DOUBLE
:
return
element
::
f64
;
case
onnx
:
:
TensorProto_DataType_FLOAT16
:
return
element
::
f16
;
case
onnx
:
:
TensorProto_DataType_FLOAT
:
return
element
::
f32
;
case
onnx
:
:
TensorProto_DataType_INT8
:
return
element
::
i8
;
case
onnx
:
:
TensorProto_DataType_INT16
:
return
element
::
i16
;
case
onnx
:
:
TensorProto_DataType_INT32
:
return
element
::
i32
;
case
onnx
:
:
TensorProto_DataType_INT64
:
return
element
::
i64
;
case
onnx
:
:
TensorProto_DataType_UINT8
:
return
element
::
u8
;
case
onnx
:
:
TensorProto_DataType_UINT16
:
return
element
::
u16
;
case
onnx
:
:
TensorProto_DataType_UINT32
:
return
element
::
u32
;
case
onnx
:
:
TensorProto_DataType_UINT64
:
return
element
::
u64
;
case
onnx
:
:
TensorProto_DataType_UNDEFINED
:
return
element
::
dynamic
;
}
throw
ngraph_error
(
"unsupported element type: "
+
onnx
::
TensorProto_DataType_Name
(
static_cast
<
onnx
::
TensorProto_DataType
>
(
onnx_type
)));
}
}
// namespace common
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/frontend/onnx_import/utils/common.hpp
View file @
24c715f4
...
...
@@ -19,6 +19,7 @@
#include <algorithm> // std::generate
#include <cmath> // std::floor, std::min
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <iterator> // std::begin, std::end
#include <memory> // std::shared_ptr, std::make_shared
#include <type_traits> // std::enable_if
...
...
@@ -27,6 +28,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/broadcasting.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace
ngraph
{
...
...
@@ -34,6 +36,8 @@ namespace ngraph
{
namespace
common
{
const
ngraph
::
element
::
Type
&
get_ngraph_element_type
(
std
::
int64_t
onnx_type
);
/// \brief Return a monotonic sequence.
///
/// \note Limitations: this function may not work for very large integer values
...
...
@@ -86,6 +90,43 @@ namespace ngraph
}
}
/// \brief Creates a shifted square identity matrix.
/// \note Shifting in the context of this operator means that
/// the matrix can be created with elements equal to 1 not only in the main diagonal.
/// Shifting adds an offset and moves the diagonal up or down
///
/// \param[in] output_shape Shape of the resulting matrix.
/// \param[in] output_type Element type of the resulting matrix.
/// \param[in] shift Shifting of diagonal.
///
/// \return A Constant node representing shifted identity matrix.
template
<
typename
T
=
double
>
std
::
shared_ptr
<
ngraph
::
op
::
Constant
>
shifted_square_identity
(
const
Shape
output_shape
,
const
element
::
Type
&
output_type
,
const
std
::
int64_t
shift
)
{
std
::
vector
<
T
>
identity_matrix
(
shape_size
(
output_shape
),
T
{
0
});
std
::
int64_t
rows
=
output_shape
[
0
];
std
::
int64_t
cols
=
output_shape
[
1
];
for
(
std
::
int64_t
row
=
0
;
row
<
rows
;
++
row
)
{
const
std
::
int64_t
diagonal_element_idx
=
(
row
*
cols
)
+
row
+
shift
;
if
(
row
+
shift
<
0
)
{
continue
;
}
else
if
(
row
+
shift
>=
cols
)
{
break
;
}
identity_matrix
.
at
(
diagonal_element_idx
)
=
T
{
1
};
}
return
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
output_type
,
output_shape
,
identity_matrix
);
}
/// \brief Creates a square identity matrix.
///
/// \param[in] n Order of the resulting matrix.
...
...
@@ -95,16 +136,9 @@ namespace ngraph
std
::
shared_ptr
<
ngraph
::
op
::
Constant
>
square_identity
(
const
size_t
n
,
const
element
::
Type
&
type
)
{
std
::
vector
<
T
>
identity_matrix
(
n
*
n
,
T
{
0
});
for
(
size_t
row
=
0
;
row
<
n
;
++
row
)
{
const
size_t
diagonal_element
=
(
n
*
row
)
+
row
;
identity_matrix
.
at
(
diagonal_element
)
=
T
{
1
};
}
return
std
::
make_shared
<
ngraph
::
op
::
Constant
>
(
type
,
Shape
{{
n
,
n
}},
identity_matrix
);
return
shifted_square_identity
(
Shape
{
n
,
n
},
type
,
0
);
}
}
// namespace common
}
// namespace onnx_import
}
// namespace ngraph
src/ngraph/node.hpp
View file @
24c715f4
...
...
@@ -216,6 +216,7 @@ namespace ngraph
virtual
bool
is_op
()
const
{
return
false
;
}
virtual
bool
is_commutative
()
{
return
false
;
}
virtual
bool
is_dynamic
()
const
;
virtual
bool
has_state
()
const
{
return
false
;
}
size_t
get_instance_id
()
const
{
return
m_instance_id
;
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Node
&
);
virtual
std
::
ostream
&
write_short_description
(
std
::
ostream
&
)
const
;
...
...
src/ngraph/op/experimental/generate_mask.hpp
View file @
24c715f4
...
...
@@ -72,6 +72,8 @@ namespace ngraph
/// \brief Returns the seed value supplied to a random generator
uint64_t
get_seed
()
const
{
return
m_seed
;
}
bool
get_use_seed
()
const
{
return
m_use_seed
;
}
/// GenerateMask has state.
bool
has_state
()
const
override
{
return
true
;
}
protected
:
virtual
void
generate_adjoints
(
autodiff
::
Adjoints
&
adjoints
,
const
NodeVector
&
deltas
)
override
...
...
src/ngraph/op/fused_op_tbl.hpp
View file @
24c715f4
...
...
@@ -24,8 +24,8 @@ NGRAPH_OP(ConvolutionBiasBackpropFiltersBias, ngraph::op)
NGRAPH_OP
(
DepthToSpace
,
ngraph
::
op
)
NGRAPH_OP
(
Elu
,
ngraph
::
op
)
NGRAPH_OP
(
FakeQuantize
,
ngraph
::
op
)
NGRAPH_OP
(
GRN
,
ngraph
::
op
)
NGRAPH_OP
(
Gemm
,
ngraph
::
op
)
NGRAPH_OP
(
GRN
,
ngraph
::
op
)
NGRAPH_OP
(
GroupConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
GroupConvolutionTranspose
,
ngraph
::
op
)
NGRAPH_OP
(
HardSigmoid
,
ngraph
::
op
)
...
...
@@ -35,9 +35,9 @@ NGRAPH_OP(MVN, ngraph::op)
NGRAPH_OP
(
Normalize
,
ngraph
::
op
)
NGRAPH_OP
(
PRelu
,
ngraph
::
op
)
NGRAPH_OP
(
ScaleShift
,
ngraph
::
op
)
NGRAPH_OP
(
SpaceToDepth
,
ngraph
::
op
)
NGRAPH_OP
(
ShuffleChannels
,
ngraph
::
op
)
NGRAPH_OP
(
SpaceToDepth
,
ngraph
::
op
)
NGRAPH_OP
(
Split
,
ngraph
::
op
)
NGRAPH_OP
(
SquaredDifference
,
ngraph
::
op
)
NGRAPH_OP
(
Squeeze
,
ngraph
::
op
)
NGRAPH_OP
(
Split
,
ngraph
::
op
)
NGRAPH_OP
(
Unsqueeze
,
ngraph
::
op
)
src/ngraph/op/op_tbl.hpp
View file @
24c715f4
...
...
@@ -81,11 +81,12 @@ NGRAPH_OP(Cos, ngraph::op)
NGRAPH_OP
(
Cosh
,
ngraph
::
op
)
NGRAPH_OP
(
Dequantize
,
ngraph
::
op
)
NGRAPH_OP
(
Divide
,
ngraph
::
op
)
NGRAPH_OP
(
DynBroadcast
,
ngraph
::
op
)
NGRAPH_OP
(
Dot
,
ngraph
::
op
)
NGRAPH_OP
(
DynBroadcast
,
ngraph
::
op
)
NGRAPH_OP
(
DynPad
,
ngraph
::
op
)
NGRAPH_OP
(
DynReshape
,
ngraph
::
op
)
NGRAPH_OP
(
DynSlice
,
ngraph
::
op
)
NGRAPH_OP
(
EmbeddingLookup
,
ngraph
::
op
)
NGRAPH_OP
(
Equal
,
ngraph
::
op
)
NGRAPH_OP
(
Erf
,
ngraph
::
op
)
NGRAPH_OP
(
Exp
,
ngraph
::
op
)
...
...
@@ -119,13 +120,13 @@ NGRAPH_OP(Power, ngraph::op)
NGRAPH_OP
(
Product
,
ngraph
::
op
)
NGRAPH_OP
(
Quantize
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedAvgPool
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBiasAdd
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionBiasSignedAdd
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolutionRelu
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedConvolution
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDotBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDot
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedDotBias
,
ngraph
::
op
)
NGRAPH_OP
(
QuantizedMaxPool
,
ngraph
::
op
)
NGRAPH_OP
(
Range
,
ngraph
::
op
)
NGRAPH_OP
(
Relu
,
ngraph
::
op
)
...
...
@@ -153,7 +154,6 @@ NGRAPH_OP(Subtract, ngraph::op)
NGRAPH_OP
(
Sum
,
ngraph
::
op
)
NGRAPH_OP
(
Tan
,
ngraph
::
op
)
NGRAPH_OP
(
Tanh
,
ngraph
::
op
)
NGRAPH_OP
(
TopK
,
ngraph
::
op
)
NGRAPH_OP
(
Tile
,
ngraph
::
op
)
NGRAPH_OP
(
TopK
,
ngraph
::
op
)
NGRAPH_OP
(
Transpose
,
ngraph
::
op
)
NGRAPH_OP
(
EmbeddingLookup
,
ngraph
::
op
)
src/ngraph/pass/fused_op_decomposition.cpp
View file @
24c715f4
...
...
@@ -13,37 +13,52 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/fused_op.hpp"
using
namespace
std
;
using
namespace
ngraph
;
bool
ngraph
::
pass
::
FusedOpDecomposition
::
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
pass
::
FusedOpDecomposition
::
FusedOpDecomposition
(
op_query_t
callback
)
:
m_has_direct_support
{
callback
}
{
}
bool
pass
::
FusedOpDecomposition
::
run_on_node
(
shared_ptr
<
Node
>
node
)
{
bool
modified
=
false
;
if
(
auto
fused_op
=
std
::
dynamic_pointer_cast
<
ngraph
::
op
::
util
::
FusedOp
>
(
node
))
if
(
auto
fused_op
=
dynamic_pointer_cast
<
op
::
util
::
FusedOp
>
(
node
))
{
if
(
m_
callback
&&
m_callback
(
*
node
))
if
(
m_
has_direct_support
&&
m_has_direct_support
(
*
node
))
{
// Op supported by backend. Do not decompose
return
modified
;
}
auto
subgraph_outputs
=
fused_op
->
decompose_op
();
// Run recursively untill no more fused ops
auto
subgraph
=
extract_subgraph
(
subgraph_outputs
,
fused_op
->
get_arguments
());
for
(
auto
subgraph_node
:
subgraph
)
{
if
(
auto
nested_fused_op
=
dynamic_pointer_cast
<
op
::
util
::
FusedOp
>
(
subgraph_node
))
{
if
(
!
(
m_has_direct_support
&&
m_has_direct_support
(
*
nested_fused_op
)))
{
run_on_node
(
nested_fused_op
);
}
}
}
size_t
i
=
0
;
for
(
auto
output_node
:
subgraph_outputs
)
{
for
(
size_t
j
=
0
;
j
<
output_node
->
get_outputs
().
size
();
j
++
,
i
++
)
{
// TODO: Provenance
std
::
set
<
ngraph
::
descriptor
::
Input
*>
fop_users
{
begin
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
()),
end
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
())};
set
<
descriptor
::
Input
*>
fop_users
{
begin
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
()),
end
(
fused_op
->
get_outputs
().
at
(
i
).
get_inputs
())};
for
(
auto
fop_user
:
fop_users
)
{
if
(
auto
goe
=
...
...
@@ -52,7 +67,7 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
if
(
goe
->
get_n
()
==
i
&&
!
goe
->
get_output_inputs
(
0
).
empty
())
{
// Replace GOE users
s
td
::
set
<
ngraph
::
descriptor
::
Input
*>
goe_users
{
s
et
<
descriptor
::
Input
*>
goe_users
{
begin
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
()),
end
(
goe
->
get_outputs
().
at
(
0
).
get_inputs
())};
for
(
auto
goe_user
:
goe_users
)
...
...
@@ -80,8 +95,3 @@ bool ngraph::pass::FusedOpDecomposition::run_on_node(std::shared_ptr<ngraph::Nod
return
modified
;
}
pass
::
FusedOpDecomposition
::
FusedOpDecomposition
(
op_query_t
callback
)
:
m_callback
{
callback
}
{
}
src/ngraph/pass/fused_op_decomposition.hpp
View file @
24c715f4
...
...
@@ -16,6 +16,9 @@
#pragma once
#include <memory>
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/pass/pass.hpp"
namespace
ngraph
...
...
@@ -25,13 +28,24 @@ namespace ngraph
class
FusedOpDecomposition
:
public
NodePass
{
public
:
/// \brief Function signature type for callback used to check whether provided node
/// is supported by backend.
using
op_query_t
=
std
::
function
<
bool
(
const
Node
&
node
)
>
;
///
/// \brief Constructor for the Fused operation decomposition pass.
///
/// \param[in] callback The function object used to determine whether current backend
/// provide direct support for passed node. Should have signature:
/// bool fn(const Node&)
///
FusedOpDecomposition
(
op_query_t
callback
=
nullptr
);
bool
run_on_node
(
std
::
shared_ptr
<
ngraph
::
Node
>
node
)
override
;
private
:
op_query_t
m_callback
=
nullptr
;
/// \brief A function returning whether provided Node is supported by current backend.
/// The returned bool value is used to control whether decompose operator or not.
op_query_t
m_has_direct_support
=
nullptr
;
};
}
}
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
24c715f4
...
...
@@ -924,7 +924,8 @@ using namespace ngraph::runtime;
// Always enable nodes computing output tensors or nodes whose outputs might get
// overwritten due to inplace kernels
// TODO (jbobba) - Do we need to handle cacheability
if
(
computes_result
(
node
.
get
())
||
possibly_overwritten
(
node
.
get
()))
if
(
computes_result
(
node
.
get
())
||
possibly_overwritten
(
node
.
get
())
||
node
->
has_state
())
{
writer
<<
" || 1"
;
}
...
...
@@ -1187,7 +1188,6 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
{
REGISTER_KNOBBED_PASS
(
CPUFusion
,
true
,
runtime
::
cpu
::
pass
);
}
REGISTER_KNOBBED_PASS
(
CPUQuantFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUHorizontalFusion
,
true
,
runtime
::
cpu
::
pass
);
REGISTER_KNOBBED_PASS
(
CPUCollapseDims
,
true
,
runtime
::
cpu
::
pass
);
...
...
@@ -1437,7 +1437,7 @@ void runtime::cpu::CPU_ExternalFunction::build(ngraph::pass::PassConfig& pass_co
bool
disable_caching
=
(
reuse_memory
&&
!
cacheable
)
// Check cacheability only if we are reusing intermediate tensors
||
computes_result
(
node
.
get
())
||
possibly_overwritten
(
node
.
get
());
||
computes_result
(
node
.
get
())
||
possibly_overwritten
(
node
.
get
())
||
node
->
has_state
()
;
vector
<
reference_wrapper
<
bool
>>
in_stale
,
out_stale
;
for
(
const
auto
&
name
:
in_names
)
...
...
src/ngraph/serializer.cpp
View file @
24c715f4
This diff is collapsed.
Click to expand it.
test/models/onnx/eye_like.prototxt
0 → 100644
View file @
24c715f4
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "x"
output: "y"
op_type: "EyeLike"
attribute {
name: "k"
i: -1
type: INT
}
}
name: "hardmax_graph"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
version: 9
}
test/onnx/onnx_import.in.cpp
View file @
24c715f4
...
...
@@ -1482,3 +1482,46 @@ NGRAPH_TEST(onnx_${BACKEND_NAME}, model_shrink_int)
test_case
.
run
();
}
NGRAPH_TEST
(
onnx_
$
{
BACKEND_NAME
},
model_eye_like
)
{
const
auto
eye_like_fn
=
onnx_import
::
import_onnx_model
(
file_util
::
path_join
(
SERIALIZED_ZOO
,
"onnx/eye_like.prototxt"
));
auto
test_case
=
ngraph
::
test
::
NgraphTestCase
(
eye_like_fn
,
"${BACKEND_NAME}"
);
test_case
.
add_input
<
float
>
({
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
});
test_case
.
add_expected_output
<
float
>
(
Shape
{
3
,
4
},
{
0.
f
,
0.
f
,
0.
f
,
0.
f
,
1.
f
,
0.
f
,
0.
f
,
0.
f
,
0.
f
,
1.
f
,
0.
f
,
0.
f
,
});
test_case
.
run
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment