Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6e6c23ff
Commit
6e6c23ff
authored
Aug 07, 2019
by
Robert Kimball
Committed by
Scott Cyphers
Aug 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add method so that element::Type can be cast into enum (#3386)
parent
34ae1ee4
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
64 additions
and
63 deletions
+64
-63
compiler.cpp
src/contrib/mlir/compiler.cpp
+1
-1
constant.cpp
src/ngraph/op/constant.cpp
+3
-3
constant.hpp
src/ngraph/op/constant.hpp
+1
-1
range.cpp
src/ngraph/op/experimental/range.cpp
+1
-1
max.cpp
src/ngraph/op/max.cpp
+1
-1
min.cpp
src/ngraph/op/min.cpp
+1
-1
constant_folding.cpp
src/ngraph/pass/constant_folding.cpp
+20
-20
dyn_elimination.cpp
src/ngraph/pass/dyn_elimination.cpp
+1
-1
allreduce.cpp
src/ngraph/runtime/cpu/builder/allreduce.cpp
+1
-1
broadcast_distributed.cpp
src/ngraph/runtime/cpu/builder/broadcast_distributed.cpp
+1
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+1
-1
gcpu_executable.cpp
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
+1
-1
gcpu_executable.hpp
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
+7
-9
intelgpu_layout.cpp
src/ngraph/runtime/intelgpu/intelgpu_layout.cpp
+2
-2
intelgpu_op_custom_kernels.cpp
src/ngraph/runtime/intelgpu/intelgpu_op_custom_kernels.cpp
+4
-5
int_executable.cpp
src/ngraph/runtime/interpreter/int_executable.cpp
+1
-1
int_executable.hpp
src/ngraph/runtime/interpreter/int_executable.hpp
+7
-9
element_type.hpp
src/ngraph/type/element_type.hpp
+7
-1
benchmark_utils.cpp
src/tools/nbench/benchmark_utils.cpp
+1
-1
test_case.cpp
test/util/test_case.cpp
+2
-2
No files found.
src/contrib/mlir/compiler.cpp
View file @
6e6c23ff
...
@@ -193,7 +193,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
...
@@ -193,7 +193,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
ngraph
:
:
element
::
Type_t
::
undefined
:
case
ngraph
:
:
element
::
Type_t
::
undefined
:
case
ngraph
:
:
element
::
Type_t
::
dynamic
:
case
ngraph
:
:
element
::
Type_t
::
dynamic
:
...
...
src/ngraph/op/constant.cpp
View file @
6e6c23ff
...
@@ -59,7 +59,7 @@ string op::Constant::convert_value_to_string(size_t index) const
...
@@ -59,7 +59,7 @@ string op::Constant::convert_value_to_string(size_t index) const
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
{
case
element
:
:
Type_t
::
boolean
:
rc
=
to_string
(
get_vector
<
char
>
()[
index
]);
break
;
case
element
:
:
Type_t
::
boolean
:
rc
=
to_string
(
get_vector
<
char
>
()[
index
]);
break
;
case
element
:
:
Type_t
::
bf16
:
case
element
:
:
Type_t
::
bf16
:
...
@@ -96,7 +96,7 @@ vector<string> op::Constant::get_value_strings() const
...
@@ -96,7 +96,7 @@ vector<string> op::Constant::get_value_strings() const
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
boolean
:
for
(
int
value
:
get_vector
<
char
>
())
for
(
int
value
:
get_vector
<
char
>
())
...
@@ -292,7 +292,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
...
@@ -292,7 +292,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
i8
:
case
element
:
:
Type_t
::
i8
:
...
...
src/ngraph/op/constant.hpp
View file @
6e6c23ff
...
@@ -289,7 +289,7 @@ namespace ngraph
...
@@ -289,7 +289,7 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
target_type
.
get_type_enum
()
)
switch
(
target_type
)
{
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
boolean
:
write_buffer
<
char
,
T
>
(
target
,
source
,
target_element_count
);
write_buffer
<
char
,
T
>
(
target
,
source
,
target_element_count
);
...
...
src/ngraph/op/experimental/range.cpp
View file @
6e6c23ff
...
@@ -219,7 +219,7 @@ void op::Range::validate_and_infer_types()
...
@@ -219,7 +219,7 @@ void op::Range::validate_and_infer_types()
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
result_et
.
get_type_enum
()
)
switch
(
result_et
)
{
{
case
element
:
:
Type_t
::
bf16
:
result_shape
=
infer_output_shape
<
bfloat16
>
(
this
,
result_et
);
break
;
case
element
:
:
Type_t
::
bf16
:
result_shape
=
infer_output_shape
<
bfloat16
>
(
this
,
result_et
);
break
;
case
element
:
:
Type_t
::
f16
:
result_shape
=
infer_output_shape
<
float16
>
(
this
,
result_et
);
break
;
case
element
:
:
Type_t
::
f16
:
result_shape
=
infer_output_shape
<
float16
>
(
this
,
result_et
);
break
;
...
...
src/ngraph/op/max.cpp
View file @
6e6c23ff
...
@@ -42,7 +42,7 @@ shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const
...
@@ -42,7 +42,7 @@ shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const
shared_ptr
<
Node
>
op
::
Max
::
get_default_value
()
const
shared_ptr
<
Node
>
op
::
Max
::
get_default_value
()
const
{
{
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
boolean
:
return
make_constant_from_string
(
"0"
,
get_element_type
(),
get_shape
());
return
make_constant_from_string
(
"0"
,
get_element_type
(),
get_shape
());
...
...
src/ngraph/op/min.cpp
View file @
6e6c23ff
...
@@ -42,7 +42,7 @@ shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const
...
@@ -42,7 +42,7 @@ shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const
shared_ptr
<
Node
>
op
::
Min
::
get_default_value
()
const
shared_ptr
<
Node
>
op
::
Min
::
get_default_value
()
const
{
{
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
boolean
:
return
make_constant_from_string
(
"1"
,
get_element_type
(),
get_shape
());
return
make_constant_from_string
(
"1"
,
get_element_type
(),
get_shape
());
...
...
src/ngraph/pass/constant_folding.cpp
View file @
6e6c23ff
...
@@ -185,7 +185,7 @@ void pass::ConstantFolding::construct_constant_reshape()
...
@@ -185,7 +185,7 @@ void pass::ConstantFolding::construct_constant_reshape()
std
::
shared_ptr
<
Node
>
replacement
;
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
constant_match
->
get_element_type
();
auto
type
=
constant_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
NGRAPH_CHECK
(
false
,
...
@@ -316,7 +316,7 @@ void pass::ConstantFolding::construct_constant_pad()
...
@@ -316,7 +316,7 @@ void pass::ConstantFolding::construct_constant_pad()
std
::
shared_ptr
<
Node
>
replacement
;
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
constant_match
->
get_element_type
();
auto
type
=
constant_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_pad_callback"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_pad_callback"
);
...
@@ -418,7 +418,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
...
@@ -418,7 +418,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
std
::
shared_ptr
<
Node
>
replacement
;
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
dyn_reshape_match
->
get_element_type
();
auto
type
=
dyn_reshape_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
NGRAPH_CHECK
(
false
,
...
@@ -532,7 +532,7 @@ void pass::ConstantFolding::construct_constant_transpose()
...
@@ -532,7 +532,7 @@ void pass::ConstantFolding::construct_constant_transpose()
std
::
shared_ptr
<
Node
>
replacement
;
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
transpose_match
->
get_element_type
();
auto
type
=
transpose_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
NGRAPH_CHECK
(
false
,
...
@@ -664,7 +664,7 @@ void pass::ConstantFolding::construct_constant_broadcast()
...
@@ -664,7 +664,7 @@ void pass::ConstantFolding::construct_constant_broadcast()
std
::
shared_ptr
<
Node
>
replacement
;
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
broadcast_match
->
get_element_type
();
auto
type
=
broadcast_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
NGRAPH_CHECK
(
false
,
...
@@ -774,7 +774,7 @@ void pass::ConstantFolding::construct_constant_dyn_broadcast()
...
@@ -774,7 +774,7 @@ void pass::ConstantFolding::construct_constant_dyn_broadcast()
std
::
shared_ptr
<
Node
>
replacement
;
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
dyn_broadcast_match
->
get_output_element_type
(
0
);
auto
type
=
dyn_broadcast_match
->
get_output_element_type
(
0
);
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
NGRAPH_CHECK
(
false
,
...
@@ -1079,7 +1079,7 @@ shared_ptr<op::Constant> fold_constant_binary_helper(const element::Type& et_out
...
@@ -1079,7 +1079,7 @@ shared_ptr<op::Constant> fold_constant_binary_helper(const element::Type& et_out
shared_ptr
<
Node
>
binary
,
shared_ptr
<
Node
>
binary
,
NodeExecutorTy
func
)
NodeExecutorTy
func
)
{
{
switch
(
et_out
.
get_type_enum
()
)
switch
(
et_out
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_binary_callback"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_binary_callback"
);
...
@@ -1159,7 +1159,7 @@ void pass::ConstantFolding::construct_constant_binary()
...
@@ -1159,7 +1159,7 @@ void pass::ConstantFolding::construct_constant_binary()
std
::
shared_ptr
<
Node
>
replacement
;
std
::
shared_ptr
<
Node
>
replacement
;
auto
in_type
=
a_match
->
get_output_element_type
(
0
);
auto
in_type
=
a_match
->
get_output_element_type
(
0
);
auto
out_type
=
binary_match
->
get_output_element_type
(
0
);
auto
out_type
=
binary_match
->
get_output_element_type
(
0
);
switch
(
in_type
.
get_type_enum
()
)
switch
(
in_type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_binary_callback"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_binary_callback"
);
...
@@ -1355,7 +1355,7 @@ void pass::ConstantFolding::construct_constant_unary()
...
@@ -1355,7 +1355,7 @@ void pass::ConstantFolding::construct_constant_unary()
std
::
shared_ptr
<
Node
>
replacement
;
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
constant_match
->
get_element_type
();
auto
type
=
constant_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_unary_callback"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_unary_callback"
);
...
@@ -1598,7 +1598,7 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant>
...
@@ -1598,7 +1598,7 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant>
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
output_element_type
.
get_type_enum
()
)
switch
(
output_element_type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_convert"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_convert"
);
...
@@ -1655,7 +1655,7 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
...
@@ -1655,7 +1655,7 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
input_element_type
.
get_type_enum
()
)
switch
(
input_element_type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_convert"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_convert"
);
...
@@ -1788,7 +1788,7 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
...
@@ -1788,7 +1788,7 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
input_element_type
.
get_type_enum
()
)
switch
(
input_element_type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_convert"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_convert"
);
...
@@ -1912,7 +1912,7 @@ static shared_ptr<op::Constant>
...
@@ -1912,7 +1912,7 @@ static shared_ptr<op::Constant>
{
{
auto
&
input_element_type
=
constant
->
get_output_element_type
(
0
);
auto
&
input_element_type
=
constant
->
get_output_element_type
(
0
);
switch
(
input_element_type
.
get_type_enum
()
)
switch
(
input_element_type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
NGRAPH_CHECK
(
false
,
...
@@ -2113,7 +2113,7 @@ void pass::ConstantFolding::construct_constant_concat()
...
@@ -2113,7 +2113,7 @@ void pass::ConstantFolding::construct_constant_concat()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
concat_node
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
concat_node
->
get_output_element_type
(
0
))
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_concat"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_concat"
);
...
@@ -2199,7 +2199,7 @@ static shared_ptr<op::Constant> fold_constant_gather(const shared_ptr<op::Consta
...
@@ -2199,7 +2199,7 @@ static shared_ptr<op::Constant> fold_constant_gather(const shared_ptr<op::Consta
{
{
auto
indices_type
=
indices
->
get_output_element_type
(
0
);
auto
indices_type
=
indices
->
get_output_element_type
(
0
);
switch
(
indices_type
.
get_type_enum
()
)
switch
(
indices_type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_gather_callback"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_gather_callback"
);
...
@@ -2255,7 +2255,7 @@ void pass::ConstantFolding::construct_constant_gather()
...
@@ -2255,7 +2255,7 @@ void pass::ConstantFolding::construct_constant_gather()
std
::
shared_ptr
<
Node
>
replacement
;
std
::
shared_ptr
<
Node
>
replacement
;
auto
data_type
=
data
->
get_output_element_type
(
0
);
auto
data_type
=
data
->
get_output_element_type
(
0
);
auto
indices_type
=
indices
->
get_output_element_type
(
0
);
auto
indices_type
=
indices
->
get_output_element_type
(
0
);
switch
(
data_type
.
get_type_enum
()
)
switch
(
data_type
)
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_gather_callback"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_gather_callback"
);
...
@@ -2351,7 +2351,7 @@ void pass::ConstantFolding::construct_constant_slice()
...
@@ -2351,7 +2351,7 @@ void pass::ConstantFolding::construct_constant_slice()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
slice
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
slice
->
get_output_element_type
(
0
))
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_slice"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_slice"
);
...
@@ -2489,7 +2489,7 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
...
@@ -2489,7 +2489,7 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
dyn_slice
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
dyn_slice
->
get_output_element_type
(
0
))
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_dyn_slice"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_dyn_slice"
);
...
@@ -2600,7 +2600,7 @@ void pass::ConstantFolding::construct_constant_range()
...
@@ -2600,7 +2600,7 @@ void pass::ConstantFolding::construct_constant_range()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
range
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
range
->
get_output_element_type
(
0
))
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_range_callback"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_range_callback"
);
...
@@ -2700,7 +2700,7 @@ void pass::ConstantFolding::construct_constant_select()
...
@@ -2700,7 +2700,7 @@ void pass::ConstantFolding::construct_constant_select()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
select
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
select
->
get_output_element_type
(
0
))
{
{
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_select_callback"
);
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_select_callback"
);
...
...
src/ngraph/pass/dyn_elimination.cpp
View file @
6e6c23ff
...
@@ -390,7 +390,7 @@ void pass::DynElimination::construct_range()
...
@@ -390,7 +390,7 @@ void pass::DynElimination::construct_range()
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
et
.
get_type_enum
()
)
switch
(
et
)
{
{
case
element
:
:
Type_t
::
bf16
:
case
element
:
:
Type_t
::
bf16
:
replacement
=
make_range_replacement
<
bfloat16
>
(
et
,
shape
,
start_arg
,
step_arg
);
replacement
=
make_range_replacement
<
bfloat16
>
(
et
,
shape
,
start_arg
,
step_arg
);
...
...
src/ngraph/runtime/cpu/builder/allreduce.cpp
View file @
6e6c23ff
...
@@ -36,7 +36,7 @@ namespace ngraph
...
@@ -36,7 +36,7 @@ namespace ngraph
auto
arg_buffer_index
=
external_function
->
get_buffer_index
(
args
[
0
].
get_name
());
auto
arg_buffer_index
=
external_function
->
get_buffer_index
(
args
[
0
].
get_name
());
auto
out_buffer_index
=
external_function
->
get_buffer_index
(
out
[
0
].
get_name
());
auto
out_buffer_index
=
external_function
->
get_buffer_index
(
out
[
0
].
get_name
());
auto
count
=
static_cast
<
int
>
(
out
[
0
].
get_size
());
auto
count
=
static_cast
<
int
>
(
out
[
0
].
get_size
());
auto
data_type
=
args
[
0
].
get_element_type
()
.
get_type_enum
()
;
auto
data_type
=
args
[
0
].
get_element_type
();
const
ngraph
::
op
::
AllReduce
*
allreduce
=
const
ngraph
::
op
::
AllReduce
*
allreduce
=
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
node
);
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
node
);
auto
reduce_type
=
allreduce
->
get_reduce_type
();
auto
reduce_type
=
allreduce
->
get_reduce_type
();
...
...
src/ngraph/runtime/cpu/builder/broadcast_distributed.cpp
View file @
6e6c23ff
...
@@ -33,7 +33,7 @@ namespace ngraph
...
@@ -33,7 +33,7 @@ namespace ngraph
auto
arg_buffer_index
=
external_function
->
get_buffer_index
(
args
[
0
].
get_name
());
auto
arg_buffer_index
=
external_function
->
get_buffer_index
(
args
[
0
].
get_name
());
auto
count
=
static_cast
<
int
>
(
args
[
0
].
get_size
());
auto
count
=
static_cast
<
int
>
(
args
[
0
].
get_size
());
auto
data_type
=
args
[
0
].
get_element_type
()
.
get_type_enum
()
;
auto
data_type
=
args
[
0
].
get_element_type
();
auto
broadcast
=
static_cast
<
const
ngraph
::
op
::
BroadcastDistributed
*>
(
node
);
auto
broadcast
=
static_cast
<
const
ngraph
::
op
::
BroadcastDistributed
*>
(
node
);
auto
root_id
=
broadcast
->
get_root_id
();
auto
root_id
=
broadcast
->
get_root_id
();
auto
functor
=
[
&
,
count
,
data_type
,
arg_buffer_index
,
root_id
](
auto
functor
=
[
&
,
count
,
data_type
,
arg_buffer_index
,
root_id
](
...
...
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
6e6c23ff
...
@@ -1261,7 +1261,7 @@ static void dump_one_kernel_with_type(runtime::cpu::CPU_DebugTracer& debug_trace
...
@@ -1261,7 +1261,7 @@ static void dump_one_kernel_with_type(runtime::cpu::CPU_DebugTracer& debug_trace
const
std
::
string
&
tensor_name
,
const
std
::
string
&
tensor_name
,
const
std
::
string
&
in_out
)
const
std
::
string
&
in_out
)
{
{
switch
(
t_attrs
.
m_type_of_element
.
get_type_enum
()
)
switch
(
t_attrs
.
m_type_of_element
)
{
{
case
element
:
:
Type_t
::
f32
:
case
element
:
:
Type_t
::
f32
:
debug_tracer
.
dump_one_tensor
<
float
>
(
kernel_name
,
debug_tracer
.
dump_one_tensor
<
float
>
(
kernel_name
,
...
...
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
View file @
6e6c23ff
...
@@ -213,7 +213,7 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
...
@@ -213,7 +213,7 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
{
{
stringstream
ss
;
stringstream
ss
;
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
boolean
:
op_engine
<
char
>
(
op
,
out
,
in
);
break
;
case
element
:
:
Type_t
::
boolean
:
op_engine
<
char
>
(
op
,
out
,
in
);
break
;
case
element
:
:
Type_t
::
f32
:
op_engine
<
float
>
(
op
,
out
,
in
);
break
;
case
element
:
:
Type_t
::
f32
:
op_engine
<
float
>
(
op
,
out
,
in
);
break
;
...
...
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
View file @
6e6c23ff
...
@@ -267,7 +267,7 @@ private:
...
@@ -267,7 +267,7 @@ private:
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
&
node
);
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
&
node
);
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
allreduce
->
get_reduce_type
(),
allreduce
->
get_reduce_type
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
break
;
break
;
...
@@ -504,7 +504,7 @@ private:
...
@@ -504,7 +504,7 @@ private:
{
{
reference
::
broadcastdistributed
<
T
>
(
reference
::
broadcastdistributed
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
root_id
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
T
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
T
);
...
@@ -514,7 +514,7 @@ private:
...
@@ -514,7 +514,7 @@ private:
{
{
reference
::
broadcastdistributed
<
T
>
(
reference
::
broadcastdistributed
<
T
>
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
root_id
);
}
}
...
@@ -559,7 +559,7 @@ private:
...
@@ -559,7 +559,7 @@ private:
element
::
Type
type
=
node
.
get_element_type
();
element
::
Type
type
=
node
.
get_element_type
();
std
::
stringstream
ss
;
std
::
stringstream
ss
;
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
boolean
:
reference
::
convert_to_bool
<
T
>
(
reference
::
convert_to_bool
<
T
>
(
...
@@ -1300,10 +1300,8 @@ private:
...
@@ -1300,10 +1300,8 @@ private:
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Recv
*>
(
&
node
);
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Recv
*>
(
&
node
);
int
src_id
=
op
->
get_src_id
();
int
src_id
=
op
->
get_src_id
();
reference
::
recv
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
reference
::
recv
<
T
>
(
node
.
get_input_element_type
(
0
).
get_type_enum
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
),
element_count
,
src_id
);
element_count
,
src_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
break
;
...
@@ -1467,7 +1465,7 @@ private:
...
@@ -1467,7 +1465,7 @@ private:
int
dest_id
=
op
->
get_dest_id
();
int
dest_id
=
op
->
get_dest_id
();
reference
::
send
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
reference
::
send
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
element_count
,
element_count
,
dest_id
);
dest_id
);
...
...
src/ngraph/runtime/intelgpu/intelgpu_layout.cpp
View file @
6e6c23ff
...
@@ -54,7 +54,7 @@ bool runtime::intelgpu::IntelGPULayout::
...
@@ -54,7 +54,7 @@ bool runtime::intelgpu::IntelGPULayout::
cldnn
::
data_types
cldnn
::
data_types
runtime
::
intelgpu
::
IntelGPULayout
::
get_cldnn_type
(
const
element
::
Type
&
element_type
)
runtime
::
intelgpu
::
IntelGPULayout
::
get_cldnn_type
(
const
element
::
Type
&
element_type
)
{
{
switch
(
element_type
.
get_type_enum
()
)
switch
(
element_type
)
{
{
case
element
:
:
Type_t
::
i8
:
case
element
:
:
Type_t
::
i8
:
case
element
:
:
Type_t
::
boolean
:
return
cldnn
::
data_types
::
i8
;
case
element
:
:
Type_t
::
boolean
:
return
cldnn
::
data_types
::
i8
;
...
@@ -118,7 +118,7 @@ cldnn::layout runtime::intelgpu::IntelGPULayout::create_cldnn_layout(
...
@@ -118,7 +118,7 @@ cldnn::layout runtime::intelgpu::IntelGPULayout::create_cldnn_layout(
const
cldnn
::
tensor
tensor
=
create_cldnn_tensor
(
element_shape
);
const
cldnn
::
tensor
tensor
=
create_cldnn_tensor
(
element_shape
);
cldnn
::
data_types
data_type
;
cldnn
::
data_types
data_type
;
switch
(
element_type
.
get_type_enum
()
)
switch
(
element_type
)
{
{
case
element
:
:
Type_t
::
i16
:
case
element
:
:
Type_t
::
i16
:
case
element
:
:
Type_t
::
u16
:
case
element
:
:
Type_t
::
u16
:
...
...
src/ngraph/runtime/intelgpu/intelgpu_op_custom_kernels.cpp
View file @
6e6c23ff
...
@@ -33,7 +33,7 @@ using namespace ngraph::runtime::intelgpu;
...
@@ -33,7 +33,7 @@ using namespace ngraph::runtime::intelgpu;
string
runtime
::
intelgpu
::
get_opencl_type_name
(
const
element
::
Type
&
ngraph_type
)
string
runtime
::
intelgpu
::
get_opencl_type_name
(
const
element
::
Type
&
ngraph_type
)
{
{
switch
(
ngraph_type
.
get_type_enum
()
)
switch
(
ngraph_type
)
{
{
case
element
:
:
Type_t
::
i64
:
return
"long"
;
case
element
:
:
Type_t
::
i64
:
return
"long"
;
case
element
:
:
Type_t
::
u64
:
return
"ulong"
;
case
element
:
:
Type_t
::
u64
:
return
"ulong"
;
...
@@ -52,7 +52,7 @@ string runtime::intelgpu::get_opencl_type_name(const element::Type& ngraph_type)
...
@@ -52,7 +52,7 @@ string runtime::intelgpu::get_opencl_type_name(const element::Type& ngraph_type)
string
runtime
::
intelgpu
::
get_opencl_type_min_max_value
(
const
element
::
Type
&
ngraph_type
,
string
runtime
::
intelgpu
::
get_opencl_type_min_max_value
(
const
element
::
Type
&
ngraph_type
,
bool
is_min
)
bool
is_min
)
{
{
switch
(
ngraph_type
.
get_type_enum
()
)
switch
(
ngraph_type
)
{
{
case
element
:
:
Type_t
::
f32
:
return
is_min
?
"-INFINITY"
:
"INFINITY"
;
case
element
:
:
Type_t
::
f32
:
return
is_min
?
"-INFINITY"
:
"INFINITY"
;
case
element
:
:
Type_t
::
f64
:
return
is_min
?
"-INFINITY"
:
"INFINITY"
;
case
element
:
:
Type_t
::
f64
:
return
is_min
?
"-INFINITY"
:
"INFINITY"
;
...
@@ -1839,9 +1839,8 @@ void runtime::intelgpu::do_convert_operation(cldnn::topology& topology,
...
@@ -1839,9 +1839,8 @@ void runtime::intelgpu::do_convert_operation(cldnn::topology& topology,
{
{
gws
=
generate_loops
(
writer
,
output_shape
,
true
);
gws
=
generate_loops
(
writer
,
output_shape
,
true
);
if
(((
input_type
.
get_type_enum
()
==
element
::
Type_t
::
f64
)
||
if
(((
input_type
==
element
::
Type_t
::
f64
)
||
(
input_type
==
element
::
Type_t
::
f32
))
&&
(
input_type
.
get_type_enum
()
==
element
::
Type_t
::
f32
))
&&
(
output_type
!=
element
::
Type_t
::
boolean
))
(
output_type
.
get_type_enum
()
!=
element
::
Type_t
::
boolean
))
{
{
// this is the workaround for OpenCL to be same as with CPU floating point operations
// this is the workaround for OpenCL to be same as with CPU floating point operations
writer
<<
input_type_name
<<
" input_var = input0"
<<
access_dims
(
output_shape
)
<<
";
\n
"
writer
<<
input_type_name
<<
" input_var = input0"
<<
access_dims
(
output_shape
)
<<
";
\n
"
...
...
src/ngraph/runtime/interpreter/int_executable.cpp
View file @
6e6c23ff
...
@@ -215,7 +215,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
...
@@ -215,7 +215,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
{
{
stringstream
ss
;
stringstream
ss
;
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
boolean
:
op_engine
<
char
>
(
op
,
out
,
in
);
break
;
case
element
:
:
Type_t
::
boolean
:
op_engine
<
char
>
(
op
,
out
,
in
);
break
;
case
element
:
:
Type_t
::
f32
:
op_engine
<
float
>
(
op
,
out
,
in
);
break
;
case
element
:
:
Type_t
::
f32
:
op_engine
<
float
>
(
op
,
out
,
in
);
break
;
...
...
src/ngraph/runtime/interpreter/int_executable.hpp
View file @
6e6c23ff
...
@@ -294,7 +294,7 @@ private:
...
@@ -294,7 +294,7 @@ private:
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
&
node
);
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
&
node
);
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
allreduce
->
get_reduce_type
(),
allreduce
->
get_reduce_type
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
break
;
break
;
...
@@ -530,7 +530,7 @@ private:
...
@@ -530,7 +530,7 @@ private:
{
{
reference
::
broadcastdistributed
<
T
>
(
reference
::
broadcastdistributed
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
root_id
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
T
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
T
);
...
@@ -540,7 +540,7 @@ private:
...
@@ -540,7 +540,7 @@ private:
{
{
reference
::
broadcastdistributed
<
T
>
(
reference
::
broadcastdistributed
<
T
>
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
root_id
);
}
}
...
@@ -585,7 +585,7 @@ private:
...
@@ -585,7 +585,7 @@ private:
element
::
Type
type
=
node
.
get_element_type
();
element
::
Type
type
=
node
.
get_element_type
();
std
::
stringstream
ss
;
std
::
stringstream
ss
;
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
boolean
:
reference
::
convert_to_bool
<
T
>
(
reference
::
convert_to_bool
<
T
>
(
...
@@ -1349,10 +1349,8 @@ private:
...
@@ -1349,10 +1349,8 @@ private:
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Recv
*>
(
&
node
);
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Recv
*>
(
&
node
);
int
src_id
=
op
->
get_src_id
();
int
src_id
=
op
->
get_src_id
();
reference
::
recv
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
reference
::
recv
<
T
>
(
node
.
get_input_element_type
(
0
).
get_type_enum
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
),
element_count
,
src_id
);
element_count
,
src_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
break
;
...
@@ -1516,7 +1514,7 @@ private:
...
@@ -1516,7 +1514,7 @@ private:
int
dest_id
=
op
->
get_dest_id
();
int
dest_id
=
op
->
get_dest_id
();
reference
::
send
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
reference
::
send
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
element_count
,
element_count
,
dest_id
);
dest_id
);
...
...
src/ngraph/type/element_type.hpp
View file @
6e6c23ff
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "ngraph/deprecated.hpp"
#include "ngraph/except.hpp"
#include "ngraph/except.hpp"
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/type/bfloat16.hpp"
#include "ngraph/type/bfloat16.hpp"
...
@@ -73,7 +74,10 @@ namespace ngraph
...
@@ -73,7 +74,10 @@ namespace ngraph
const
std
::
string
&
cname
);
const
std
::
string
&
cname
);
~
Type
()
{}
~
Type
()
{}
Type
&
operator
=
(
const
Type
&
)
=
default
;
Type
&
operator
=
(
const
Type
&
)
=
default
;
Type_t
get_type_enum
()
const
{
return
m_type
;
}
NGRAPH_DEPRECATED
(
"Use operator Type_t()"
)
Type_t
get_type_enum
()
const
{
return
m_type
;
}
const
std
::
string
&
c_type_string
()
const
;
const
std
::
string
&
c_type_string
()
const
;
size_t
size
()
const
;
size_t
size
()
const
;
size_t
hash
()
const
;
size_t
hash
()
const
;
...
@@ -119,6 +123,8 @@ namespace ngraph
...
@@ -119,6 +123,8 @@ namespace ngraph
/// does nothing to dst, and returns false
/// does nothing to dst, and returns false
static
bool
merge
(
element
::
Type
&
dst
,
const
element
::
Type
&
t1
,
const
element
::
Type
&
t2
);
static
bool
merge
(
element
::
Type
&
dst
,
const
element
::
Type
&
t1
,
const
element
::
Type
&
t2
);
// \brief This allows switch(element_type)
operator
Type_t
()
const
{
return
m_type
;
}
private
:
private
:
Type_t
m_type
{
Type_t
::
undefined
};
Type_t
m_type
{
Type_t
::
undefined
};
};
};
...
...
src/tools/nbench/benchmark_utils.cpp
View file @
6e6c23ff
...
@@ -85,7 +85,7 @@ void random_init(shared_ptr<runtime::Tensor> tensor)
...
@@ -85,7 +85,7 @@ void random_init(shared_ptr<runtime::Tensor> tensor)
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
#endif
switch
(
et
.
get_type_enum
()
)
switch
(
et
)
{
{
case
element
:
:
Type_t
::
boolean
:
init_int_tensor
<
char
>
(
tensor
,
0
,
1
);
break
;
case
element
:
:
Type_t
::
boolean
:
init_int_tensor
<
char
>
(
tensor
,
0
,
1
);
break
;
case
element
:
:
Type_t
::
f32
:
init_real_tensor
<
float
>
(
tensor
,
-
1
,
1
);
break
;
case
element
:
:
Type_t
::
f32
:
init_real_tensor
<
float
>
(
tensor
,
-
1
,
1
);
break
;
...
...
test/util/test_case.cpp
View file @
6e6c23ff
...
@@ -39,14 +39,14 @@ void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
...
@@ -39,14 +39,14 @@ void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
auto
result_shape
=
result_tensor
->
get_shape
();
auto
result_shape
=
result_tensor
->
get_shape
();
EXPECT_EQ
(
expected_shape
,
result_shape
);
EXPECT_EQ
(
expected_shape
,
result_shape
);
if
(
m_value_comparators
.
count
(
element_type
.
get_type_enum
()
)
==
0
)
if
(
m_value_comparators
.
count
(
element_type
)
==
0
)
{
{
NGRAPH_FAIL
()
<<
"Please add support for "
<<
element_type
NGRAPH_FAIL
()
<<
"Please add support for "
<<
element_type
<<
" to ngraph::test::NgraphTestCase::run()"
;
<<
" to ngraph::test::NgraphTestCase::run()"
;
}
}
else
else
{
{
auto
values_match
=
m_value_comparators
.
at
(
element_type
.
get_type_enum
()
);
auto
values_match
=
m_value_comparators
.
at
(
element_type
);
EXPECT_TRUE
(
values_match
(
expected_result_constant
,
result_tensor
));
EXPECT_TRUE
(
values_match
(
expected_result_constant
,
result_tensor
));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment