Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6e6c23ff
Commit
6e6c23ff
authored
Aug 07, 2019
by
Robert Kimball
Committed by
Scott Cyphers
Aug 07, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add method so that element::Type can be cast into enum (#3386)
parent
34ae1ee4
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
64 additions
and
63 deletions
+64
-63
compiler.cpp
src/contrib/mlir/compiler.cpp
+1
-1
constant.cpp
src/ngraph/op/constant.cpp
+3
-3
constant.hpp
src/ngraph/op/constant.hpp
+1
-1
range.cpp
src/ngraph/op/experimental/range.cpp
+1
-1
max.cpp
src/ngraph/op/max.cpp
+1
-1
min.cpp
src/ngraph/op/min.cpp
+1
-1
constant_folding.cpp
src/ngraph/pass/constant_folding.cpp
+20
-20
dyn_elimination.cpp
src/ngraph/pass/dyn_elimination.cpp
+1
-1
allreduce.cpp
src/ngraph/runtime/cpu/builder/allreduce.cpp
+1
-1
broadcast_distributed.cpp
src/ngraph/runtime/cpu/builder/broadcast_distributed.cpp
+1
-1
cpu_external_function.cpp
src/ngraph/runtime/cpu/cpu_external_function.cpp
+1
-1
gcpu_executable.cpp
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
+1
-1
gcpu_executable.hpp
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
+7
-9
intelgpu_layout.cpp
src/ngraph/runtime/intelgpu/intelgpu_layout.cpp
+2
-2
intelgpu_op_custom_kernels.cpp
src/ngraph/runtime/intelgpu/intelgpu_op_custom_kernels.cpp
+4
-5
int_executable.cpp
src/ngraph/runtime/interpreter/int_executable.cpp
+1
-1
int_executable.hpp
src/ngraph/runtime/interpreter/int_executable.hpp
+7
-9
element_type.hpp
src/ngraph/type/element_type.hpp
+7
-1
benchmark_utils.cpp
src/tools/nbench/benchmark_utils.cpp
+1
-1
test_case.cpp
test/util/test_case.cpp
+2
-2
No files found.
src/contrib/mlir/compiler.cpp
View file @
6e6c23ff
...
...
@@ -193,7 +193,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
ngraph
:
:
element
::
Type_t
::
undefined
:
case
ngraph
:
:
element
::
Type_t
::
dynamic
:
...
...
src/ngraph/op/constant.cpp
View file @
6e6c23ff
...
...
@@ -59,7 +59,7 @@ string op::Constant::convert_value_to_string(size_t index) const
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
case
element
:
:
Type_t
::
boolean
:
rc
=
to_string
(
get_vector
<
char
>
()[
index
]);
break
;
case
element
:
:
Type_t
::
bf16
:
...
...
@@ -96,7 +96,7 @@ vector<string> op::Constant::get_value_strings() const
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
case
element
:
:
Type_t
::
boolean
:
for
(
int
value
:
get_vector
<
char
>
())
...
...
@@ -292,7 +292,7 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
i8
:
...
...
src/ngraph/op/constant.hpp
View file @
6e6c23ff
...
...
@@ -289,7 +289,7 @@ namespace ngraph
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
target_type
.
get_type_enum
()
)
switch
(
target_type
)
{
case
element
:
:
Type_t
::
boolean
:
write_buffer
<
char
,
T
>
(
target
,
source
,
target_element_count
);
...
...
src/ngraph/op/experimental/range.cpp
View file @
6e6c23ff
...
...
@@ -219,7 +219,7 @@ void op::Range::validate_and_infer_types()
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
result_et
.
get_type_enum
()
)
switch
(
result_et
)
{
case
element
:
:
Type_t
::
bf16
:
result_shape
=
infer_output_shape
<
bfloat16
>
(
this
,
result_et
);
break
;
case
element
:
:
Type_t
::
f16
:
result_shape
=
infer_output_shape
<
float16
>
(
this
,
result_et
);
break
;
...
...
src/ngraph/op/max.cpp
View file @
6e6c23ff
...
...
@@ -42,7 +42,7 @@ shared_ptr<Node> op::Max::copy_with_new_args(const NodeVector& new_args) const
shared_ptr
<
Node
>
op
::
Max
::
get_default_value
()
const
{
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
case
element
:
:
Type_t
::
boolean
:
return
make_constant_from_string
(
"0"
,
get_element_type
(),
get_shape
());
...
...
src/ngraph/op/min.cpp
View file @
6e6c23ff
...
...
@@ -42,7 +42,7 @@ shared_ptr<Node> op::Min::copy_with_new_args(const NodeVector& new_args) const
shared_ptr
<
Node
>
op
::
Min
::
get_default_value
()
const
{
switch
(
get_element_type
()
.
get_type_enum
()
)
switch
(
get_element_type
())
{
case
element
:
:
Type_t
::
boolean
:
return
make_constant_from_string
(
"1"
,
get_element_type
(),
get_shape
());
...
...
src/ngraph/pass/constant_folding.cpp
View file @
6e6c23ff
...
...
@@ -185,7 +185,7 @@ void pass::ConstantFolding::construct_constant_reshape()
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
constant_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
...
...
@@ -316,7 +316,7 @@ void pass::ConstantFolding::construct_constant_pad()
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
constant_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_pad_callback"
);
...
...
@@ -418,7 +418,7 @@ void pass::ConstantFolding::construct_constant_dyn_reshape()
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
dyn_reshape_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
...
...
@@ -532,7 +532,7 @@ void pass::ConstantFolding::construct_constant_transpose()
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
transpose_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
...
...
@@ -664,7 +664,7 @@ void pass::ConstantFolding::construct_constant_broadcast()
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
broadcast_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
...
...
@@ -774,7 +774,7 @@ void pass::ConstantFolding::construct_constant_dyn_broadcast()
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
dyn_broadcast_match
->
get_output_element_type
(
0
);
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
...
...
@@ -1079,7 +1079,7 @@ shared_ptr<op::Constant> fold_constant_binary_helper(const element::Type& et_out
shared_ptr
<
Node
>
binary
,
NodeExecutorTy
func
)
{
switch
(
et_out
.
get_type_enum
()
)
switch
(
et_out
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_binary_callback"
);
...
...
@@ -1159,7 +1159,7 @@ void pass::ConstantFolding::construct_constant_binary()
std
::
shared_ptr
<
Node
>
replacement
;
auto
in_type
=
a_match
->
get_output_element_type
(
0
);
auto
out_type
=
binary_match
->
get_output_element_type
(
0
);
switch
(
in_type
.
get_type_enum
()
)
switch
(
in_type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_binary_callback"
);
...
...
@@ -1355,7 +1355,7 @@ void pass::ConstantFolding::construct_constant_unary()
std
::
shared_ptr
<
Node
>
replacement
;
auto
type
=
constant_match
->
get_element_type
();
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_unary_callback"
);
...
...
@@ -1598,7 +1598,7 @@ shared_ptr<op::Constant> fold_constant_convert_helper0(shared_ptr<op::Constant>
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
output_element_type
.
get_type_enum
()
)
switch
(
output_element_type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_convert"
);
...
...
@@ -1655,7 +1655,7 @@ static shared_ptr<op::Constant> fold_constant_convert(shared_ptr<op::Constant> c
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
input_element_type
.
get_type_enum
()
)
switch
(
input_element_type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_convert"
);
...
...
@@ -1788,7 +1788,7 @@ static shared_ptr<op::Constant> fold_constant_reverse(shared_ptr<op::Constant> c
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
input_element_type
.
get_type_enum
()
)
switch
(
input_element_type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_convert"
);
...
...
@@ -1912,7 +1912,7 @@ static shared_ptr<op::Constant>
{
auto
&
input_element_type
=
constant
->
get_output_element_type
(
0
);
switch
(
input_element_type
.
get_type_enum
()
)
switch
(
input_element_type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
...
...
@@ -2113,7 +2113,7 @@ void pass::ConstantFolding::construct_constant_concat()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
concat_node
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
concat_node
->
get_output_element_type
(
0
))
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_concat"
);
...
...
@@ -2199,7 +2199,7 @@ static shared_ptr<op::Constant> fold_constant_gather(const shared_ptr<op::Consta
{
auto
indices_type
=
indices
->
get_output_element_type
(
0
);
switch
(
indices_type
.
get_type_enum
()
)
switch
(
indices_type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_gather_callback"
);
...
...
@@ -2255,7 +2255,7 @@ void pass::ConstantFolding::construct_constant_gather()
std
::
shared_ptr
<
Node
>
replacement
;
auto
data_type
=
data
->
get_output_element_type
(
0
);
auto
indices_type
=
indices
->
get_output_element_type
(
0
);
switch
(
data_type
.
get_type_enum
()
)
switch
(
data_type
)
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_gather_callback"
);
...
...
@@ -2351,7 +2351,7 @@ void pass::ConstantFolding::construct_constant_slice()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
slice
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
slice
->
get_output_element_type
(
0
))
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_slice"
);
...
...
@@ -2489,7 +2489,7 @@ void pass::ConstantFolding::construct_constant_dyn_slice()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
dyn_slice
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
dyn_slice
->
get_output_element_type
(
0
))
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in fold_constant_dyn_slice"
);
...
...
@@ -2600,7 +2600,7 @@ void pass::ConstantFolding::construct_constant_range()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
range
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
range
->
get_output_element_type
(
0
))
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_range_callback"
);
...
...
@@ -2700,7 +2700,7 @@ void pass::ConstantFolding::construct_constant_select()
std
::
shared_ptr
<
op
::
Constant
>
replacement
;
switch
(
select
->
get_output_element_type
(
0
)
.
get_type_enum
()
)
switch
(
select
->
get_output_element_type
(
0
))
{
case
element
:
:
Type_t
::
undefined
:
NGRAPH_CHECK
(
false
,
"Encountered 'undefined' element type in constant_select_callback"
);
...
...
src/ngraph/pass/dyn_elimination.cpp
View file @
6e6c23ff
...
...
@@ -390,7 +390,7 @@ void pass::DynElimination::construct_range()
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
et
.
get_type_enum
()
)
switch
(
et
)
{
case
element
:
:
Type_t
::
bf16
:
replacement
=
make_range_replacement
<
bfloat16
>
(
et
,
shape
,
start_arg
,
step_arg
);
...
...
src/ngraph/runtime/cpu/builder/allreduce.cpp
View file @
6e6c23ff
...
...
@@ -36,7 +36,7 @@ namespace ngraph
auto
arg_buffer_index
=
external_function
->
get_buffer_index
(
args
[
0
].
get_name
());
auto
out_buffer_index
=
external_function
->
get_buffer_index
(
out
[
0
].
get_name
());
auto
count
=
static_cast
<
int
>
(
out
[
0
].
get_size
());
auto
data_type
=
args
[
0
].
get_element_type
()
.
get_type_enum
()
;
auto
data_type
=
args
[
0
].
get_element_type
();
const
ngraph
::
op
::
AllReduce
*
allreduce
=
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
node
);
auto
reduce_type
=
allreduce
->
get_reduce_type
();
...
...
src/ngraph/runtime/cpu/builder/broadcast_distributed.cpp
View file @
6e6c23ff
...
...
@@ -33,7 +33,7 @@ namespace ngraph
auto
arg_buffer_index
=
external_function
->
get_buffer_index
(
args
[
0
].
get_name
());
auto
count
=
static_cast
<
int
>
(
args
[
0
].
get_size
());
auto
data_type
=
args
[
0
].
get_element_type
()
.
get_type_enum
()
;
auto
data_type
=
args
[
0
].
get_element_type
();
auto
broadcast
=
static_cast
<
const
ngraph
::
op
::
BroadcastDistributed
*>
(
node
);
auto
root_id
=
broadcast
->
get_root_id
();
auto
functor
=
[
&
,
count
,
data_type
,
arg_buffer_index
,
root_id
](
...
...
src/ngraph/runtime/cpu/cpu_external_function.cpp
View file @
6e6c23ff
...
...
@@ -1261,7 +1261,7 @@ static void dump_one_kernel_with_type(runtime::cpu::CPU_DebugTracer& debug_trace
const
std
::
string
&
tensor_name
,
const
std
::
string
&
in_out
)
{
switch
(
t_attrs
.
m_type_of_element
.
get_type_enum
()
)
switch
(
t_attrs
.
m_type_of_element
)
{
case
element
:
:
Type_t
::
f32
:
debug_tracer
.
dump_one_tensor
<
float
>
(
kernel_name
,
...
...
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
View file @
6e6c23ff
...
...
@@ -213,7 +213,7 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
{
stringstream
ss
;
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
boolean
:
op_engine
<
char
>
(
op
,
out
,
in
);
break
;
case
element
:
:
Type_t
::
f32
:
op_engine
<
float
>
(
op
,
out
,
in
);
break
;
...
...
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
View file @
6e6c23ff
...
...
@@ -267,7 +267,7 @@ private:
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
&
node
);
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
allreduce
->
get_reduce_type
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
break
;
...
...
@@ -504,7 +504,7 @@ private:
{
reference
::
broadcastdistributed
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
T
);
...
...
@@ -514,7 +514,7 @@ private:
{
reference
::
broadcastdistributed
<
T
>
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
}
...
...
@@ -559,7 +559,7 @@ private:
element
::
Type
type
=
node
.
get_element_type
();
std
::
stringstream
ss
;
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
boolean
:
reference
::
convert_to_bool
<
T
>
(
...
...
@@ -1300,10 +1300,8 @@ private:
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Recv
*>
(
&
node
);
int
src_id
=
op
->
get_src_id
();
reference
::
recv
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
element_count
,
src_id
);
reference
::
recv
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
),
element_count
,
src_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
...
...
@@ -1467,7 +1465,7 @@ private:
int
dest_id
=
op
->
get_dest_id
();
reference
::
send
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
element_count
,
dest_id
);
...
...
src/ngraph/runtime/intelgpu/intelgpu_layout.cpp
View file @
6e6c23ff
...
...
@@ -54,7 +54,7 @@ bool runtime::intelgpu::IntelGPULayout::
cldnn
::
data_types
runtime
::
intelgpu
::
IntelGPULayout
::
get_cldnn_type
(
const
element
::
Type
&
element_type
)
{
switch
(
element_type
.
get_type_enum
()
)
switch
(
element_type
)
{
case
element
:
:
Type_t
::
i8
:
case
element
:
:
Type_t
::
boolean
:
return
cldnn
::
data_types
::
i8
;
...
...
@@ -118,7 +118,7 @@ cldnn::layout runtime::intelgpu::IntelGPULayout::create_cldnn_layout(
const
cldnn
::
tensor
tensor
=
create_cldnn_tensor
(
element_shape
);
cldnn
::
data_types
data_type
;
switch
(
element_type
.
get_type_enum
()
)
switch
(
element_type
)
{
case
element
:
:
Type_t
::
i16
:
case
element
:
:
Type_t
::
u16
:
...
...
src/ngraph/runtime/intelgpu/intelgpu_op_custom_kernels.cpp
View file @
6e6c23ff
...
...
@@ -33,7 +33,7 @@ using namespace ngraph::runtime::intelgpu;
string
runtime
::
intelgpu
::
get_opencl_type_name
(
const
element
::
Type
&
ngraph_type
)
{
switch
(
ngraph_type
.
get_type_enum
()
)
switch
(
ngraph_type
)
{
case
element
:
:
Type_t
::
i64
:
return
"long"
;
case
element
:
:
Type_t
::
u64
:
return
"ulong"
;
...
...
@@ -52,7 +52,7 @@ string runtime::intelgpu::get_opencl_type_name(const element::Type& ngraph_type)
string
runtime
::
intelgpu
::
get_opencl_type_min_max_value
(
const
element
::
Type
&
ngraph_type
,
bool
is_min
)
{
switch
(
ngraph_type
.
get_type_enum
()
)
switch
(
ngraph_type
)
{
case
element
:
:
Type_t
::
f32
:
return
is_min
?
"-INFINITY"
:
"INFINITY"
;
case
element
:
:
Type_t
::
f64
:
return
is_min
?
"-INFINITY"
:
"INFINITY"
;
...
...
@@ -1839,9 +1839,8 @@ void runtime::intelgpu::do_convert_operation(cldnn::topology& topology,
{
gws
=
generate_loops
(
writer
,
output_shape
,
true
);
if
(((
input_type
.
get_type_enum
()
==
element
::
Type_t
::
f64
)
||
(
input_type
.
get_type_enum
()
==
element
::
Type_t
::
f32
))
&&
(
output_type
.
get_type_enum
()
!=
element
::
Type_t
::
boolean
))
if
(((
input_type
==
element
::
Type_t
::
f64
)
||
(
input_type
==
element
::
Type_t
::
f32
))
&&
(
output_type
!=
element
::
Type_t
::
boolean
))
{
// this is the workaround for OpenCL to be same as with CPU floating point operations
writer
<<
input_type_name
<<
" input_var = input0"
<<
access_dims
(
output_shape
)
<<
";
\n
"
...
...
src/ngraph/runtime/interpreter/int_executable.cpp
View file @
6e6c23ff
...
...
@@ -215,7 +215,7 @@ void runtime::interpreter::INTExecutable::generate_calls(const element::Type& ty
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
{
stringstream
ss
;
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
boolean
:
op_engine
<
char
>
(
op
,
out
,
in
);
break
;
case
element
:
:
Type_t
::
f32
:
op_engine
<
float
>
(
op
,
out
,
in
);
break
;
...
...
src/ngraph/runtime/interpreter/int_executable.hpp
View file @
6e6c23ff
...
...
@@ -294,7 +294,7 @@ private:
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
&
node
);
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
allreduce
->
get_reduce_type
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
break
;
...
...
@@ -530,7 +530,7 @@ private:
{
reference
::
broadcastdistributed
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
T
);
...
...
@@ -540,7 +540,7 @@ private:
{
reference
::
broadcastdistributed
<
T
>
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
}
...
...
@@ -585,7 +585,7 @@ private:
element
::
Type
type
=
node
.
get_element_type
();
std
::
stringstream
ss
;
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
switch
(
type
.
get_type_enum
()
)
switch
(
type
)
{
case
element
:
:
Type_t
::
boolean
:
reference
::
convert_to_bool
<
T
>
(
...
...
@@ -1349,10 +1349,8 @@ private:
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Recv
*>
(
&
node
);
int
src_id
=
op
->
get_src_id
();
reference
::
recv
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
element_count
,
src_id
);
reference
::
recv
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
),
element_count
,
src_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
...
...
@@ -1516,7 +1514,7 @@ private:
int
dest_id
=
op
->
get_dest_id
();
reference
::
send
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
node
.
get_input_element_type
(
0
),
element_count
,
dest_id
);
...
...
src/ngraph/type/element_type.hpp
View file @
6e6c23ff
...
...
@@ -26,6 +26,7 @@
#include <string>
#include <vector>
#include "ngraph/deprecated.hpp"
#include "ngraph/except.hpp"
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/type/bfloat16.hpp"
...
...
@@ -73,7 +74,10 @@ namespace ngraph
const
std
::
string
&
cname
);
~
Type
()
{}
Type
&
operator
=
(
const
Type
&
)
=
default
;
Type_t
get_type_enum
()
const
{
return
m_type
;
}
NGRAPH_DEPRECATED
(
"Use operator Type_t()"
)
Type_t
get_type_enum
()
const
{
return
m_type
;
}
const
std
::
string
&
c_type_string
()
const
;
size_t
size
()
const
;
size_t
hash
()
const
;
...
...
@@ -119,6 +123,8 @@ namespace ngraph
/// does nothing to dst, and returns false
static
bool
merge
(
element
::
Type
&
dst
,
const
element
::
Type
&
t1
,
const
element
::
Type
&
t2
);
// \brief This allows switch(element_type)
operator
Type_t
()
const
{
return
m_type
;
}
private
:
Type_t
m_type
{
Type_t
::
undefined
};
};
...
...
src/tools/nbench/benchmark_utils.cpp
View file @
6e6c23ff
...
...
@@ -85,7 +85,7 @@ void random_init(shared_ptr<runtime::Tensor> tensor)
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch
(
et
.
get_type_enum
()
)
switch
(
et
)
{
case
element
:
:
Type_t
::
boolean
:
init_int_tensor
<
char
>
(
tensor
,
0
,
1
);
break
;
case
element
:
:
Type_t
::
f32
:
init_real_tensor
<
float
>
(
tensor
,
-
1
,
1
);
break
;
...
...
test/util/test_case.cpp
View file @
6e6c23ff
...
...
@@ -39,14 +39,14 @@ void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
auto
result_shape
=
result_tensor
->
get_shape
();
EXPECT_EQ
(
expected_shape
,
result_shape
);
if
(
m_value_comparators
.
count
(
element_type
.
get_type_enum
()
)
==
0
)
if
(
m_value_comparators
.
count
(
element_type
)
==
0
)
{
NGRAPH_FAIL
()
<<
"Please add support for "
<<
element_type
<<
" to ngraph::test::NgraphTestCase::run()"
;
}
else
{
auto
values_match
=
m_value_comparators
.
at
(
element_type
.
get_type_enum
()
);
auto
values_match
=
m_value_comparators
.
at
(
element_type
);
EXPECT_TRUE
(
values_match
(
expected_result_constant
,
result_tensor
));
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment