Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
253d4cdf
Unverified
Commit
253d4cdf
authored
Feb 15, 2019
by
Robert Kimball
Committed by
GitHub
Feb 15, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Pass HostTensor all the way down (#2454)
* pass HostTensor all the way down * fix distributed build error
parent
0a3858a0
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
244 additions
and
247 deletions
+244
-247
int_executable.cpp
src/ngraph/runtime/interpreter/int_executable.cpp
+4
-15
int_executable.hpp
src/ngraph/runtime/interpreter/int_executable.hpp
+240
-232
No files found.
src/ngraph/runtime/interpreter/int_executable.cpp
View file @
253d4cdf
...
@@ -182,22 +182,11 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
...
@@ -182,22 +182,11 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
return
true
;
return
true
;
}
}
void
runtime
::
interpreter
::
INTExecutable
::
generate_calls
(
void
runtime
::
interpreter
::
INTExecutable
::
generate_calls
(
const
element
::
Type
&
type
,
const
element
::
Type
&
type
,
const
NodeWrapper
&
op
,
const
NodeWrapper
&
op
,
const
vector
<
shared_ptr
<
HostTensor
>>&
out
,
const
vector
<
shared_ptr
<
HostTensor
>>&
outputs
,
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
const
vector
<
shared_ptr
<
HostTensor
>>&
inputs
)
{
{
vector
<
void
*>
out
;
vector
<
const
void
*>
in
;
for
(
auto
t
:
outputs
)
{
out
.
push_back
(
t
->
get_data_ptr
());
}
for
(
auto
t
:
inputs
)
{
in
.
push_back
(
t
->
get_data_ptr
());
}
stringstream
ss
;
stringstream
ss
;
switch
(
type
.
get_type_enum
())
switch
(
type
.
get_type_enum
())
{
{
...
...
src/ngraph/runtime/interpreter/int_executable.hpp
View file @
253d4cdf
...
@@ -181,8 +181,8 @@ private:
...
@@ -181,8 +181,8 @@ private:
template
<
typename
T
>
template
<
typename
T
>
void
op_engine
(
const
NodeWrapper
&
node_wrapper
,
void
op_engine
(
const
NodeWrapper
&
node_wrapper
,
const
std
::
vector
<
void
*
>&
out
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
out
,
const
std
::
vector
<
const
void
*
>&
args
)
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
args
)
{
{
const
Node
&
node
=
node_wrapper
.
get_node
();
const
Node
&
node
=
node_wrapper
.
get_node
();
std
::
string
node_op
=
node
.
description
();
std
::
string
node_op
=
node
.
description
();
...
@@ -200,30 +200,30 @@ private:
...
@@ -200,30 +200,30 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
abs
<
T
>
(
reference
::
abs
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Acos
:
case
OP_TYPEID
:
:
Acos
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
acos
<
T
>
(
reference
::
acos
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Add
:
case
OP_TYPEID
:
:
Add
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
add
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
add
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
All
:
case
OP_TYPEID
:
:
All
:
{
{
const
op
::
All
*
all
=
static_cast
<
const
op
::
All
*>
(
&
node
);
const
op
::
All
*
all
=
static_cast
<
const
op
::
All
*>
(
&
node
);
reference
::
all
(
static_cast
<
const
char
*>
(
args
[
0
]
),
reference
::
all
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
all
->
get_reduction_axes
());
all
->
get_reduction_axes
());
...
@@ -231,8 +231,8 @@ private:
...
@@ -231,8 +231,8 @@ private:
}
}
case
OP_TYPEID
:
:
AllReduce
:
{
case
OP_TYPEID
:
:
AllReduce
:
{
#ifdef NGRAPH_DISTRIBUTED_ENABLE
#ifdef NGRAPH_DISTRIBUTED_ENABLE
reference
::
allreduce
<
T
>
(
static_cast
<
T
*>
(
const_cast
<
void
*>
(
args
[
0
])
),
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_element_type
(
0
),
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
#endif
#endif
...
@@ -241,17 +241,17 @@ private:
...
@@ -241,17 +241,17 @@ private:
case
OP_TYPEID
:
:
And
:
case
OP_TYPEID
:
:
And
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_and
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
logical_and
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Any
:
case
OP_TYPEID
:
:
Any
:
{
{
const
op
::
Any
*
any
=
static_cast
<
const
op
::
Any
*>
(
&
node
);
const
op
::
Any
*
any
=
static_cast
<
const
op
::
Any
*>
(
&
node
);
reference
::
any
(
static_cast
<
const
char
*>
(
args
[
0
]
),
reference
::
any
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
any
->
get_reduction_axes
());
any
->
get_reduction_axes
());
...
@@ -263,16 +263,16 @@ private:
...
@@ -263,16 +263,16 @@ private:
auto
element_type
=
node
.
get_output_element_type
(
0
);
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
if
(
element_type
==
element
::
i64
)
{
{
reference
::
argmin
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
argmin
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
argmin
->
get_reduction_axis
());
}
}
else
if
(
element_type
==
element
::
i32
)
else
if
(
element_type
==
element
::
i32
)
{
{
reference
::
argmin
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
argmin
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
argmin
->
get_reduction_axis
());
...
@@ -289,16 +289,16 @@ private:
...
@@ -289,16 +289,16 @@ private:
auto
element_type
=
node
.
get_output_element_type
(
0
);
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
if
(
element_type
==
element
::
i64
)
{
{
reference
::
argmax
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
argmax
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
argmax
->
get_reduction_axis
());
}
}
else
if
(
element_type
==
element
::
i32
)
else
if
(
element_type
==
element
::
i32
)
{
{
reference
::
argmax
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
argmax
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
argmax
->
get_reduction_axis
());
...
@@ -313,22 +313,22 @@ private:
...
@@ -313,22 +313,22 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
asin
<
T
>
(
reference
::
asin
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Atan
:
case
OP_TYPEID
:
:
Atan
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
atan
<
T
>
(
reference
::
atan
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
AvgPool
:
case
OP_TYPEID
:
:
AvgPool
:
{
{
const
op
::
AvgPool
*
avg_pool
=
static_cast
<
const
op
::
AvgPool
*>
(
&
node
);
const
op
::
AvgPool
*
avg_pool
=
static_cast
<
const
op
::
AvgPool
*>
(
&
node
);
reference
::
avg_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
avg_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
avg_pool
->
get_window_shape
(),
avg_pool
->
get_window_shape
(),
...
@@ -347,11 +347,10 @@ private:
...
@@ -347,11 +347,10 @@ private:
ngraph
::
RNGState
::
create_rng_state
(
gm
->
get_seed
(),
gm
->
get_probability
()));
ngraph
::
RNGState
::
create_rng_state
(
gm
->
get_seed
(),
gm
->
get_probability
()));
}
}
bool
training
=
static_cast
<
bool
>
(
static_cast
<
const
T
*>
(
args
[
0
]
)[
0
]);
bool
training
=
static_cast
<
bool
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
)[
0
]);
auto
state
=
m_states
.
at
(
&
node
).
get
();
auto
state
=
m_states
.
at
(
&
node
).
get
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
generate_mask
<
T
>
(
reference
::
generate_mask
<
T
>
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
,
state
,
training
);
reinterpret_cast
<
T
*>
(
out
[
0
]),
element_count
,
state
,
training
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
GetOutputElement
:
case
OP_TYPEID
:
:
GetOutputElement
:
...
@@ -361,7 +360,7 @@ private:
...
@@ -361,7 +360,7 @@ private:
size_t
n
=
get_output_element
->
get_n
();
size_t
n
=
get_output_element
->
get_n
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
num_bytes
=
element_count
*
node
.
get_output_element_type
(
0
).
size
();
size_t
num_bytes
=
element_count
*
node
.
get_output_element_type
(
0
).
size
();
std
::
memcpy
(
static_cast
<
T
*>
(
out
[
0
]),
args
[
n
]
,
num_bytes
);
std
::
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
n
]
->
get_data_ptr
<
T
>
()
,
num_bytes
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
BatchNormTraining
:
case
OP_TYPEID
:
:
BatchNormTraining
:
...
@@ -369,12 +368,12 @@ private:
...
@@ -369,12 +368,12 @@ private:
const
ngraph
::
op
::
BatchNormTraining
*
bn
=
const
ngraph
::
op
::
BatchNormTraining
*
bn
=
static_cast
<
const
ngraph
::
op
::
BatchNormTraining
*>
(
&
node
);
static_cast
<
const
ngraph
::
op
::
BatchNormTraining
*>
(
&
node
);
reference
::
batch_norm_training
<
T
>
(
bn
->
get_eps_value
(),
reference
::
batch_norm_training
<
T
>
(
bn
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
1
]
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
2
]
),
out
[
2
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
node
.
get_input_shape
(
2
));
break
;
break
;
}
}
...
@@ -383,12 +382,12 @@ private:
...
@@ -383,12 +382,12 @@ private:
const
ngraph
::
op
::
BatchNormInference
*
bn
=
const
ngraph
::
op
::
BatchNormInference
*
bn
=
static_cast
<
const
ngraph
::
op
::
BatchNormInference
*>
(
&
node
);
static_cast
<
const
ngraph
::
op
::
BatchNormInference
*>
(
&
node
);
reference
::
batch_norm_inference
<
T
>
(
bn
->
get_eps_value
(),
reference
::
batch_norm_inference
<
T
>
(
bn
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
3
]
),
args
[
3
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
4
]
),
args
[
4
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
node
.
get_input_shape
(
2
));
break
;
break
;
}
}
...
@@ -397,23 +396,23 @@ private:
...
@@ -397,23 +396,23 @@ private:
const
ngraph
::
op
::
BatchNormTrainingBackprop
*
bn_bprop
=
const
ngraph
::
op
::
BatchNormTrainingBackprop
*
bn_bprop
=
static_cast
<
const
ngraph
::
op
::
BatchNormTrainingBackprop
*>
(
&
node
);
static_cast
<
const
ngraph
::
op
::
BatchNormTrainingBackprop
*>
(
&
node
);
reference
::
batch_norm_backprop
(
bn_bprop
->
get_eps_value
(),
reference
::
batch_norm_backprop
(
bn_bprop
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
3
]
),
args
[
3
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
4
]
),
args
[
4
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
5
]
),
args
[
5
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
1
]
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
2
]
),
out
[
2
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
node
.
get_input_shape
(
2
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
AvgPoolBackprop
:
case
OP_TYPEID
:
:
AvgPoolBackprop
:
{
{
const
op
::
AvgPoolBackprop
*
apb
=
static_cast
<
const
op
::
AvgPoolBackprop
*>
(
&
node
);
const
op
::
AvgPoolBackprop
*
apb
=
static_cast
<
const
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
avg_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
apb
->
get_window_shape
(),
apb
->
get_window_shape
(),
...
@@ -429,8 +428,8 @@ private:
...
@@ -429,8 +428,8 @@ private:
Shape
in_shape
=
node
.
get_input_shape
(
0
);
Shape
in_shape
=
node
.
get_input_shape
(
0
);
Shape
out_shape
=
node
.
get_output_shape
(
0
);
Shape
out_shape
=
node
.
get_output_shape
(
0
);
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
reference
::
broadcast
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
broadcast
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
in_shape
,
in_shape
,
out_shape
,
out_shape
,
broadcast_axes
);
broadcast_axes
);
...
@@ -441,7 +440,7 @@ private:
...
@@ -441,7 +440,7 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
ceiling
<
T
>
(
reference
::
ceiling
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Concat
:
case
OP_TYPEID
:
:
Concat
:
...
@@ -451,11 +450,11 @@ private:
...
@@ -451,11 +450,11 @@ private:
std
::
vector
<
Shape
>
in_shapes
;
std
::
vector
<
Shape
>
in_shapes
;
for
(
size_t
i
=
0
;
i
<
node
.
get_input_size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
node
.
get_input_size
();
i
++
)
{
{
in_args
.
push_back
(
static_cast
<
const
T
*>
(
args
[
i
]
));
in_args
.
push_back
(
args
[
i
]
->
get_data_ptr
<
const
T
>
(
));
in_shapes
.
push_back
(
node
.
get_input_shape
(
i
));
in_shapes
.
push_back
(
node
.
get_input_shape
(
i
));
}
}
reference
::
concat
<
T
>
(
in_args
,
reference
::
concat
<
T
>
(
in_args
,
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
in_shapes
,
in_shapes
,
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
concat
->
get_concatenation_axis
());
concat
->
get_concatenation_axis
());
...
@@ -465,7 +464,7 @@ private:
...
@@ -465,7 +464,7 @@ private:
{
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
constant
<
T
>
(
c
->
get_data_ptr
<
T
>
(),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
reference
::
constant
<
T
>
(
c
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ScalarConstantLike
:
break
;
case
OP_TYPEID
:
:
ScalarConstantLike
:
break
;
...
@@ -479,47 +478,56 @@ private:
...
@@ -479,47 +478,56 @@ private:
{
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
boolean
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
char
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
f32
:
case
element
:
:
Type_t
::
f32
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
float
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
float
>
(
),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
f64
:
case
element
:
:
Type_t
::
f64
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
double
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
double
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
i8
:
case
element
:
:
Type_t
::
i8
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int8_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
int8_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
i16
:
case
element
:
:
Type_t
::
i16
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int16_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
int16_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
i32
:
case
element
:
:
Type_t
::
i32
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int32_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
i64
:
case
element
:
:
Type_t
::
i64
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int64_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
int64_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
u8
:
case
element
:
:
Type_t
::
u8
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint8_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
u16
:
case
element
:
:
Type_t
::
u16
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint16_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
uint16_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
u32
:
case
element
:
:
Type_t
::
u32
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint32_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
uint32_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
u64
:
case
element
:
:
Type_t
::
u64
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint64_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
uint64_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
dynamic
:
...
@@ -532,9 +540,9 @@ private:
...
@@ -532,9 +540,9 @@ private:
case
OP_TYPEID
:
:
Convolution
:
case
OP_TYPEID
:
:
Convolution
:
{
{
const
op
::
Convolution
*
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
const
op
::
Convolution
*
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
convolution
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
...
@@ -556,9 +564,9 @@ private:
...
@@ -556,9 +564,9 @@ private:
{
{
const
op
::
ConvolutionBackpropFilters
*
c
=
const
op
::
ConvolutionBackpropFilters
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
convolution
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
...
@@ -581,9 +589,9 @@ private:
...
@@ -581,9 +589,9 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
// Note that args[1] and args[0] are switched here from the usual order.
const
op
::
ConvolutionBackpropData
*
c
=
const
op
::
ConvolutionBackpropData
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
1
]
),
reference
::
convolution
<
T
>
(
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
0
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
...
@@ -605,14 +613,14 @@ private:
...
@@ -605,14 +613,14 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cos
<
T
>
(
reference
::
cos
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Cosh
:
case
OP_TYPEID
:
:
Cosh
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cosh
<
T
>
(
reference
::
cosh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Dequantize
:
case
OP_TYPEID
:
:
Dequantize
:
...
@@ -622,20 +630,20 @@ private:
...
@@ -622,20 +630,20 @@ private:
if
(
type
==
element
::
f32
)
if
(
type
==
element
::
f32
)
{
{
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
float
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
float
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
float
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
float
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
dequantize
->
get_axes
());
}
}
else
if
(
type
==
element
::
f64
)
else
if
(
type
==
element
::
f64
)
{
{
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
double
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
double
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
double
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
double
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
dequantize
->
get_axes
());
...
@@ -652,9 +660,9 @@ private:
...
@@ -652,9 +660,9 @@ private:
case
OP_TYPEID
:
:
Divide
:
case
OP_TYPEID
:
:
Divide
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
divide
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
divide
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -662,9 +670,9 @@ private:
...
@@ -662,9 +670,9 @@ private:
{
{
const
op
::
Dot
*
dot
=
static_cast
<
const
op
::
Dot
*>
(
&
node
);
const
op
::
Dot
*
dot
=
static_cast
<
const
op
::
Dot
*>
(
&
node
);
reference
::
dot
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
dot
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
...
@@ -679,33 +687,33 @@ private:
...
@@ -679,33 +687,33 @@ private:
if
(
type
==
element
::
f32
)
if
(
type
==
element
::
f32
)
{
{
reference
::
embedding
<
T
,
float
>
(
static_cast
<
const
float
*>
(
args
[
0
]
),
reference
::
embedding
<
T
,
float
>
(
args
[
0
]
->
get_data_ptr
<
const
float
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
element_count
,
embed
->
get_shape
());
embed
->
get_shape
());
}
}
else
if
(
type
==
element
::
f64
)
else
if
(
type
==
element
::
f64
)
{
{
reference
::
embedding
<
T
,
double
>
(
static_cast
<
const
double
*>
(
args
[
0
]
),
reference
::
embedding
<
T
,
double
>
(
args
[
0
]
->
get_data_ptr
<
const
double
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
element_count
,
embed
->
get_shape
());
embed
->
get_shape
());
}
}
else
if
(
type
==
element
::
i32
)
else
if
(
type
==
element
::
i32
)
{
{
reference
::
embedding
<
T
,
int
>
(
static_cast
<
const
int
*>
(
args
[
0
]
),
reference
::
embedding
<
T
,
int
>
(
args
[
0
]
->
get_data_ptr
<
const
int
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
element_count
,
embed
->
get_shape
());
embed
->
get_shape
());
}
}
else
if
(
type
==
element
::
i64
)
else
if
(
type
==
element
::
i64
)
{
{
reference
::
embedding
<
T
,
int64_t
>
(
static_cast
<
const
int64_t
*>
(
args
[
0
]
),
reference
::
embedding
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
int64_t
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
element_count
,
embed
->
get_shape
());
embed
->
get_shape
());
}
}
...
@@ -719,9 +727,9 @@ private:
...
@@ -719,9 +727,9 @@ private:
case
OP_TYPEID
:
:
Equal
:
case
OP_TYPEID
:
:
Equal
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -729,49 +737,49 @@ private:
...
@@ -729,49 +737,49 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
exp
<
T
>
(
reference
::
exp
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Floor
:
case
OP_TYPEID
:
:
Floor
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
floor
<
T
>
(
reference
::
floor
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Greater
:
case
OP_TYPEID
:
:
Greater
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
greater
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
GreaterEq
:
case
OP_TYPEID
:
:
GreaterEq
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
greater_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Less
:
case
OP_TYPEID
:
:
Less
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
less
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
LessEq
:
case
OP_TYPEID
:
:
LessEq
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
less_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -779,14 +787,14 @@ private:
...
@@ -779,14 +787,14 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
log
<
T
>
(
reference
::
log
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
LRN
:
case
OP_TYPEID
:
:
LRN
:
{
{
const
op
::
LRN
*
lrn
=
static_cast
<
const
op
::
LRN
*>
(
&
node
);
const
op
::
LRN
*
lrn
=
static_cast
<
const
op
::
LRN
*>
(
&
node
);
reference
::
lrn
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
lrn
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
lrn
->
get_alpha
(),
lrn
->
get_alpha
(),
lrn
->
get_beta
(),
lrn
->
get_beta
(),
...
@@ -797,8 +805,8 @@ private:
...
@@ -797,8 +805,8 @@ private:
case
OP_TYPEID
:
:
Max
:
case
OP_TYPEID
:
:
Max
:
{
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
max
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
max
->
get_reduction_axes
());
max
->
get_reduction_axes
());
...
@@ -807,9 +815,9 @@ private:
...
@@ -807,9 +815,9 @@ private:
case
OP_TYPEID
:
:
Maximum
:
case
OP_TYPEID
:
:
Maximum
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
maximum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
maximum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -817,8 +825,8 @@ private:
...
@@ -817,8 +825,8 @@ private:
{
{
const
op
::
MaxPool
*
max_pool
=
static_cast
<
const
op
::
MaxPool
*>
(
&
node
);
const
op
::
MaxPool
*
max_pool
=
static_cast
<
const
op
::
MaxPool
*>
(
&
node
);
reference
::
max_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
max_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
max_pool
->
get_window_shape
(),
max_pool
->
get_window_shape
(),
...
@@ -832,9 +840,9 @@ private:
...
@@ -832,9 +840,9 @@ private:
const
op
::
MaxPoolBackprop
*
max_pool_backprop
=
const
op
::
MaxPoolBackprop
*
max_pool_backprop
=
static_cast
<
const
op
::
MaxPoolBackprop
*>
(
&
node
);
static_cast
<
const
op
::
MaxPoolBackprop
*>
(
&
node
);
reference
::
max_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
max_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
max_pool_backprop
->
get_window_shape
(),
max_pool_backprop
->
get_window_shape
(),
...
@@ -846,8 +854,8 @@ private:
...
@@ -846,8 +854,8 @@ private:
case
OP_TYPEID
:
:
Min
:
case
OP_TYPEID
:
:
Min
:
{
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
min
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
min
->
get_reduction_axes
());
min
->
get_reduction_axes
());
...
@@ -856,18 +864,18 @@ private:
...
@@ -856,18 +864,18 @@ private:
case
OP_TYPEID
:
:
Minimum
:
case
OP_TYPEID
:
:
Minimum
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
minimum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
minimum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Multiply
:
case
OP_TYPEID
:
:
Multiply
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
multiply
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
multiply
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -875,30 +883,30 @@ private:
...
@@ -875,30 +883,30 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
negate
<
T
>
(
reference
::
negate
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Not
:
case
OP_TYPEID
:
:
Not
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_not
(
reference
::
logical_not
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
NotEqual
:
case
OP_TYPEID
:
:
NotEqual
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
not_equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
not_equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
OneHot
:
case
OP_TYPEID
:
:
OneHot
:
{
{
const
op
::
OneHot
*
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
const
op
::
OneHot
*
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
one_hot
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
oh
->
get_one_hot_axis
());
oh
->
get_one_hot_axis
());
...
@@ -907,9 +915,9 @@ private:
...
@@ -907,9 +915,9 @@ private:
case
OP_TYPEID
:
:
Or
:
case
OP_TYPEID
:
:
Or
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_or
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
logical_or
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -918,9 +926,9 @@ private:
...
@@ -918,9 +926,9 @@ private:
{
{
const
op
::
Pad
*
pad
=
static_cast
<
const
op
::
Pad
*>
(
&
node
);
const
op
::
Pad
*
pad
=
static_cast
<
const
op
::
Pad
*>
(
&
node
);
reference
::
pad
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
pad
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
pad
->
get_padding_below
(),
pad
->
get_padding_below
(),
...
@@ -931,17 +939,17 @@ private:
...
@@ -931,17 +939,17 @@ private:
case
OP_TYPEID
:
:
Power
:
case
OP_TYPEID
:
:
Power
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
power
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
power
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Product
:
case
OP_TYPEID
:
:
Product
:
{
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
product
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
product
->
get_reduction_axes
());
product
->
get_reduction_axes
());
...
@@ -954,10 +962,10 @@ private:
...
@@ -954,10 +962,10 @@ private:
if
(
type
==
element
::
u8
)
if
(
type
==
element
::
u8
)
{
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
uint8_t
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
uint8_t
>
(
),
static_cast
<
uint8_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
quantize
->
get_axes
(),
...
@@ -965,10 +973,10 @@ private:
...
@@ -965,10 +973,10 @@ private:
}
}
else
if
(
type
==
element
::
i8
)
else
if
(
type
==
element
::
i8
)
{
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
int8_t
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
int8_t
>
(
),
static_cast
<
int8_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
quantize
->
get_axes
(),
...
@@ -976,10 +984,10 @@ private:
...
@@ -976,10 +984,10 @@ private:
}
}
else
if
(
type
==
element
::
i32
)
else
if
(
type
==
element
::
i32
)
{
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
int32_t
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
int32_t
>
(
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
quantize
->
get_axes
(),
...
@@ -1009,24 +1017,24 @@ private:
...
@@ -1009,24 +1017,24 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu
<
T
>
(
reference
::
relu
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ReluBackprop
:
case
OP_TYPEID
:
:
ReluBackprop
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
relu_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ReplaceSlice
:
case
OP_TYPEID
:
:
ReplaceSlice
:
{
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
replace_slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
slice
->
get_lower_bounds
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_upper_bounds
(),
...
@@ -1037,8 +1045,8 @@ private:
...
@@ -1037,8 +1045,8 @@ private:
case
OP_TYPEID
:
:
Reshape
:
case
OP_TYPEID
:
:
Reshape
:
{
{
const
op
::
Reshape
*
reshape
=
static_cast
<
const
op
::
Reshape
*>
(
&
node
);
const
op
::
Reshape
*
reshape
=
static_cast
<
const
op
::
Reshape
*>
(
&
node
);
reference
::
reshape
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
reshape
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
reshape
->
get_input_order
(),
reshape
->
get_input_order
(),
node
.
get_output_shape
(
0
));
node
.
get_output_shape
(
0
));
...
@@ -1047,16 +1055,16 @@ private:
...
@@ -1047,16 +1055,16 @@ private:
case
OP_TYPEID
:
:
Result
:
case
OP_TYPEID
:
:
Result
:
{
{
const
op
::
Result
*
res
=
static_cast
<
const
op
::
Result
*>
(
&
node
);
const
op
::
Result
*
res
=
static_cast
<
const
op
::
Result
*>
(
&
node
);
reference
::
result
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
result
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
shape_size
(
res
->
get_shape
()));
shape_size
(
res
->
get_shape
()));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Reverse
:
case
OP_TYPEID
:
:
Reverse
:
{
{
const
op
::
Reverse
*
reverse
=
static_cast
<
const
op
::
Reverse
*>
(
&
node
);
const
op
::
Reverse
*
reverse
=
static_cast
<
const
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
reverse
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
reverse
->
get_reversed_axes
());
reverse
->
get_reversed_axes
());
...
@@ -1068,12 +1076,12 @@ private:
...
@@ -1068,12 +1076,12 @@ private:
if
(
node
.
get_input_element_type
(
1
)
==
element
::
i32
)
if
(
node
.
get_input_element_type
(
1
)
==
element
::
i32
)
{
{
reference
::
reverse_sequence
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
reverse_sequence
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
reverse
->
get_batch_axis
(),
reverse
->
get_batch_axis
(),
reverse
->
get_sequence_axis
(),
reverse
->
get_sequence_axis
(),
static_cast
<
const
int32_t
*>
(
args
[
1
]
));
args
[
1
]
->
get_data_ptr
<
const
int32_t
>
(
));
}
}
else
else
{
{
...
@@ -1084,31 +1092,31 @@ private:
...
@@ -1084,31 +1092,31 @@ private:
case
OP_TYPEID
:
:
Select
:
case
OP_TYPEID
:
:
Select
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
select
<
T
>
(
static_cast
<
const
char
*>
(
args
[
0
]
),
reference
::
select
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ShapeOf
:
case
OP_TYPEID
:
:
ShapeOf
:
{
{
reference
::
shape_of
(
node
.
get_input_shape
(
0
),
static_cast
<
uint64_t
*>
(
out
[
0
]
));
reference
::
shape_of
(
node
.
get_input_shape
(
0
),
out
[
0
]
->
get_data_ptr
<
uint64_t
>
(
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sigmoid
:
case
OP_TYPEID
:
:
Sigmoid
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid
<
T
>
(
reference
::
sigmoid
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
SigmoidBackprop
:
case
OP_TYPEID
:
:
SigmoidBackprop
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
sigmoid_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -1116,28 +1124,28 @@ private:
...
@@ -1116,28 +1124,28 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sign
<
T
>
(
reference
::
sign
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sin
:
case
OP_TYPEID
:
:
Sin
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sin
<
T
>
(
reference
::
sin
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sinh
:
case
OP_TYPEID
:
:
Sinh
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sinh
<
T
>
(
reference
::
sinh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Slice
:
case
OP_TYPEID
:
:
Slice
:
{
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
slice
->
get_lower_bounds
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_upper_bounds
(),
...
@@ -1148,8 +1156,8 @@ private:
...
@@ -1148,8 +1156,8 @@ private:
case
OP_TYPEID
:
:
Softmax
:
case
OP_TYPEID
:
:
Softmax
:
{
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
softmax
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
softmax
->
get_axes
());
softmax
->
get_axes
());
break
;
break
;
...
@@ -1158,7 +1166,7 @@ private:
...
@@ -1158,7 +1166,7 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sqrt
<
T
>
(
reference
::
sqrt
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
StopGradient
:
{
throw
unsupported_op
(
"Unsupported op 'StopGradient'"
);
case
OP_TYPEID
:
:
StopGradient
:
{
throw
unsupported_op
(
"Unsupported op 'StopGradient'"
);
...
@@ -1166,17 +1174,17 @@ private:
...
@@ -1166,17 +1174,17 @@ private:
case
OP_TYPEID
:
:
Subtract
:
case
OP_TYPEID
:
:
Subtract
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
subtract
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
subtract
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sum
:
case
OP_TYPEID
:
:
Sum
:
{
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
sum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
sum
->
get_reduction_axes
());
sum
->
get_reduction_axes
());
...
@@ -1186,14 +1194,14 @@ private:
...
@@ -1186,14 +1194,14 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tan
<
T
>
(
reference
::
tan
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Tanh
:
case
OP_TYPEID
:
:
Tanh
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tanh
<
T
>
(
reference
::
tanh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
TopK
:
case
OP_TYPEID
:
:
TopK
:
...
@@ -1201,9 +1209,9 @@ private:
...
@@ -1201,9 +1209,9 @@ private:
const
op
::
TopK
*
topk
=
static_cast
<
const
op
::
TopK
*>
(
&
node
);
const
op
::
TopK
*
topk
=
static_cast
<
const
op
::
TopK
*>
(
&
node
);
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i64
)
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i64
)
{
{
reference
::
topk
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
topk
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
static_cast
<
T
*>
(
out
[
1
]
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
topk
->
get_top_k_axis
(),
...
@@ -1212,9 +1220,9 @@ private:
...
@@ -1212,9 +1220,9 @@ private:
}
}
else
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i32
)
else
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i32
)
{
{
reference
::
topk
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
topk
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
static_cast
<
T
*>
(
out
[
1
]
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
topk
->
get_top_k_axis
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment