Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
6f511762
Unverified
Commit
6f511762
authored
Nov 09, 2018
by
Robert Kimball
Committed by
GitHub
Nov 09, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Interpreter rework (#2030)
* all tests passing * rename a few vars to be consistent with new tensor names
parent
b52a7798
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
463 additions
and
420 deletions
+463
-420
int_backend.cpp
src/ngraph/runtime/interpreter/int_backend.cpp
+56
-51
int_backend.hpp
src/ngraph/runtime/interpreter/int_backend.hpp
+404
-366
relu.hpp
src/ngraph/runtime/reference/relu.hpp
+1
-1
reverse_sequence.hpp
src/ngraph/runtime/reference/reverse_sequence.hpp
+1
-1
sigmoid.hpp
src/ngraph/runtime/reference/sigmoid.hpp
+1
-1
No files found.
src/ngraph/runtime/interpreter/int_backend.cpp
View file @
6f511762
...
...
@@ -24,6 +24,7 @@
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
...
...
@@ -31,6 +32,8 @@ using namespace ngraph;
using
descriptor
::
layout
::
DenseTensorLayout
;
const
int
runtime
::
interpreter
::
INTBackend
::
m_alignment
=
64
;
extern
"C"
const
char
*
get_ngraph_version_string
()
{
return
NGRAPH_VERSION
;
...
...
@@ -63,8 +66,12 @@ bool runtime::interpreter::INTBackend::compile(shared_ptr<Function> function)
pass_manager
.
register_pass
<
pass
::
LikeReplacement
>
();
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorLayout
>>
();
pass_manager
.
register_pass
<
pass
::
Liveness
>
();
pass_manager
.
register_pass
<
pass
::
MemoryLayout
>
(
m_alignment
);
pass_manager
.
run_passes
(
function
);
size_t
memory_pool_size
=
function
->
get_temporary_pool_size
();
instance
.
m_temporary_memory
.
reset
(
new
AlignedBuffer
(
memory_pool_size
,
m_alignment
));
for
(
const
shared_ptr
<
Node
>&
node
:
function
->
get_ordered_ops
())
{
instance
.
m_wrapped_nodes
.
emplace_back
(
node
);
...
...
@@ -84,32 +91,36 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
FunctionInstance
&
instance
=
m_function_map
[
function
];
// convert inputs to HostTensor
vector
<
shared_ptr
<
runtime
::
HostTensor
>>
func_inputs
;
for
(
auto
tv
:
inputs
)
vector
<
void
*>
func_inputs
;
vector
<
shared_ptr
<
runtime
::
HostTensor
>>
htv_inputs
;
for
(
auto
tensor
:
inputs
)
{
func_inputs
.
push_back
(
static_pointer_cast
<
runtime
::
HostTensor
>
(
tv
));
auto
host_tensor
=
static_pointer_cast
<
runtime
::
HostTensor
>
(
tensor
);
func_inputs
.
push_back
(
static_cast
<
void
*>
(
host_tensor
->
get_data_ptr
()));
htv_inputs
.
push_back
(
host_tensor
);
}
if
(
instance
.
m_nan_check_enabled
)
{
perform_nan_check
(
func
_inputs
);
perform_nan_check
(
htv
_inputs
);
}
// convert outputs to HostTensor
vector
<
shared_ptr
<
runtime
::
HostTensor
>
>
func_outputs
;
for
(
auto
t
v
:
outputs
)
vector
<
void
*
>
func_outputs
;
for
(
auto
t
ensor
:
outputs
)
{
func_outputs
.
push_back
(
static_pointer_cast
<
runtime
::
HostTensor
>
(
tv
));
auto
host_tensor
=
static_pointer_cast
<
runtime
::
HostTensor
>
(
tensor
);
func_outputs
.
push_back
(
static_cast
<
void
*>
(
host_tensor
->
get_data_ptr
()));
}
// map function params -> HostTensor
unordered_map
<
descriptor
::
Tensor
*
,
shared_ptr
<
runtime
::
HostTensor
>
>
tensor_map
;
unordered_map
<
descriptor
::
Tensor
*
,
void
*
>
tensor_map
;
size_t
input_count
=
0
;
for
(
auto
param
:
function
->
get_parameters
())
{
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
{
descriptor
::
Tensor
*
t
v
=
param
->
get_output_tensor_ptr
(
i
).
get
();
tensor_map
.
insert
({
t
v
,
func_inputs
[
input_count
++
]});
descriptor
::
Tensor
*
t
ensor
=
param
->
get_output_tensor_ptr
(
i
).
get
();
tensor_map
.
insert
({
t
ensor
,
func_inputs
[
input_count
++
]});
}
}
...
...
@@ -121,8 +132,8 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
{
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
}
descriptor
::
Tensor
*
t
v
=
output
->
get_output_tensor_ptr
(
0
).
get
();
tensor_map
.
insert
({
t
v
,
func_outputs
[
output_count
]});
descriptor
::
Tensor
*
t
ensor
=
output
->
get_output_tensor_ptr
(
0
).
get
();
tensor_map
.
insert
({
t
ensor
,
func_outputs
[
output_count
]});
}
// for each ordered op in the graph
...
...
@@ -134,35 +145,42 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
{
continue
;
}
if
(
type_id
==
OP_TYPEID
::
Constant
)
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
op
);
descriptor
::
Tensor
*
tensor
=
op
->
get_output_tensor_ptr
(
0
).
get
();
tensor_map
.
insert
({
tensor
,
const_cast
<
void
*>
(
c
->
get_data_ptr
())});
continue
;
}
// get op inputs from map
vector
<
shared_ptr
<
runtime
::
HostTensor
>
>
op_inputs
;
vector
<
const
void
*
>
op_inputs
;
for
(
const
descriptor
::
Input
&
input
:
op
->
get_inputs
())
{
descriptor
::
Tensor
*
t
v
=
input
.
get_output
().
get_tensor_ptr
().
get
();
op_inputs
.
push_back
(
tensor_map
.
at
(
t
v
));
descriptor
::
Tensor
*
t
ensor
=
input
.
get_output
().
get_tensor_ptr
().
get
();
op_inputs
.
push_back
(
tensor_map
.
at
(
t
ensor
));
}
// get op outputs from map or create
vector
<
shared_ptr
<
runtime
::
HostTensor
>>
op_outputs
;
vector
<
void
*>
op_outputs
;
vector
<
shared_ptr
<
runtime
::
HostTensor
>>
htv_outputs
;
for
(
size_t
i
=
0
;
i
<
op
->
get_output_size
();
++
i
)
{
descriptor
::
Tensor
*
t
v
=
op
->
get_output_tensor_ptr
(
i
).
get
();
shared_ptr
<
runtime
::
HostTensor
>
htv
;
auto
it
=
tensor_map
.
find
(
t
v
);
descriptor
::
Tensor
*
t
ensor
=
op
->
get_output_tensor_ptr
(
i
).
get
();
void
*
host_tensor
=
nullptr
;
auto
it
=
tensor_map
.
find
(
t
ensor
);
if
(
it
==
tensor_map
.
end
())
{
// the output tensor is not in the tensor map so create a new tensor
const
Shape
&
shape
=
op
->
get_output_shape
(
i
);
const
element
::
Type
&
type
=
op
->
get_output_element_type
(
i
);
string
name
=
op
->
get_output_tensor
(
i
).
get_name
();
htv
=
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
name
);
tensor_map
.
insert
({
tv
,
htv
});
auto
offset
=
op
->
get_output_tensor
(
i
).
get_pool_offset
();
host_tensor
=
instance
.
get_temporary_pointer
(
offset
);
tensor_map
.
insert
({
tensor
,
host_tensor
});
}
else
{
h
tv
=
it
->
second
;
h
ost_tensor
=
it
->
second
;
}
op_outputs
.
push_back
(
htv
);
op_outputs
.
push_back
(
host_tensor
);
htv_outputs
.
push_back
(
make_shared
<
runtime
::
HostTensor
>
(
tensor
->
get_element_type
(),
tensor
->
get_shape
(),
host_tensor
));
}
// get op type
...
...
@@ -202,20 +220,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
}
if
(
instance
.
m_nan_check_enabled
)
{
perform_nan_check
(
op_outputs
,
op
);
}
// delete any obsolete tensors
for
(
const
descriptor
::
Tensor
*
t
:
op
->
liveness_free_list
)
{
for
(
auto
it
=
tensor_map
.
begin
();
it
!=
tensor_map
.
end
();
++
it
)
{
if
(
it
->
second
->
get_name
()
==
t
->
get_name
())
{
tensor_map
.
erase
(
it
);
break
;
}
}
perform_nan_check
(
htv_outputs
,
op
);
}
}
...
...
@@ -224,8 +229,8 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
void
runtime
::
interpreter
::
INTBackend
::
generate_calls
(
const
element
::
Type
&
type
,
const
NodeWrapper
&
op
,
const
vector
<
shared_ptr
<
HostTensor
>
>&
outputs
,
const
vector
<
shared_ptr
<
HostTensor
>
>&
inputs
,
const
vector
<
void
*
>&
outputs
,
const
vector
<
const
void
*
>&
inputs
,
FunctionInstance
&
instance
)
{
if
(
type
==
element
::
boolean
)
...
...
@@ -307,17 +312,17 @@ vector<runtime::PerformanceCounter>
return
rc
;
}
void
runtime
::
interpreter
::
INTBackend
::
perform_nan_check
(
const
vector
<
shared_ptr
<
HostTensor
>>&
tvs
,
const
Node
*
op
)
void
runtime
::
interpreter
::
INTBackend
::
perform_nan_check
(
const
vector
<
shared_ptr
<
HostTensor
>>&
tensors
,
const
Node
*
op
)
{
size_t
arg_number
=
1
;
for
(
shared_ptr
<
HostTensor
>
tv
:
tv
s
)
for
(
const
shared_ptr
<
HostTensor
>&
tensor
:
tensor
s
)
{
const
element
::
Type
&
type
=
t
v
->
get_element_type
();
const
element
::
Type
&
type
=
t
ensor
->
get_element_type
();
if
(
type
==
element
::
f32
)
{
const
float
*
data
=
t
v
->
get_data_ptr
<
float
>
();
for
(
size_t
i
=
0
;
i
<
t
v
->
get_element_count
();
i
++
)
const
float
*
data
=
t
ensor
->
get_data_ptr
<
float
>
();
for
(
size_t
i
=
0
;
i
<
t
ensor
->
get_element_count
();
i
++
)
{
if
(
std
::
isnan
(
data
[
i
]))
{
...
...
@@ -335,8 +340,8 @@ void runtime::interpreter::INTBackend::perform_nan_check(const vector<shared_ptr
}
else
if
(
type
==
element
::
f64
)
{
const
double
*
data
=
t
v
->
get_data_ptr
<
double
>
();
for
(
size_t
i
=
0
;
i
<
t
v
->
get_element_count
();
i
++
)
const
double
*
data
=
t
ensor
->
get_data_ptr
<
double
>
();
for
(
size_t
i
=
0
;
i
<
t
ensor
->
get_element_count
();
i
++
)
{
if
(
std
::
isnan
(
data
[
i
]))
{
...
...
src/ngraph/runtime/interpreter/int_backend.hpp
View file @
6f511762
...
...
@@ -54,6 +54,7 @@
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
...
...
@@ -165,6 +166,7 @@ public:
bool
is_supported
(
const
Node
&
node
)
const
override
{
return
true
;
}
private
:
static
const
int
m_alignment
;
class
FunctionInstance
{
public
:
...
...
@@ -173,6 +175,9 @@ private:
bool
m_performance_counters_enabled
=
false
;
std
::
unordered_map
<
const
Node
*
,
stopwatch
>
m_timer_map
;
std
::
vector
<
NodeWrapper
>
m_wrapped_nodes
;
std
::
unique_ptr
<
AlignedBuffer
>
m_temporary_memory
;
void
*
get_temporary_pointer
(
size_t
offset
)
{
return
m_temporary_memory
->
get_ptr
(
offset
);
}
};
std
::
map
<
std
::
shared_ptr
<
Function
>
,
FunctionInstance
>
m_function_map
;
...
...
@@ -181,14 +186,14 @@ private:
void
generate_calls
(
const
element
::
Type
&
type
,
const
NodeWrapper
&
op
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
inputs
,
const
std
::
vector
<
void
*
>&
outputs
,
const
std
::
vector
<
const
void
*
>&
inputs
,
FunctionInstance
&
instance
);
template
<
typename
T
>
void
op_engine
(
const
NodeWrapper
&
node_wrapper
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
out
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
args
,
const
std
::
vector
<
void
*
>&
out
,
const
std
::
vector
<
const
void
*
>&
args
,
FunctionInstance
&
instance
)
{
const
Node
&
node
=
node_wrapper
.
get_node
();
...
...
@@ -205,58 +210,63 @@ private:
{
case
OP_TYPEID
:
:
Abs
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
abs
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Acos
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
acos
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Add
:
{
reference
::
add
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
add
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
AllReduce
:
{
#ifdef NGRAPH_DISTRIBUTED
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_element_type
(
),
static_cast
<
int
>
(
args
[
0
]
->
get_element_count
(
)));
reference
::
allreduce
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)
)));
#endif
break
;
}
case
OP_TYPEID
:
:
And
:
{
reference
::
logical_and
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_and
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
ArgMin
:
{
const
op
::
ArgMin
*
argmin
=
static_cast
<
const
op
::
ArgMin
*>
(
&
node
);
if
(
out
[
0
]
->
get_element_type
()
==
element
::
i64
)
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
{
reference
::
argmin
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
argmin
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
}
else
if
(
out
[
0
]
->
get_element_type
()
==
element
::
i32
)
else
if
(
element_type
==
element
::
i32
)
{
reference
::
argmin
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
argmin
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
}
else
...
...
@@ -268,20 +278,21 @@ private:
case
OP_TYPEID
:
:
ArgMax
:
{
const
op
::
ArgMax
*
argmax
=
static_cast
<
const
op
::
ArgMax
*>
(
&
node
);
if
(
out
[
0
]
->
get_element_type
()
==
element
::
i64
)
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
{
reference
::
argmax
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
argmax
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
}
else
if
(
out
[
0
]
->
get_element_type
()
==
element
::
i32
)
else
if
(
element_type
==
element
::
i32
)
{
reference
::
argmax
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
argmax
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
}
else
...
...
@@ -292,24 +303,26 @@ private:
}
case
OP_TYPEID
:
:
Asin
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
asin
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Atan
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
atan
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
AvgPool
:
{
const
op
::
AvgPool
*
avg_pool
=
static_cast
<
const
op
::
AvgPool
*>
(
&
node
);
reference
::
avg_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
avg_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
avg_pool
->
get_window_shape
(),
avg_pool
->
get_window_movement_strides
(),
avg_pool
->
get_padding_below
(),
...
...
@@ -327,8 +340,9 @@ private:
const
op
::
GetOutputElement
*
get_output_element
=
static_cast
<
const
op
::
GetOutputElement
*>
(
&
node
);
size_t
n
=
get_output_element
->
get_n
();
size_t
num_bytes
=
out
[
0
]
->
get_element_count
()
*
out
[
0
]
->
get_element_type
().
size
();
std
::
memcpy
(
out
[
0
]
->
get_data_ptr
(),
args
[
n
]
->
get_data_ptr
(),
num_bytes
);
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
num_bytes
=
element_count
*
node
.
get_output_element_type
(
0
).
size
();
std
::
memcpy
(
static_cast
<
T
*>
(
out
[
0
]),
args
[
n
],
num_bytes
);
break
;
}
case
OP_TYPEID
:
:
BatchNormTraining
:
...
...
@@ -337,26 +351,25 @@ private:
static_cast
<
const
ngraph
::
op
::
BatchNormTraining
*>
(
&
node
);
if
(
bn
->
get_output_size
()
==
3
)
{
reference
::
batch_norm_three_outputs
<
T
>
(
bn
->
get_eps_value
(),
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
2
]
->
get_data_ptr
()),
args
[
2
]
->
get_shape
());
reference
::
batch_norm_three_outputs
<
T
>
(
bn
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
const
T
*>
(
args
[
2
]),
static_cast
<
T
*>
(
out
[
0
]),
static_cast
<
T
*>
(
out
[
1
]),
static_cast
<
T
*>
(
out
[
2
]),
node
.
get_input_shape
(
2
));
}
else
{
reference
::
batch_norm_one_output
<
T
>
(
bn
->
get_eps_value
(),
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
3
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
4
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()
),
args
[
2
]
->
get_shape
(
));
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
const
T
*>
(
args
[
3
]
),
static_cast
<
const
T
*>
(
args
[
4
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
2
));
}
break
;
}
...
...
@@ -365,13 +378,13 @@ private:
const
ngraph
::
op
::
BatchNormInference
*
bn
=
static_cast
<
const
ngraph
::
op
::
BatchNormInference
*>
(
&
node
);
reference
::
batch_norm_one_output
<
T
>
(
bn
->
get_eps_value
(),
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
3
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
4
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()
),
args
[
2
]
->
get_shape
(
));
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
const
T
*>
(
args
[
3
]
),
static_cast
<
const
T
*>
(
args
[
4
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
2
));
break
;
}
case
OP_TYPEID
:
:
BatchNormTrainingBackprop
:
...
...
@@ -379,25 +392,25 @@ private:
const
ngraph
::
op
::
BatchNormTrainingBackprop
*
bn_bprop
=
static_cast
<
const
ngraph
::
op
::
BatchNormTrainingBackprop
*>
(
&
node
);
reference
::
batch_norm_backprop
(
bn_bprop
->
get_eps_value
(),
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
3
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
4
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
args
[
5
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
out
[
1
]
->
get_data_ptr
()
),
reinterpret_cast
<
T
*>
(
out
[
2
]
->
get_data_ptr
()
),
args
[
2
]
->
get_shape
(
));
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
const
T
*>
(
args
[
3
]
),
static_cast
<
const
T
*>
(
args
[
4
]
),
static_cast
<
const
T
*>
(
args
[
5
]
),
static_cast
<
T
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
static_cast
<
T
*>
(
out
[
2
]
),
node
.
get_input_shape
(
2
));
break
;
}
case
OP_TYPEID
:
:
AvgPoolBackprop
:
{
const
op
::
AvgPoolBackprop
*
apb
=
static_cast
<
const
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
avg_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
apb
->
get_window_shape
(),
apb
->
get_window_movement_strides
(),
apb
->
get_padding_below
(),
...
...
@@ -408,11 +421,11 @@ private:
case
OP_TYPEID
:
:
Broadcast
:
{
const
op
::
Broadcast
*
broadcast
=
static_cast
<
const
op
::
Broadcast
*>
(
&
node
);
Shape
in_shape
=
args
[
0
]
->
get_shape
(
);
Shape
out_shape
=
out
[
0
]
->
get_shape
(
);
Shape
in_shape
=
node
.
get_input_shape
(
0
);
Shape
out_shape
=
node
.
get_output_shape
(
0
);
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
reference
::
broadcast
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
reference
::
broadcast
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
in_shape
,
out_shape
,
broadcast_axes
);
...
...
@@ -420,8 +433,9 @@ private:
}
case
OP_TYPEID
:
:
Ceiling
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
ceiling
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Concat
:
...
...
@@ -429,94 +443,82 @@ private:
const
op
::
Concat
*
concat
=
static_cast
<
const
op
::
Concat
*>
(
&
node
);
std
::
vector
<
const
T
*>
in_args
;
std
::
vector
<
Shape
>
in_shapes
;
for
(
s
td
::
shared_ptr
<
HostTensor
>
arg
:
args
)
for
(
s
ize_t
i
=
0
;
i
<
node
.
get_input_size
();
i
++
)
{
in_args
.
push_back
(
arg
->
get_data_ptr
<
T
>
(
));
in_shapes
.
push_back
(
arg
->
get_shape
(
));
in_args
.
push_back
(
static_cast
<
const
T
*>
(
args
[
i
]
));
in_shapes
.
push_back
(
node
.
get_input_shape
(
i
));
}
reference
::
concat
<
T
>
(
in_args
,
out
[
0
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
in_shapes
,
out
[
0
]
->
get_shape
(
),
node
.
get_output_shape
(
0
),
concat
->
get_concatenation_axis
());
break
;
}
case
OP_TYPEID
:
:
Constant
:
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
reference
::
constant
<
T
>
(
c
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
// Constant is handled in the main loop
break
;
}
case
OP_TYPEID
:
:
Convert
:
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element
::
Type
type
=
node
.
get_element_type
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
if
(
type
==
element
::
boolean
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
char
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
f32
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
float
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
float
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
f64
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
double
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
double
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
i8
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int8_t
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
i16
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
int16_t
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int16_t
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
i32
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int32_t
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
i64
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int64_t
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
u8
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint8_t
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
u16
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint16_t
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint16_t
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
u32
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint32_t
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint32_t
*>
(
out
[
0
]),
element_count
);
}
else
if
(
type
==
element
::
u64
)
{
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint64_t
>
(),
out
[
0
]
->
get_element_count
());
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint64_t
*>
(
out
[
0
]),
element_count
);
}
else
{
...
...
@@ -529,12 +531,12 @@ private:
case
OP_TYPEID
:
:
Convolution
:
{
const
op
::
Convolution
*
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
args
[
1
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
c
->
get_window_movement_strides
(),
c
->
get_window_dilation_strides
(),
c
->
get_padding_below
(),
...
...
@@ -553,12 +555,12 @@ private:
{
const
op
::
ConvolutionBackpropFilters
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
args
[
1
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
...
...
@@ -578,12 +580,12 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
const
op
::
ConvolutionBackpropData
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
args
[
1
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_shape
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
...
...
@@ -600,14 +602,16 @@ private:
}
case
OP_TYPEID
:
:
Cos
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cos
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Cosh
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cosh
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Dequantize
:
...
...
@@ -617,22 +621,22 @@ private:
if
(
type
==
element
::
f32
)
{
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
float
>
(
),
args
[
2
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
float
>
(
),
args
[
0
]
->
get_shape
(
),
args
[
1
]
->
get_shape
(
),
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
float
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
float
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
}
else
if
(
type
==
element
::
f64
)
{
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
double
>
(
),
args
[
2
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
double
>
(
),
args
[
0
]
->
get_shape
(
),
args
[
1
]
->
get_shape
(
),
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
double
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
double
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
}
else
...
...
@@ -646,43 +650,47 @@ private:
}
case
OP_TYPEID
:
:
Divide
:
{
reference
::
divide
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
divide
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Dot
:
{
const
op
::
Dot
*
dot
=
static_cast
<
const
op
::
Dot
*>
(
&
node
);
reference
::
dot
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
args
[
1
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
dot
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
dot
->
get_reduction_axes_count
());
break
;
}
case
OP_TYPEID
:
:
Equal
:
{
reference
::
equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
char
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Exp
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
exp
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Floor
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
floor
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
FunctionCall
:
...
...
@@ -690,15 +698,24 @@ private:
std
::
shared_ptr
<
Function
>
function
=
node
.
get_functions
()[
0
];
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>
outputs
;
for
(
auto
tv
:
out
)
for
(
size_t
i
=
0
;
i
<
function
->
get_output_size
();
i
++
)
{
outputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
Tensor
>
(
tv
));
element
::
Type
et
=
function
->
get_output_element_type
(
i
);
Shape
shape
=
function
->
get_output_shape
(
i
);
auto
host_tensor
=
std
::
make_shared
<
HostTensor
>
(
et
,
shape
,
out
[
i
]);
outputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
Tensor
>
(
host_tensor
));
}
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>
inputs
;
for
(
auto
tv
:
args
)
auto
parameters
=
function
->
get_parameters
();
for
(
size_t
i
=
0
;
i
<
parameters
.
size
();
i
++
)
{
inputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
Tensor
>
(
tv
));
auto
parameter
=
parameters
[
i
];
element
::
Type
et
=
parameter
->
get_element_type
();
Shape
shape
=
parameter
->
get_shape
();
auto
host_tensor
=
std
::
make_shared
<
HostTensor
>
(
et
,
shape
,
const_cast
<
void
*>
(
args
[
i
]));
inputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
Tensor
>
(
host_tensor
));
}
call
(
function
,
outputs
,
inputs
);
...
...
@@ -706,48 +723,53 @@ private:
}
case
OP_TYPEID
:
:
Greater
:
{
reference
::
greater
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
char
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
GreaterEq
:
{
reference
::
greater_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
char
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Less
:
{
reference
::
less
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
char
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
LessEq
:
{
reference
::
less_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
char
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Log
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
log
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
LRN
:
{
const
op
::
LRN
*
lrn
=
static_cast
<
const
op
::
LRN
*>
(
&
node
);
reference
::
lrn
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
reference
::
lrn
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
lrn
->
get_alpha
(),
lrn
->
get_beta
(),
lrn
->
get_bias
(),
...
...
@@ -757,29 +779,30 @@ private:
case
OP_TYPEID
:
:
Max
:
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
max
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
max
->
get_reduction_axes
());
break
;
}
case
OP_TYPEID
:
:
Maximum
:
{
reference
::
maximum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
maximum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
MaxPool
:
{
const
op
::
MaxPool
*
max_pool
=
static_cast
<
const
op
::
MaxPool
*>
(
&
node
);
reference
::
max_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
max_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
max_pool
->
get_window_shape
(),
max_pool
->
get_window_movement_strides
(),
max_pool
->
get_padding_below
(),
...
...
@@ -791,11 +814,11 @@ private:
const
op
::
MaxPoolBackprop
*
max_pool_backprop
=
static_cast
<
const
op
::
MaxPoolBackprop
*>
(
&
node
);
reference
::
max_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
max_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
max_pool_backprop
->
get_window_shape
(),
max_pool_backprop
->
get_window_movement_strides
(),
max_pool_backprop
->
get_padding_below
(),
...
...
@@ -805,65 +828,71 @@ private:
case
OP_TYPEID
:
:
Min
:
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
min
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
min
->
get_reduction_axes
());
break
;
}
case
OP_TYPEID
:
:
Minimum
:
{
reference
::
minimum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
minimum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Multiply
:
{
reference
::
multiply
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
multiply
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Negative
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
negate
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Not
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_not
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
NotEqual
:
{
reference
::
not_equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
not_equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
char
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
OneHot
:
{
const
op
::
OneHot
*
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
one_hot
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
oh
->
get_one_hot_axis
());
break
;
}
case
OP_TYPEID
:
:
Or
:
{
reference
::
logical_or
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_or
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Parameter
:
break
;
...
...
@@ -871,9 +900,9 @@ private:
{
const
op
::
Pad
*
pad
=
static_cast
<
const
op
::
Pad
*>
(
&
node
);
reference
::
pad
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
reference
::
pad
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
pad
->
get_padding_below
(),
...
...
@@ -883,19 +912,20 @@ private:
}
case
OP_TYPEID
:
:
Power
:
{
reference
::
power
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
power
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Product
:
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
product
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
product
->
get_reduction_axes
());
break
;
}
...
...
@@ -906,23 +936,23 @@ private:
if
(
type
==
element
::
u8
)
{
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
args
[
2
]
->
get_data_ptr
<
uint8_t
>
(
),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(
),
args
[
0
]
->
get_shape
(
),
args
[
1
]
->
get_shape
(
),
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
uint8_t
*>
(
args
[
2
]
),
static_cast
<
uint8_t
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
quantize
->
get_round_mode
());
}
else
if
(
type
==
element
::
i8
)
{
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
args
[
2
]
->
get_data_ptr
<
int8_t
>
(
),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(
),
args
[
0
]
->
get_shape
(
),
args
[
1
]
->
get_shape
(
),
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
int8_t
*>
(
args
[
2
]
),
static_cast
<
int8_t
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
quantize
->
get_round_mode
());
}
...
...
@@ -942,20 +972,18 @@ private:
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensor
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_temp_x"
);
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
&
x
,
"reduce_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensor
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_temp_y"
);
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
&
y
,
"reduce_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensor
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_temp_r"
);
*
(
tx
->
get_data_ptr
<
T
>
())
=
x
;
*
(
ty
->
get_data_ptr
<
T
>
())
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
tr
->
get_data_ptr
<
T
>
());
};
reference
::
reduce
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
reference
::
reduce
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
reduce
->
get_reduction_axes
(),
...
...
@@ -968,21 +996,23 @@ private:
std
::
shared_ptr
<
Function
>
reduction_function
=
reduce_window
->
get_functions
()[
0
];
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensor
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_window_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensor
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_window_temp_y"
);
auto
tx
=
std
::
make_shared
<
HostTensor
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
&
x
,
"reduce_window_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensor
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
&
y
,
"reduce_window_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensor
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_window_temp_r"
);
*
(
tx
->
get_data_ptr
<
T
>
())
=
x
;
*
(
ty
->
get_data_ptr
<
T
>
())
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
tr
->
get_data_ptr
<
T
>
());
};
reference
::
reduce_window
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
reference
::
reduce_window
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
f
,
...
...
@@ -992,56 +1022,58 @@ private:
}
case
OP_TYPEID
:
:
Relu
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
ReluBackprop
:
{
reference
::
relu_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
ReplaceSlice
:
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_shape
(
),
reference
::
replace_slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
1
),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
(
));
node
.
get_output_shape
(
0
));
break
;
}
case
OP_TYPEID
:
:
Reshape
:
{
const
op
::
Reshape
*
reshape
=
static_cast
<
const
op
::
Reshape
*>
(
&
node
);
reference
::
reshape
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
reference
::
reshape
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
reshape
->
get_input_order
(),
out
[
0
]
->
get_shape
(
));
node
.
get_output_shape
(
0
));
break
;
}
case
OP_TYPEID
:
:
Result
:
{
const
op
::
Result
*
res
=
static_cast
<
const
op
::
Result
*>
(
&
node
);
reference
::
result
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
reference
::
result
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
shape_size
(
res
->
get_shape
()));
break
;
}
case
OP_TYPEID
:
:
Reverse
:
{
const
op
::
Reverse
*
reverse
=
static_cast
<
const
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
reverse
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
reverse
->
get_reversed_axes
());
break
;
}
...
...
@@ -1049,14 +1081,14 @@ private:
{
const
op
::
ReverseSequence
*
reverse
=
static_cast
<
const
op
::
ReverseSequence
*>
(
&
node
);
if
(
args
[
1
]
->
get_element_type
(
)
==
element
::
i32
)
if
(
node
.
get_input_element_type
(
1
)
==
element
::
i32
)
{
reference
::
reverse_sequence
<
T
,
int
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
reference
::
reverse_sequence
<
T
,
int
32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
reverse
->
get_batch_axis
(),
reverse
->
get_sequence_axis
(),
args
[
1
]
->
get_data_ptr
<
int
>
(
));
static_cast
<
const
int32_t
*>
(
args
[
1
]
));
}
else
{
...
...
@@ -1066,11 +1098,12 @@ private:
}
case
OP_TYPEID
:
:
Select
:
{
reference
::
select
<
T
>
(
args
[
0
]
->
get_data_ptr
<
char
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
args
[
2
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
select
<
T
>
(
static_cast
<
const
char
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
const
T
*>
(
args
[
2
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
SelectAndScatter
:
...
...
@@ -1083,13 +1116,11 @@ private:
std
::
function
<
bool
(
T
,
T
)
>
f_selection
=
[
this
,
&
node
,
selection_function
](
T
x
,
T
y
)
->
bool
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensor
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"selection_temp_x"
);
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
&
x
,
"selection_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensor
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"selection_temp_y"
);
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
&
y
,
"selection_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensor
>
(
element
::
boolean
,
Shape
{},
"selection_temp_r"
);
*
(
tx
->
get_data_ptr
<
T
>
())
=
x
;
*
(
ty
->
get_data_ptr
<
T
>
())
=
y
;
call
(
selection_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
tr
->
get_data_ptr
<
char
>
());
};
...
...
@@ -1098,24 +1129,22 @@ private:
select_and_scatter
->
get_functions
()[
1
];
std
::
function
<
T
(
T
,
T
)
>
f_scatter
=
[
this
,
&
node
,
scatter_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensor
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"scatter_temp_x"
);
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
&
x
,
"scatter_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensor
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"scatter_temp_y"
);
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
&
y
,
"scatter_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensor
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"scatter_temp_r"
);
*
(
tx
->
get_data_ptr
<
T
>
())
=
x
;
*
(
ty
->
get_data_ptr
<
T
>
())
=
y
;
call
(
scatter_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
tr
->
get_data_ptr
<
T
>
());
};
reference
::
select_and_scatter
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
1
]
->
get_data_ptr
<
T
>
(
),
args
[
2
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
args
[
1
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
select_and_scatter
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
f_selection
,
f_scatter
,
select_and_scatter
->
get_window_shape
(),
...
...
@@ -1124,116 +1153,125 @@ private:
}
case
OP_TYPEID
:
:
Sigmoid
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
SigmoidBackprop
:
{
reference
::
sigmoid_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sign
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sign
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sin
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sin
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sinh
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sinh
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Slice
:
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
reference
::
slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
(
));
node
.
get_output_shape
(
0
));
break
;
}
case
OP_TYPEID
:
:
Softmax
:
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_shape
(
),
reference
::
softmax
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_output_shape
(
0
),
softmax
->
get_axes
());
break
;
}
case
OP_TYPEID
:
:
Sqrt
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sqrt
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
StopGradient
:
{
throw
unsupported_op
(
"Unsupported op 'StopGradient'"
);
}
case
OP_TYPEID
:
:
Subtract
:
{
reference
::
subtract
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
1
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
());
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
subtract
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sum
:
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
sum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
sum
->
get_reduction_axes
());
break
;
}
case
OP_TYPEID
:
:
Tan
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tan
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Tanh
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tanh
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_element_count
()
);
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
break
;
}
case
OP_TYPEID
:
:
TopK
:
{
const
op
::
TopK
*
topk
=
static_cast
<
const
op
::
TopK
*>
(
&
node
);
if
(
out
[
0
]
->
get_element_type
(
)
==
element
::
i64
)
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i64
)
{
reference
::
topk
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
topk
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
topk
->
get_k
(),
topk
->
get_compute_max
());
}
else
if
(
out
[
0
]
->
get_element_type
(
)
==
element
::
i32
)
else
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i32
)
{
reference
::
topk
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
args
[
0
]
->
get_shape
(
),
out
[
0
]
->
get_shape
(
),
reference
::
topk
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
topk
->
get_k
(),
topk
->
get_compute_max
());
...
...
src/ngraph/runtime/reference/relu.hpp
View file @
6f511762
...
...
@@ -34,7 +34,7 @@ namespace ngraph
}
}
template
<
typename
T
>
void
relu_backprop
(
const
T
*
arg
,
T
*
delta_arg
,
T
*
out
,
size_t
count
)
void
relu_backprop
(
const
T
*
arg
,
const
T
*
delta_arg
,
T
*
out
,
size_t
count
)
{
T
zero
=
0
;
for
(
size_t
i
=
0
;
i
<
count
;
i
++
)
...
...
src/ngraph/runtime/reference/reverse_sequence.hpp
View file @
6f511762
...
...
@@ -34,7 +34,7 @@ namespace ngraph
const
Shape
&
arg_shape
,
size_t
batch_axis
,
size_t
sequence_axis
,
U
*
sequence_lengths
)
const
U
*
sequence_lengths
)
{
CoordinateTransform
input_transform
(
arg_shape
);
for
(
const
Coordinate
&
in_coord
:
input_transform
)
...
...
src/ngraph/runtime/reference/sigmoid.hpp
View file @
6f511762
...
...
@@ -37,7 +37,7 @@ namespace ngraph
}
template
<
typename
T
>
void
sigmoid_backprop
(
const
T
*
arg
,
T
*
delta_arg
,
T
*
out
,
size_t
count
)
void
sigmoid_backprop
(
const
T
*
arg
,
const
T
*
delta_arg
,
T
*
out
,
size_t
count
)
{
T
exp_value
;
T
func_x
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment