Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
89963725
Commit
89963725
authored
Apr 26, 2018
by
Robert Kimball
Committed by
Scott Cyphers
Apr 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Replace interpreter with Adam's simplified implementation (#915)
* wip * simplified interpreter backend
parent
66198b33
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1124 additions
and
2556 deletions
+1124
-2556
CMakeLists.txt
src/ngraph/CMakeLists.txt
+0
-3
ie_backend.cpp
src/ngraph/runtime/ie/ie_backend.cpp
+0
-243
ie_backend.hpp
src/ngraph/runtime/ie/ie_backend.hpp
+0
-885
int_backend.cpp
src/ngraph/runtime/interpreter/int_backend.cpp
+266
-38
int_backend.hpp
src/ngraph/runtime/interpreter/int_backend.hpp
+854
-26
int_call_frame.cpp
src/ngraph/runtime/interpreter/int_call_frame.cpp
+0
-293
int_call_frame.hpp
src/ngraph/runtime/interpreter/int_call_frame.hpp
+0
-882
int_external_function.cpp
src/ngraph/runtime/interpreter/int_external_function.cpp
+0
-133
int_external_function.hpp
src/ngraph/runtime/interpreter/int_external_function.hpp
+0
-48
CMakeLists.txt
test/CMakeLists.txt
+0
-1
backend_debug_api.cpp
test/backend_debug_api.cpp
+4
-4
No files found.
src/ngraph/CMakeLists.txt
View file @
89963725
...
...
@@ -127,10 +127,7 @@ set (SRC
runtime/aligned_buffer.cpp
runtime/backend.cpp
runtime/host_tensor_view.cpp
runtime/ie/ie_backend.cpp
runtime/interpreter/int_backend.cpp
runtime/interpreter/int_call_frame.cpp
runtime/interpreter/int_external_function.cpp
runtime/tensor_view.cpp
serializer.cpp
shape.cpp
...
...
src/ngraph/runtime/ie/ie_backend.cpp
deleted
100644 → 0
View file @
66198b33
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/runtime/ie/ie_backend.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
ngraph
;
using
descriptor
::
layout
::
DenseTensorViewLayout
;
static
bool
static_init
()
{
runtime
::
Backend
::
register_backend
(
"IE"
,
make_shared
<
runtime
::
ie
::
IE_Backend
>
());
return
true
;
};
bool
runtime
::
ie
::
IE_Backend
::
init
=
static_init
();
shared_ptr
<
runtime
::
TensorView
>
runtime
::
ie
::
IE_Backend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
{
return
make_shared
<
runtime
::
HostTensorView
>
(
type
,
shape
,
"external"
);
}
shared_ptr
<
runtime
::
TensorView
>
runtime
::
ie
::
IE_Backend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
{
return
make_shared
<
runtime
::
HostTensorView
>
(
type
,
shape
,
memory_pointer
,
"external"
);
}
bool
runtime
::
ie
::
IE_Backend
::
compile
(
shared_ptr
<
Function
>
function
)
{
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorViewLayout
>>
();
pass_manager
.
register_pass
<
pass
::
Liveness
>
();
pass_manager
.
run_passes
(
function
);
return
true
;
}
bool
runtime
::
ie
::
IE_Backend
::
call
(
shared_ptr
<
Function
>
function
,
const
vector
<
shared_ptr
<
runtime
::
TensorView
>>&
outputs
,
const
vector
<
shared_ptr
<
runtime
::
TensorView
>>&
inputs
)
{
validate_call
(
function
,
outputs
,
inputs
);
// TODO: check if function already compiled?
compile
(
function
);
// convert inputs to HostTensorView
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
func_inputs
;
for
(
auto
tv
:
inputs
)
{
func_inputs
.
push_back
(
static_pointer_cast
<
runtime
::
HostTensorView
>
(
tv
));
}
// convert outputs to HostTensorView
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
func_outputs
;
for
(
auto
tv
:
outputs
)
{
func_outputs
.
push_back
(
static_pointer_cast
<
runtime
::
HostTensorView
>
(
tv
));
}
// map function params -> HostTensorView
unordered_map
<
descriptor
::
TensorView
*
,
shared_ptr
<
runtime
::
HostTensorView
>>
tensor_map
;
size_t
input_count
=
0
;
for
(
auto
param
:
function
->
get_parameters
())
{
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
{
descriptor
::
TensorView
*
tv
=
param
->
get_output_tensor_view
(
i
).
get
();
tensor_map
.
insert
({
tv
,
func_inputs
[
input_count
++
]});
}
}
// map function outputs -> HostTensorView
for
(
size_t
output_count
=
0
;
output_count
<
function
->
get_output_size
();
++
output_count
)
{
auto
output
=
function
->
get_output_op
(
output_count
);
if
(
!
dynamic_pointer_cast
<
op
::
Result
>
(
output
))
{
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
}
descriptor
::
TensorView
*
tv
=
output
->
get_output_tensor_view
(
0
).
get
();
tensor_map
.
insert
({
tv
,
func_outputs
[
output_count
]});
}
// for each ordered op in the graph
for
(
shared_ptr
<
Node
>
op
:
function
->
get_ordered_ops
())
{
if
(
op
->
description
()
==
"Parameter"
)
{
continue
;
}
// get op inputs from map
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
op_inputs
;
for
(
const
descriptor
::
Input
&
input
:
op
->
get_inputs
())
{
descriptor
::
TensorView
*
tv
=
input
.
get_output
().
get_tensor_view
().
get
();
op_inputs
.
push_back
(
tensor_map
.
at
(
tv
));
}
// get op outputs from map or create
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
op_outputs
;
for
(
size_t
i
=
0
;
i
<
op
->
get_output_size
();
++
i
)
{
descriptor
::
TensorView
*
tv
=
op
->
get_output_tensor_view
(
i
).
get
();
shared_ptr
<
runtime
::
HostTensorView
>
htv
;
if
(
!
contains_key
(
tensor_map
,
tv
))
{
// the output tensor is not in the tensor map so create a new tensor
const
Shape
&
shape
=
op
->
get_output_shape
(
i
);
const
element
::
Type
&
type
=
op
->
get_output_element_type
(
i
);
string
name
=
op
->
get_output_tensor
(
i
).
get_name
();
htv
=
make_shared
<
runtime
::
HostTensorView
>
(
type
,
shape
,
name
);
tensor_map
.
insert
({
tv
,
htv
});
}
else
{
htv
=
tensor_map
.
at
(
tv
);
}
op_outputs
.
push_back
(
htv
);
}
// get op type
element
::
Type
type
;
if
(
dynamic_pointer_cast
<
op
::
util
::
BinaryElementwiseComparison
>
(
op
)
||
dynamic_pointer_cast
<
op
::
Select
>
(
op
))
{
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type
=
op
->
get_inputs
().
at
(
1
).
get_tensor
().
get_element_type
();
}
else
if
(
dynamic_pointer_cast
<
op
::
Convert
>
(
op
))
{
type
=
op
->
get_inputs
().
at
(
0
).
get_tensor
().
get_element_type
();
}
else
{
type
=
op
->
get_element_type
();
}
generate_calls
(
type
,
*
op
,
op_outputs
,
op_inputs
);
// delete any obsolete tensors
for
(
const
descriptor
::
Tensor
*
t
:
op
->
liveness_free_list
)
{
for
(
auto
it
=
tensor_map
.
begin
();
it
!=
tensor_map
.
end
();
++
it
)
{
if
(
it
->
second
->
get_tensor
().
get_name
()
==
t
->
get_name
())
{
tensor_map
.
erase
(
it
);
break
;
}
}
}
}
return
true
;
}
void
runtime
::
ie
::
IE_Backend
::
generate_calls
(
const
element
::
Type
&
type
,
Node
&
op
,
const
vector
<
shared_ptr
<
HostTensorView
>>&
outputs
,
const
vector
<
shared_ptr
<
HostTensorView
>>&
inputs
)
{
if
(
type
==
element
::
boolean
)
{
op_engine
<
char
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
f32
)
{
op_engine
<
float
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
f64
)
{
op_engine
<
double
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
i8
)
{
op_engine
<
int8_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
i16
)
{
op_engine
<
int16_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
i32
)
{
op_engine
<
int32_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
i64
)
{
op_engine
<
int64_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
u8
)
{
op_engine
<
uint8_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
u16
)
{
op_engine
<
uint16_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
u32
)
{
op_engine
<
uint32_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
u64
)
{
op_engine
<
uint64_t
>
(
op
,
outputs
,
inputs
);
}
else
{
stringstream
ss
;
ss
<<
"unsupported element type "
<<
type
<<
" op "
<<
op
.
get_name
();
throw
ngraph_error
(
ss
.
str
());
}
}
src/ngraph/runtime/ie/ie_backend.hpp
deleted
100644 → 0
View file @
66198b33
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce.hpp"
#include "ngraph/op/reduce_window.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/copy.hpp"
#include "ngraph/runtime/reference/cos.hpp"
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/log.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/max_pool.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/reduce.hpp"
#include "ngraph/runtime/reference/reduce_window.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/select_and_scatter.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sin.hpp"
#include "ngraph/runtime/reference/sinh.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/runtime/reference/allreduce.hpp"
#endif
namespace
ngraph
{
namespace
runtime
{
namespace
ie
{
class
IE_Backend
;
}
}
}
class
ngraph
::
runtime
::
ie
::
IE_Backend
:
public
Backend
{
public
:
std
::
shared_ptr
<
TensorView
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
override
;
std
::
shared_ptr
<
TensorView
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
override
;
bool
compile
(
std
::
shared_ptr
<
Function
>
function
)
override
;
bool
call
(
std
::
shared_ptr
<
Function
>
function
,
const
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>&
intputs
)
override
;
private
:
static
bool
init
;
void
generate_calls
(
const
element
::
Type
&
type
,
Node
&
op
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
inputs
);
template
<
typename
T
>
void
op_engine
(
Node
&
node
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
out
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
args
)
{
std
::
string
node_op
=
node
.
description
();
if
(
node_op
==
"Abs"
)
{
reference
::
abs
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Acos"
)
{
reference
::
acos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Add"
)
{
reference
::
add
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
#ifdef NGRAPH_DISTRIBUTED
else
if
(
node_op
==
"AllReduce"
)
{
reference
::
allreduce
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_element_type
(),
static_cast
<
int
>
(
args
[
0
]
->
get_element_count
()));
}
#endif
else
if
(
node_op
==
"And"
)
{
reference
::
logical_and
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Asin"
)
{
reference
::
asin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Atan"
)
{
reference
::
atan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"AvgPool"
)
{
op
::
AvgPool
*
avg_pool
=
dynamic_cast
<
op
::
AvgPool
*>
(
&
node
);
reference
::
avg_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
avg_pool
->
get_window_shape
(),
avg_pool
->
get_window_movement_strides
(),
avg_pool
->
get_padding_below
(),
avg_pool
->
get_padding_above
(),
avg_pool
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"AvgPoolBackprop"
)
{
op
::
AvgPoolBackprop
*
apb
=
dynamic_cast
<
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
apb
->
get_window_shape
(),
apb
->
get_window_movement_strides
(),
apb
->
get_padding_below
(),
apb
->
get_padding_above
(),
apb
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"Broadcast"
)
{
op
::
Broadcast
*
broadcast
=
dynamic_cast
<
op
::
Broadcast
*>
(
&
node
);
Shape
in_shape
=
args
[
0
]
->
get_shape
();
Shape
out_shape
=
out
[
0
]
->
get_shape
();
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
reference
::
broadcast
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shape
,
out_shape
,
broadcast_axes
);
}
else
if
(
node_op
==
"Ceiling"
)
{
reference
::
ceiling
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Concat"
)
{
const
op
::
Concat
*
concat
=
static_cast
<
const
op
::
Concat
*>
(
&
node
);
std
::
vector
<
const
T
*>
in_args
;
std
::
vector
<
Shape
>
in_shapes
;
for
(
std
::
shared_ptr
<
HostTensorView
>
arg
:
args
)
{
in_args
.
push_back
(
reinterpret_cast
<
T
*>
(
arg
->
get_data_ptr
()));
in_shapes
.
push_back
(
arg
->
get_shape
());
}
reference
::
concat
<
T
>
(
in_args
,
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shapes
,
out
[
0
]
->
get_shape
(),
concat
->
get_concatenation_axis
());
}
else
if
(
node_op
==
"Constant"
)
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
reference
::
constant
<
T
>
(
reinterpret_cast
<
const
T
*>
(
c
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Convert"
)
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element
::
Type
type
=
node
.
get_element_type
();
if
(
type
==
element
::
boolean
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
f32
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
float
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
f64
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
double
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i8
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int8_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i16
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int16_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i32
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i64
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int64_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u8
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint8_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u16
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint16_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u32
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint32_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u64
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint64_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported element type "
<<
type
<<
" op Convert"
;
throw
std
::
runtime_error
(
ss
.
str
());
}
}
else
if
(
node_op
==
"Convolution"
)
{
auto
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides
(),
c
->
get_window_dilation_strides
(),
c
->
get_padding_below
(),
c
->
get_padding_above
(),
c
->
get_data_dilation_strides
(),
0
,
1
,
1
,
0
,
0
,
1
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropFilters"
)
{
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
1
,
0
,
0
,
1
,
1
,
0
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropData"
)
{
// Note that args[1] and args[0] are switched here from the usual order.
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
0
,
1
,
0
,
1
,
0
,
1
,
true
);
}
else
if
(
node_op
==
"Cos"
)
{
reference
::
cos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Cosh"
)
{
reference
::
cosh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Divide"
)
{
reference
::
divide
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Dot"
)
{
op
::
Dot
*
dot
=
dynamic_cast
<
op
::
Dot
*>
(
&
node
);
reference
::
dot
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
dot
->
get_reduction_axes_count
());
}
else
if
(
node_op
==
"Equal"
)
{
reference
::
equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Exp"
)
{
reference
::
exp
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Floor"
)
{
reference
::
floor
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"FunctionCall"
)
{
std
::
shared_ptr
<
Function
>
function
=
node
.
get_functions
()[
0
];
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
outputs
;
for
(
auto
tv
:
out
)
{
outputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
TensorView
>
(
tv
));
}
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
inputs
;
for
(
auto
tv
:
args
)
{
inputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
TensorView
>
(
tv
));
}
call
(
function
,
outputs
,
inputs
);
}
else
if
(
node_op
==
"Greater"
)
{
reference
::
greater
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"GreaterEq"
)
{
reference
::
greater_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Less"
)
{
reference
::
less
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"LessEq"
)
{
reference
::
less_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Log"
)
{
reference
::
log
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Max"
)
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Maximum"
)
{
reference
::
maximum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"MaxPool"
)
{
op
::
MaxPool
*
max_pool
=
dynamic_cast
<
op
::
MaxPool
*>
(
&
node
);
reference
::
max_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool
->
get_window_shape
(),
max_pool
->
get_window_movement_strides
(),
max_pool
->
get_padding_below
(),
max_pool
->
get_padding_above
());
}
else
if
(
node_op
==
"MaxPoolBackprop"
)
{
op
::
MaxPoolBackprop
*
max_pool_backprop
=
dynamic_cast
<
op
::
MaxPoolBackprop
*>
(
&
node
);
reference
::
max_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool_backprop
->
get_window_shape
(),
max_pool_backprop
->
get_window_movement_strides
(),
max_pool_backprop
->
get_padding_below
(),
max_pool_backprop
->
get_padding_above
());
}
else
if
(
node_op
==
"Min"
)
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
min
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Minimum"
)
{
reference
::
minimum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Multiply"
)
{
reference
::
multiply
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Negative"
)
{
reference
::
negate
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Not"
)
{
reference
::
logical_not
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"NotEqual"
)
{
reference
::
not_equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"OneHot"
)
{
auto
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
oh
->
get_one_hot_axis
());
}
else
if
(
node_op
==
"Or"
)
{
reference
::
logical_or
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Parameter"
)
{
}
else
if
(
node_op
==
"Pad"
)
{
op
::
Pad
*
pad
=
dynamic_cast
<
op
::
Pad
*>
(
&
node
);
reference
::
pad
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
pad
->
get_padding_below
(),
pad
->
get_padding_above
(),
pad
->
get_padding_interior
());
}
else
if
(
node_op
==
"Power"
)
{
reference
::
power
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Product"
)
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
product
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Reduce"
)
{
op
::
Reduce
*
reduce
=
dynamic_cast
<
op
::
Reduce
*>
(
&
node
);
std
::
shared_ptr
<
Function
>
reduction_function
=
reduce
->
get_functions
()[
0
];
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
reduce
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
reduce
->
get_reduction_axes
(),
f
);
}
else
if
(
node_op
==
"ReduceWindow"
)
{
op
::
ReduceWindow
*
reduce_window
=
dynamic_cast
<
op
::
ReduceWindow
*>
(
&
node
);
std
::
shared_ptr
<
Function
>
reduction_function
=
reduce_window
->
get_functions
()[
0
];
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_window_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_window_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_window_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
reduce_window
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
f
,
reduce_window
->
get_window_shape
(),
reduce_window
->
get_window_movement_strides
());
}
else
if
(
node_op
==
"Relu"
)
{
reference
::
relu
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"ReluBackprop"
)
{
reference
::
relu_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"ReplaceSlice"
)
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Reshape"
)
{
op
::
Reshape
*
reshape
=
dynamic_cast
<
op
::
Reshape
*>
(
&
node
);
reference
::
reshape
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
reshape
->
get_input_order
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Result"
)
{
op
::
Result
*
res
=
dynamic_cast
<
op
::
Result
*>
(
&
node
);
reference
::
result
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
shape_size
(
res
->
get_shape
()));
}
else
if
(
node_op
==
"Reverse"
)
{
op
::
Reverse
*
reverse
=
dynamic_cast
<
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
reverse
->
get_reversed_axes
());
}
else
if
(
node_op
==
"Select"
)
{
reference
::
select
<
T
>
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"SelectAndScatter"
)
{
ngraph
::
op
::
SelectAndScatter
*
select_and_scatter
=
dynamic_cast
<
ngraph
::
op
::
SelectAndScatter
*>
(
&
node
);
std
::
shared_ptr
<
ngraph
::
Function
>
selection_function
=
select_and_scatter
->
get_functions
()[
0
];
std
::
function
<
bool
(
T
,
T
)
>
f_selection
=
[
this
,
&
node
,
selection_function
](
T
x
,
T
y
)
->
bool
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"selection_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"selection_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
element
::
boolean
,
Shape
{},
"selection_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
selection_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
char
*>
(
tr
->
get_data_ptr
()));
};
std
::
shared_ptr
<
ngraph
::
Function
>
scatter_function
=
select_and_scatter
->
get_functions
()[
1
];
std
::
function
<
T
(
T
,
T
)
>
f_scatter
=
[
this
,
&
node
,
scatter_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"scatter_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"scatter_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"scatter_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
scatter_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
select_and_scatter
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
f_selection
,
f_scatter
,
select_and_scatter
->
get_window_shape
(),
select_and_scatter
->
get_window_movement_strides
());
}
else
if
(
node_op
==
"Sign"
)
{
reference
::
sign
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sin"
)
{
reference
::
sin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sinh"
)
{
reference
::
sinh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Slice"
)
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Softmax"
)
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_shape
(),
softmax
->
get_axes
());
}
else
if
(
node_op
==
"Sqrt"
)
{
reference
::
sqrt
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Subtract"
)
{
reference
::
subtract
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sum"
)
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
sum
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Tan"
)
{
reference
::
tan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Tanh"
)
{
reference
::
tanh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported op "
<<
node_op
;
throw
ngraph_error
(
ss
.
str
());
}
}
};
src/ngraph/runtime/interpreter/int_backend.cpp
View file @
89963725
...
...
@@ -15,81 +15,260 @@
*******************************************************************************/
#include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/interpreter/int_call_frame.hpp"
#include "ngraph/runtime/interpreter/int_external_function.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/util.hpp"
using
namespace
ngraph
;
using
namespace
std
;
using
namespace
ngraph
;
using
descriptor
::
layout
::
DenseTensorViewLayout
;
static
bool
static_init
()
{
runtime
::
Backend
::
register_backend
(
"INTERPRETER"
,
make_shared
<
runtime
::
interpreter
::
INT
_
Backend
>
());
make_shared
<
runtime
::
interpreter
::
INTBackend
>
());
return
true
;
};
bool
runtime
::
interpreter
::
INT_Backend
::
init
=
static_init
();
shared_ptr
<
runtime
::
interpreter
::
INT_CallFrame
>
runtime
::
interpreter
::
INT_Backend
::
make_call_frame
(
const
shared_ptr
<
runtime
::
interpreter
::
ExternalFunction
>&
external_function
)
{
return
external_function
->
make_call_frame
();
}
bool
runtime
::
interpreter
::
INTBackend
::
init
=
static_init
();
shared_ptr
<
runtime
::
TensorView
>
runtime
::
interpreter
::
INT_Backend
::
create_tensor
(
const
element
::
Type
&
element_type
,
const
Shape
&
shape
)
runtime
::
interpreter
::
INTBackend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
{
return
make_shared
<
runtime
::
HostTensorView
>
(
element_
type
,
shape
,
"external"
);
return
make_shared
<
runtime
::
HostTensorView
>
(
type
,
shape
,
"external"
);
}
shared_ptr
<
runtime
::
TensorView
>
runtime
::
interpreter
::
INT
_
Backend
::
create_tensor
(
const
element
::
Type
&
element_
type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
shared_ptr
<
runtime
::
TensorView
>
runtime
::
interpreter
::
INTBackend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
{
return
make_shared
<
runtime
::
HostTensorView
>
(
element_
type
,
shape
,
memory_pointer
,
"external"
);
return
make_shared
<
runtime
::
HostTensorView
>
(
type
,
shape
,
memory_pointer
,
"external"
);
}
bool
runtime
::
interpreter
::
INT
_Backend
::
compile
(
shared_ptr
<
Function
>
func
)
bool
runtime
::
interpreter
::
INT
Backend
::
compile
(
shared_ptr
<
Function
>
function
)
{
FunctionInstance
&
instance
=
m_function_map
[
func
];
if
(
instance
.
m_external_function
==
nullptr
)
FunctionInstance
&
instance
=
m_function_map
[
func
tion
];
if
(
!
instance
.
m_is_compiled
)
{
instance
.
m_
external_function
=
make_shared
<
ExternalFunction
>
(
func
)
;
auto
cf
=
instance
.
m_external_function
->
make_call_frame
()
;
instance
.
m_call_frame
=
dynamic_pointer_cast
<
INT_CallFrame
>
(
cf
);
instance
.
m_call_frame
->
m_emit_timing
=
instance
.
m_performance_counters_enabled
;
instance
.
m_call_frame
->
set_nan_check
(
instance
.
m_nan_check_enabled
);
instance
.
m_
is_compiled
=
true
;
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorViewLayout
>>
(
);
pass_manager
.
register_pass
<
pass
::
Liveness
>
()
;
pass_manager
.
run_passes
(
function
);
}
return
true
;
}
bool
runtime
::
interpreter
::
INT
_Backend
::
call
(
shared_ptr
<
Function
>
func
,
bool
runtime
::
interpreter
::
INT
Backend
::
call
(
shared_ptr
<
Function
>
function
,
const
vector
<
shared_ptr
<
runtime
::
TensorView
>>&
outputs
,
const
vector
<
shared_ptr
<
runtime
::
TensorView
>>&
inputs
)
{
bool
rc
=
true
;
validate_call
(
function
,
outputs
,
inputs
)
;
validate_call
(
func
,
outputs
,
inputs
);
compile
(
function
);
FunctionInstance
&
instance
=
m_function_map
[
function
];
FunctionInstance
&
instance
=
m_function_map
[
func
];
if
(
instance
.
m_external_function
==
nullptr
)
// convert inputs to HostTensorView
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
func_inputs
;
for
(
auto
tv
:
inputs
)
{
func_inputs
.
push_back
(
static_pointer_cast
<
runtime
::
HostTensorView
>
(
tv
));
}
if
(
instance
.
m_nan_check_enabled
)
{
rc
=
compile
(
func
);
perform_nan_check
(
func_inputs
);
}
instance
.
m_call_frame
->
call
(
outputs
,
inputs
);
// convert outputs to HostTensorView
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
func_outputs
;
for
(
auto
tv
:
outputs
)
{
func_outputs
.
push_back
(
static_pointer_cast
<
runtime
::
HostTensorView
>
(
tv
));
}
return
rc
;
// map function params -> HostTensorView
unordered_map
<
descriptor
::
TensorView
*
,
shared_ptr
<
runtime
::
HostTensorView
>>
tensor_map
;
size_t
input_count
=
0
;
for
(
auto
param
:
function
->
get_parameters
())
{
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
{
descriptor
::
TensorView
*
tv
=
param
->
get_output_tensor_view
(
i
).
get
();
tensor_map
.
insert
({
tv
,
func_inputs
[
input_count
++
]});
}
}
// map function outputs -> HostTensorView
for
(
size_t
output_count
=
0
;
output_count
<
function
->
get_output_size
();
++
output_count
)
{
auto
output
=
function
->
get_output_op
(
output_count
);
if
(
!
dynamic_pointer_cast
<
op
::
Result
>
(
output
))
{
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
}
descriptor
::
TensorView
*
tv
=
output
->
get_output_tensor_view
(
0
).
get
();
tensor_map
.
insert
({
tv
,
func_outputs
[
output_count
]});
}
// for each ordered op in the graph
for
(
shared_ptr
<
Node
>
op
:
function
->
get_ordered_ops
())
{
if
(
op
->
description
()
==
"Parameter"
)
{
continue
;
}
// get op inputs from map
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
op_inputs
;
for
(
const
descriptor
::
Input
&
input
:
op
->
get_inputs
())
{
descriptor
::
TensorView
*
tv
=
input
.
get_output
().
get_tensor_view
().
get
();
op_inputs
.
push_back
(
tensor_map
.
at
(
tv
));
}
// get op outputs from map or create
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
op_outputs
;
for
(
size_t
i
=
0
;
i
<
op
->
get_output_size
();
++
i
)
{
descriptor
::
TensorView
*
tv
=
op
->
get_output_tensor_view
(
i
).
get
();
shared_ptr
<
runtime
::
HostTensorView
>
htv
;
if
(
!
contains_key
(
tensor_map
,
tv
))
{
// the output tensor is not in the tensor map so create a new tensor
const
Shape
&
shape
=
op
->
get_output_shape
(
i
);
const
element
::
Type
&
type
=
op
->
get_output_element_type
(
i
);
string
name
=
op
->
get_output_tensor
(
i
).
get_name
();
htv
=
make_shared
<
runtime
::
HostTensorView
>
(
type
,
shape
,
name
);
tensor_map
.
insert
({
tv
,
htv
});
}
else
{
htv
=
tensor_map
.
at
(
tv
);
}
op_outputs
.
push_back
(
htv
);
}
// get op type
element
::
Type
type
;
if
(
dynamic_pointer_cast
<
op
::
util
::
BinaryElementwiseComparison
>
(
op
)
||
dynamic_pointer_cast
<
op
::
Select
>
(
op
))
{
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type
=
op
->
get_inputs
().
at
(
1
).
get_tensor
().
get_element_type
();
}
else
if
(
dynamic_pointer_cast
<
op
::
Convert
>
(
op
))
{
type
=
op
->
get_inputs
().
at
(
0
).
get_tensor
().
get_element_type
();
}
else
{
type
=
op
->
get_element_type
();
}
if
(
instance
.
m_performance_counters_enabled
)
{
instance
.
m_timer_map
[
op
.
get
()].
start
();
}
generate_calls
(
type
,
*
op
,
op_outputs
,
op_inputs
);
if
(
instance
.
m_performance_counters_enabled
)
{
instance
.
m_timer_map
[
op
.
get
()].
stop
();
}
if
(
instance
.
m_nan_check_enabled
)
{
perform_nan_check
(
op_outputs
,
op
.
get
());
}
// delete any obsolete tensors
for
(
const
descriptor
::
Tensor
*
t
:
op
->
liveness_free_list
)
{
for
(
auto
it
=
tensor_map
.
begin
();
it
!=
tensor_map
.
end
();
++
it
)
{
if
(
it
->
second
->
get_tensor
().
get_name
()
==
t
->
get_name
())
{
tensor_map
.
erase
(
it
);
break
;
}
}
}
}
return
true
;
}
void
runtime
::
interpreter
::
INTBackend
::
generate_calls
(
const
element
::
Type
&
type
,
Node
&
op
,
const
vector
<
shared_ptr
<
HostTensorView
>>&
outputs
,
const
vector
<
shared_ptr
<
HostTensorView
>>&
inputs
)
{
if
(
type
==
element
::
boolean
)
{
op_engine
<
char
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
f32
)
{
op_engine
<
float
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
f64
)
{
op_engine
<
double
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
i8
)
{
op_engine
<
int8_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
i16
)
{
op_engine
<
int16_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
i32
)
{
op_engine
<
int32_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
i64
)
{
op_engine
<
int64_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
u8
)
{
op_engine
<
uint8_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
u16
)
{
op_engine
<
uint16_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
u32
)
{
op_engine
<
uint32_t
>
(
op
,
outputs
,
inputs
);
}
else
if
(
type
==
element
::
u64
)
{
op_engine
<
uint64_t
>
(
op
,
outputs
,
inputs
);
}
else
{
stringstream
ss
;
ss
<<
"unsupported element type "
<<
type
<<
" op "
<<
op
.
get_name
();
throw
ngraph_error
(
ss
.
str
());
}
}
void
runtime
::
interpreter
::
INT
_
Backend
::
set_nan_check
(
shared_ptr
<
Function
>
func
,
bool
enable
)
void
runtime
::
interpreter
::
INTBackend
::
set_nan_check
(
shared_ptr
<
Function
>
func
,
bool
enable
)
{
FunctionInstance
&
instance
=
m_function_map
[
func
];
instance
.
m_nan_check_enabled
=
enable
;
}
void
runtime
::
interpreter
::
INT
_
Backend
::
enable_performance_data
(
shared_ptr
<
Function
>
func
,
void
runtime
::
interpreter
::
INTBackend
::
enable_performance_data
(
shared_ptr
<
Function
>
func
,
bool
enable
)
{
FunctionInstance
&
instance
=
m_function_map
[
func
];
...
...
@@ -97,11 +276,11 @@ void runtime::interpreter::INT_Backend::enable_performance_data(shared_ptr<Funct
}
vector
<
runtime
::
PerformanceCounter
>
runtime
::
interpreter
::
INT
_
Backend
::
get_performance_data
(
shared_ptr
<
Function
>
func
)
const
runtime
::
interpreter
::
INTBackend
::
get_performance_data
(
shared_ptr
<
Function
>
func
)
const
{
vector
<
runtime
::
PerformanceCounter
>
rc
;
const
FunctionInstance
&
instance
=
m_function_map
.
at
(
func
);
for
(
const
pair
<
const
Node
*
,
stopwatch
>
p
:
instance
.
m_
call_frame
->
m_
timer_map
)
for
(
const
pair
<
const
Node
*
,
stopwatch
>
p
:
instance
.
m_timer_map
)
{
rc
.
emplace_back
(
p
.
first
->
get_name
().
c_str
(),
p
.
second
.
get_total_microseconds
(),
...
...
@@ -109,3 +288,52 @@ vector<runtime::PerformanceCounter>
}
return
rc
;
}
void
runtime
::
interpreter
::
INTBackend
::
perform_nan_check
(
const
vector
<
shared_ptr
<
HostTensorView
>>&
tvs
,
const
Node
*
op
)
{
size_t
arg_number
=
1
;
for
(
shared_ptr
<
HostTensorView
>
tv
:
tvs
)
{
const
element
::
Type
&
type
=
tv
->
get_tensor
().
get_element_type
();
if
(
type
==
element
::
f32
)
{
const
float
*
data
=
reinterpret_cast
<
float
*>
(
tv
->
get_data_ptr
());
for
(
size_t
i
=
0
;
i
<
tv
->
get_element_count
();
i
++
)
{
if
(
std
::
isnan
(
data
[
i
]))
{
if
(
op
)
{
throw
runtime_error
(
"nan found in op '"
+
op
->
get_name
()
+
"' output"
);
}
else
{
throw
runtime_error
(
"nan found in function's input tensor number "
+
to_string
(
arg_number
));
}
}
}
}
else
if
(
type
==
element
::
f64
)
{
const
double
*
data
=
reinterpret_cast
<
double
*>
(
tv
->
get_data_ptr
());
for
(
size_t
i
=
0
;
i
<
tv
->
get_element_count
();
i
++
)
{
if
(
std
::
isnan
(
data
[
i
]))
{
if
(
op
)
{
throw
runtime_error
(
"nan found in op '"
+
op
->
get_name
()
+
"' output"
);
}
else
{
throw
runtime_error
(
"nan found in function's input tensor number "
+
to_string
(
arg_number
));
}
}
}
}
arg_number
++
;
}
}
src/ngraph/runtime/interpreter/int_backend.hpp
View file @
89963725
...
...
@@ -16,10 +16,101 @@
#pragma once
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce.hpp"
#include "ngraph/op/reduce_window.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/copy.hpp"
#include "ngraph/runtime/reference/cos.hpp"
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/log.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/max_pool.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/reduce.hpp"
#include "ngraph/runtime/reference/reduce_window.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/select_and_scatter.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sin.hpp"
#include "ngraph/runtime/reference/sinh.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/runtime/reference/allreduce.hpp"
#endif
namespace
ngraph
{
...
...
@@ -27,30 +118,24 @@ namespace ngraph
{
namespace
interpreter
{
class
ExternalFunction
;
class
INT_CallFrame
;
class
INT_Backend
:
public
runtime
::
Backend
{
public
:
std
::
shared_ptr
<
INT_CallFrame
>
make_call_frame
(
const
std
::
shared_ptr
<
ngraph
::
runtime
::
interpreter
::
ExternalFunction
>&
external_function
);
std
::
shared_ptr
<
ngraph
::
runtime
::
TensorView
>
create_tensor
(
const
ngraph
::
element
::
Type
&
element_type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
override
;
class
INTBackend
;
}
}
}
class
ngraph
::
runtime
::
interpreter
::
INTBackend
:
public
Backend
{
public
:
std
::
shared_ptr
<
TensorView
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
override
;
std
::
shared_ptr
<
ngraph
::
runtime
::
TensorView
>
create_tensor
(
const
ngraph
::
element
::
Type
&
element_type
,
std
::
shared_ptr
<
TensorView
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
override
;
bool
compile
(
std
::
shared_ptr
<
Function
>
func
)
override
;
bool
compile
(
std
::
shared_ptr
<
Function
>
function
)
override
;
bool
call
(
std
::
shared_ptr
<
Function
>
func
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>&
in
puts
)
override
;
bool
call
(
std
::
shared_ptr
<
Function
>
function
,
const
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>&
int
puts
)
override
;
void
set_nan_check
(
std
::
shared_ptr
<
Function
>
func
,
bool
);
...
...
@@ -58,19 +143,762 @@ namespace ngraph
std
::
vector
<
PerformanceCounter
>
get_performance_data
(
std
::
shared_ptr
<
Function
>
func
)
const
override
;
private
:
private
:
class
FunctionInstance
{
public
:
std
::
shared_ptr
<
interpreter
::
ExternalFunction
>
m_external_function
;
std
::
shared_ptr
<
interpreter
::
INT_CallFrame
>
m_call_frame
;
bool
m_is_compiled
=
false
;
bool
m_nan_check_enabled
=
false
;
bool
m_performance_counters_enabled
=
false
;
std
::
unordered_map
<
const
Node
*
,
stopwatch
>
m_timer_map
;
};
std
::
map
<
std
::
shared_ptr
<
Function
>
,
FunctionInstance
>
m_function_map
;
static
bool
init
;
static
void
perform_nan_check
(
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
,
const
Node
*
op
=
nullptr
);
void
generate_calls
(
const
element
::
Type
&
type
,
Node
&
op
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
inputs
);
template
<
typename
T
>
void
op_engine
(
Node
&
node
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
out
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
args
)
{
std
::
string
node_op
=
node
.
description
();
if
(
node_op
==
"Abs"
)
{
reference
::
abs
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Acos"
)
{
reference
::
acos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Add"
)
{
reference
::
add
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
#ifdef NGRAPH_DISTRIBUTED
else
if
(
node_op
==
"AllReduce"
)
{
reference
::
allreduce
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_element_type
(),
static_cast
<
int
>
(
args
[
0
]
->
get_element_count
()));
}
#endif
else
if
(
node_op
==
"And"
)
{
reference
::
logical_and
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Asin"
)
{
reference
::
asin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Atan"
)
{
reference
::
atan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"AvgPool"
)
{
op
::
AvgPool
*
avg_pool
=
dynamic_cast
<
op
::
AvgPool
*>
(
&
node
);
reference
::
avg_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
avg_pool
->
get_window_shape
(),
avg_pool
->
get_window_movement_strides
(),
avg_pool
->
get_padding_below
(),
avg_pool
->
get_padding_above
(),
avg_pool
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"AvgPoolBackprop"
)
{
op
::
AvgPoolBackprop
*
apb
=
dynamic_cast
<
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
apb
->
get_window_shape
(),
apb
->
get_window_movement_strides
(),
apb
->
get_padding_below
(),
apb
->
get_padding_above
(),
apb
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"Broadcast"
)
{
op
::
Broadcast
*
broadcast
=
dynamic_cast
<
op
::
Broadcast
*>
(
&
node
);
Shape
in_shape
=
args
[
0
]
->
get_shape
();
Shape
out_shape
=
out
[
0
]
->
get_shape
();
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
reference
::
broadcast
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shape
,
out_shape
,
broadcast_axes
);
}
else
if
(
node_op
==
"Ceiling"
)
{
reference
::
ceiling
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Concat"
)
{
const
op
::
Concat
*
concat
=
static_cast
<
const
op
::
Concat
*>
(
&
node
);
std
::
vector
<
const
T
*>
in_args
;
std
::
vector
<
Shape
>
in_shapes
;
for
(
std
::
shared_ptr
<
HostTensorView
>
arg
:
args
)
{
in_args
.
push_back
(
reinterpret_cast
<
T
*>
(
arg
->
get_data_ptr
()));
in_shapes
.
push_back
(
arg
->
get_shape
());
}
reference
::
concat
<
T
>
(
in_args
,
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shapes
,
out
[
0
]
->
get_shape
(),
concat
->
get_concatenation_axis
());
}
else
if
(
node_op
==
"Constant"
)
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
reference
::
constant
<
T
>
(
reinterpret_cast
<
const
T
*>
(
c
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Convert"
)
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element
::
Type
type
=
node
.
get_element_type
();
if
(
type
==
element
::
boolean
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
f32
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
float
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
f64
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
double
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i8
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int8_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i16
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int16_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i32
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i64
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int64_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u8
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint8_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u16
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint16_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u32
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint32_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u64
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint64_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported element type "
<<
type
<<
" op Convert"
;
throw
std
::
runtime_error
(
ss
.
str
());
}
}
else
if
(
node_op
==
"Convolution"
)
{
auto
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides
(),
c
->
get_window_dilation_strides
(),
c
->
get_padding_below
(),
c
->
get_padding_above
(),
c
->
get_data_dilation_strides
(),
0
,
1
,
1
,
0
,
0
,
1
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropFilters"
)
{
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
1
,
0
,
0
,
1
,
1
,
0
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropData"
)
{
// Note that args[1] and args[0] are switched here from the usual order.
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
0
,
1
,
0
,
1
,
0
,
1
,
true
);
}
else
if
(
node_op
==
"Cos"
)
{
reference
::
cos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Cosh"
)
{
reference
::
cosh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Divide"
)
{
reference
::
divide
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Dot"
)
{
op
::
Dot
*
dot
=
dynamic_cast
<
op
::
Dot
*>
(
&
node
);
reference
::
dot
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
dot
->
get_reduction_axes_count
());
}
else
if
(
node_op
==
"Equal"
)
{
reference
::
equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Exp"
)
{
reference
::
exp
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Floor"
)
{
reference
::
floor
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"FunctionCall"
)
{
std
::
shared_ptr
<
Function
>
function
=
node
.
get_functions
()[
0
];
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
outputs
;
for
(
auto
tv
:
out
)
{
outputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
TensorView
>
(
tv
));
}
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
inputs
;
for
(
auto
tv
:
args
)
{
inputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
TensorView
>
(
tv
));
}
call
(
function
,
outputs
,
inputs
);
}
else
if
(
node_op
==
"Greater"
)
{
reference
::
greater
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"GreaterEq"
)
{
reference
::
greater_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Less"
)
{
reference
::
less
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"LessEq"
)
{
reference
::
less_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Log"
)
{
reference
::
log
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Max"
)
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Maximum"
)
{
reference
::
maximum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"MaxPool"
)
{
op
::
MaxPool
*
max_pool
=
dynamic_cast
<
op
::
MaxPool
*>
(
&
node
);
reference
::
max_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool
->
get_window_shape
(),
max_pool
->
get_window_movement_strides
(),
max_pool
->
get_padding_below
(),
max_pool
->
get_padding_above
());
}
else
if
(
node_op
==
"MaxPoolBackprop"
)
{
op
::
MaxPoolBackprop
*
max_pool_backprop
=
dynamic_cast
<
op
::
MaxPoolBackprop
*>
(
&
node
);
reference
::
max_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool_backprop
->
get_window_shape
(),
max_pool_backprop
->
get_window_movement_strides
(),
max_pool_backprop
->
get_padding_below
(),
max_pool_backprop
->
get_padding_above
());
}
else
if
(
node_op
==
"Min"
)
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
min
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Minimum"
)
{
reference
::
minimum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Multiply"
)
{
reference
::
multiply
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Negative"
)
{
reference
::
negate
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Not"
)
{
reference
::
logical_not
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"NotEqual"
)
{
reference
::
not_equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"OneHot"
)
{
auto
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
oh
->
get_one_hot_axis
());
}
else
if
(
node_op
==
"Or"
)
{
reference
::
logical_or
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Parameter"
)
{
}
else
if
(
node_op
==
"Pad"
)
{
op
::
Pad
*
pad
=
dynamic_cast
<
op
::
Pad
*>
(
&
node
);
reference
::
pad
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
pad
->
get_padding_below
(),
pad
->
get_padding_above
(),
pad
->
get_padding_interior
());
}
else
if
(
node_op
==
"Power"
)
{
reference
::
power
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Product"
)
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
product
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Reduce"
)
{
op
::
Reduce
*
reduce
=
dynamic_cast
<
op
::
Reduce
*>
(
&
node
);
std
::
shared_ptr
<
Function
>
reduction_function
=
reduce
->
get_functions
()[
0
];
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
reduce
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
reduce
->
get_reduction_axes
(),
f
);
}
else
if
(
node_op
==
"ReduceWindow"
)
{
op
::
ReduceWindow
*
reduce_window
=
dynamic_cast
<
op
::
ReduceWindow
*>
(
&
node
);
std
::
shared_ptr
<
Function
>
reduction_function
=
reduce_window
->
get_functions
()[
0
];
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_window_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_window_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_window_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
reduce_window
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
f
,
reduce_window
->
get_window_shape
(),
reduce_window
->
get_window_movement_strides
());
}
}
else
if
(
node_op
==
"Relu"
)
{
reference
::
relu
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"ReluBackprop"
)
{
reference
::
relu_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"ReplaceSlice"
)
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Reshape"
)
{
op
::
Reshape
*
reshape
=
dynamic_cast
<
op
::
Reshape
*>
(
&
node
);
reference
::
reshape
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
reshape
->
get_input_order
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Result"
)
{
op
::
Result
*
res
=
dynamic_cast
<
op
::
Result
*>
(
&
node
);
reference
::
result
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
shape_size
(
res
->
get_shape
()));
}
else
if
(
node_op
==
"Reverse"
)
{
op
::
Reverse
*
reverse
=
dynamic_cast
<
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
reverse
->
get_reversed_axes
());
}
else
if
(
node_op
==
"Select"
)
{
reference
::
select
<
T
>
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"SelectAndScatter"
)
{
ngraph
::
op
::
SelectAndScatter
*
select_and_scatter
=
dynamic_cast
<
ngraph
::
op
::
SelectAndScatter
*>
(
&
node
);
std
::
shared_ptr
<
ngraph
::
Function
>
selection_function
=
select_and_scatter
->
get_functions
()[
0
];
std
::
function
<
bool
(
T
,
T
)
>
f_selection
=
[
this
,
&
node
,
selection_function
](
T
x
,
T
y
)
->
bool
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"selection_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"selection_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
element
::
boolean
,
Shape
{},
"selection_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
selection_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
char
*>
(
tr
->
get_data_ptr
()));
};
std
::
shared_ptr
<
ngraph
::
Function
>
scatter_function
=
select_and_scatter
->
get_functions
()[
1
];
std
::
function
<
T
(
T
,
T
)
>
f_scatter
=
[
this
,
&
node
,
scatter_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"scatter_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"scatter_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"scatter_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
scatter_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
select_and_scatter
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
f_selection
,
f_scatter
,
select_and_scatter
->
get_window_shape
(),
select_and_scatter
->
get_window_movement_strides
());
}
else
if
(
node_op
==
"Sign"
)
{
reference
::
sign
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sin"
)
{
reference
::
sin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sinh"
)
{
reference
::
sinh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Slice"
)
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Softmax"
)
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_shape
(),
softmax
->
get_axes
());
}
else
if
(
node_op
==
"Sqrt"
)
{
reference
::
sqrt
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Subtract"
)
{
reference
::
subtract
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sum"
)
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
sum
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Tan"
)
{
reference
::
tan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Tanh"
)
{
reference
::
tanh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported op "
<<
node_op
;
throw
ngraph_error
(
ss
.
str
());
}
}
};
src/ngraph/runtime/interpreter/int_call_frame.cpp
deleted
100644 → 0
View file @
66198b33
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <cstdlib>
#include <iomanip>
#include "ngraph/op/result.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/interpreter/int_call_frame.hpp"
using
namespace
std
;
using
namespace
ngraph
;
runtime
::
interpreter
::
INT_CallFrame
::
INT_CallFrame
(
shared_ptr
<
Function
>
func
)
:
m_function
(
func
)
,
m_emit_timing
(
false
)
,
m_nan_check
(
std
::
getenv
(
"NGRAPH_INTERPRETER_NAN_CHECK"
)
!=
nullptr
)
{
}
void
runtime
::
interpreter
::
INT_CallFrame
::
call
(
std
::
shared_ptr
<
Function
>
function
,
const
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>&
output_tvs
,
const
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>&
input_tvs
)
{
if
(
m_nan_check
)
{
perform_nan_check
(
input_tvs
);
}
unordered_map
<
descriptor
::
TensorView
*
,
shared_ptr
<
runtime
::
HostTensorView
>>
tensor_map
;
size_t
arg_index
=
0
;
for
(
shared_ptr
<
op
::
Parameter
>
param
:
function
->
get_parameters
())
{
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
{
descriptor
::
TensorView
*
tv
=
param
->
get_output_tensor_view
(
i
).
get
();
tensor_map
.
insert
({
tv
,
input_tvs
[
arg_index
++
]});
}
}
for
(
size_t
i
=
0
;
i
<
function
->
get_output_size
();
i
++
)
{
auto
output_op
=
function
->
get_output_op
(
i
);
if
(
!
std
::
dynamic_pointer_cast
<
op
::
Result
>
(
output_op
))
{
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
}
descriptor
::
TensorView
*
tv
=
function
->
get_output_op
(
i
)
->
get_output_tensor_view
(
0
).
get
();
tensor_map
.
insert
({
tv
,
output_tvs
[
i
]});
}
// Invoke computation
for
(
shared_ptr
<
Node
>
op
:
function
->
get_ordered_ops
())
{
if
(
op
->
description
()
==
"Parameter"
)
{
continue
;
}
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
inputs
;
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
outputs
;
for
(
const
descriptor
::
Input
&
input
:
op
->
get_inputs
())
{
descriptor
::
TensorView
*
tv
=
input
.
get_output
().
get_tensor_view
().
get
();
string
name
=
tv
->
get_tensor
().
get_name
();
inputs
.
push_back
(
tensor_map
.
at
(
tv
));
}
for
(
size_t
i
=
0
;
i
<
op
->
get_output_size
();
++
i
)
{
descriptor
::
TensorView
*
tv
=
op
->
get_output_tensor_view
(
i
).
get
();
string
name
=
tv
->
get_tensor
().
get_name
();
shared_ptr
<
runtime
::
HostTensorView
>
itv
;
if
(
!
contains_key
(
tensor_map
,
tv
))
{
// The output tensor is not in the tensor map so create a new tensor
const
Shape
&
shape
=
op
->
get_output_shape
(
i
);
const
element
::
Type
&
element_type
=
op
->
get_output_element_type
(
i
);
string
tensor_name
=
op
->
get_output_tensor
(
i
).
get_name
();
itv
=
make_shared
<
runtime
::
HostTensorView
>
(
element_type
,
shape
,
tensor_name
);
tensor_map
.
insert
({
tv
,
itv
});
}
else
{
itv
=
tensor_map
.
at
(
tv
);
}
outputs
.
push_back
(
itv
);
}
element
::
Type
base_type
;
element
::
Type
secondary_type
;
if
(
op
->
get_inputs
().
empty
())
{
base_type
=
op
->
get_element_type
();
}
else
{
base_type
=
op
->
get_inputs
().
at
(
0
).
get_tensor
().
get_element_type
();
}
secondary_type
=
op
->
get_element_type
();
// Some ops have unusual intput/output types so handle those special cases here
if
(
op
->
description
()
==
"Select"
)
{
base_type
=
op
->
get_inputs
().
at
(
1
).
get_tensor
().
get_element_type
();
secondary_type
=
op
->
get_inputs
().
at
(
0
).
get_tensor
().
get_element_type
();
}
if
(
m_emit_timing
)
{
m_timer_map
[
op
.
get
()].
start
();
}
generate_calls
(
base_type
,
secondary_type
,
*
op
,
inputs
,
outputs
);
if
(
m_emit_timing
)
{
stopwatch
&
timer
=
m_timer_map
[
op
.
get
()];
timer
.
stop
();
}
if
(
m_nan_check
)
{
perform_nan_check
(
outputs
,
op
.
get
());
}
// Delete any obsolete tensors
for
(
const
descriptor
::
Tensor
*
t
:
op
->
liveness_free_list
)
{
for
(
auto
it
=
tensor_map
.
begin
();
it
!=
tensor_map
.
end
();
++
it
)
{
if
(
it
->
second
->
get_tensor
().
get_name
()
==
t
->
get_name
())
{
tensor_map
.
erase
(
it
);
break
;
}
}
}
}
}
void
runtime
::
interpreter
::
INT_CallFrame
::
generate_calls
(
const
element
::
Type
&
base_type
,
const
element
::
Type
&
secondary_type
,
ngraph
::
Node
&
op
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
args
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
out
)
{
if
(
base_type
==
element
::
boolean
)
{
generate_calls
<
char
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
f32
)
{
generate_calls
<
float
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
f64
)
{
generate_calls
<
double
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
i8
)
{
generate_calls
<
int8_t
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
i16
)
{
generate_calls
<
int16_t
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
i32
)
{
generate_calls
<
int32_t
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
i64
)
{
generate_calls
<
int64_t
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
u8
)
{
generate_calls
<
uint8_t
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
u16
)
{
generate_calls
<
uint16_t
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
u32
)
{
generate_calls
<
uint32_t
>
(
secondary_type
,
op
,
args
,
out
);
}
else
if
(
base_type
==
element
::
u64
)
{
generate_calls
<
uint64_t
>
(
secondary_type
,
op
,
args
,
out
);
}
else
{
stringstream
ss
;
ss
<<
"unsupported element type "
<<
base_type
<<
" op "
<<
op
.
get_name
();
throw
runtime_error
(
ss
.
str
());
}
}
void
runtime
::
interpreter
::
INT_CallFrame
::
call
(
const
vector
<
shared_ptr
<
runtime
::
TensorView
>>&
output_tvs
,
const
vector
<
shared_ptr
<
runtime
::
TensorView
>>&
input_tvs
)
{
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
args
;
vector
<
shared_ptr
<
runtime
::
HostTensorView
>>
out
;
for
(
auto
tv
:
input_tvs
)
{
args
.
push_back
(
static_pointer_cast
<
runtime
::
HostTensorView
>
(
tv
));
}
for
(
auto
tv
:
output_tvs
)
{
out
.
push_back
(
static_pointer_cast
<
runtime
::
HostTensorView
>
(
tv
));
}
call
(
m_function
,
out
,
args
);
}
vector
<
runtime
::
PerformanceCounter
>
runtime
::
interpreter
::
INT_CallFrame
::
get_performance_data
()
const
{
vector
<
runtime
::
PerformanceCounter
>
rc
;
for
(
const
pair
<
const
Node
*
,
stopwatch
>
p
:
m_timer_map
)
{
rc
.
emplace_back
(
p
.
first
->
get_name
().
c_str
(),
p
.
second
.
get_total_microseconds
(),
p
.
second
.
get_call_count
());
}
return
rc
;
}
void
runtime
::
interpreter
::
INT_CallFrame
::
perform_nan_check
(
const
vector
<
shared_ptr
<
HostTensorView
>>&
tvs
,
const
Node
*
op
)
{
size_t
arg_number
=
1
;
for
(
shared_ptr
<
HostTensorView
>
tv
:
tvs
)
{
const
element
::
Type
&
type
=
tv
->
get_tensor
().
get_element_type
();
if
(
type
==
element
::
f32
)
{
const
float
*
data
=
reinterpret_cast
<
float
*>
(
tv
->
get_data_ptr
());
for
(
size_t
i
=
0
;
i
<
tv
->
get_element_count
();
i
++
)
{
if
(
std
::
isnan
(
data
[
i
]))
{
if
(
op
)
{
throw
runtime_error
(
"nan found in op '"
+
op
->
get_name
()
+
"' output"
);
}
else
{
throw
runtime_error
(
"nan found in function's input tensor number "
+
to_string
(
arg_number
));
}
}
}
}
else
if
(
type
==
element
::
f64
)
{
const
double
*
data
=
reinterpret_cast
<
double
*>
(
tv
->
get_data_ptr
());
for
(
size_t
i
=
0
;
i
<
tv
->
get_element_count
();
i
++
)
{
if
(
std
::
isnan
(
data
[
i
]))
{
if
(
op
)
{
throw
runtime_error
(
"nan found in op '"
+
op
->
get_name
()
+
"' output"
);
}
else
{
throw
runtime_error
(
"nan found in function's input tensor number "
+
to_string
(
arg_number
));
}
}
}
}
arg_number
++
;
}
}
void
runtime
::
interpreter
::
INT_CallFrame
::
set_nan_check
(
bool
value
)
{
m_nan_check
=
value
;
}
src/ngraph/runtime/interpreter/int_call_frame.hpp
deleted
100644 → 0
View file @
66198b33
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <functional>
#include <memory>
#include <vector>
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce.hpp"
#include "ngraph/op/reduce_window.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/runtime/host_tensor_view.hpp"
#include "ngraph/runtime/performance_counter.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/copy.hpp"
#include "ngraph/runtime/reference/cos.hpp"
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
#include "ngraph/runtime/reference/less_eq.hpp"
#include "ngraph/runtime/reference/log.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/max_pool.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/reduce.hpp"
#include "ngraph/runtime/reference/reduce_window.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/select_and_scatter.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sin.hpp"
#include "ngraph/runtime/reference/sinh.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/tensor_view.hpp"
#include "ngraph/util.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/runtime/reference/allreduce.hpp"
#endif
namespace
ngraph
{
namespace
runtime
{
namespace
interpreter
{
class
INT_CallFrame
;
}
}
}
// Compile and execute graphs
class
ngraph
::
runtime
::
interpreter
::
INT_CallFrame
{
friend
class
INT_Backend
;
public
:
INT_CallFrame
(
std
::
shared_ptr
<
Function
>
func
);
/// @brief Invoke the function with values matching the signature of the function.
///
/// Tuples will be expanded into their tensor views to build the call frame.
void
call
(
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>&
inputs
);
std
::
vector
<
runtime
::
PerformanceCounter
>
get_performance_data
()
const
;
void
set_nan_check
(
bool
);
private
:
void
call
(
std
::
shared_ptr
<
Function
>
function
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
HostTensorView
>>&
output_tvs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
HostTensorView
>>&
input_tvs
);
static
void
perform_nan_check
(
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
,
const
Node
*
op
=
nullptr
);
std
::
shared_ptr
<
Function
>
m_function
;
bool
m_emit_timing
;
bool
m_nan_check
;
std
::
unordered_map
<
const
Node
*
,
stopwatch
>
m_timer_map
;
void
generate_calls
(
const
element
::
Type
&
base_type
,
const
element
::
Type
&
secondary_type
,
ngraph
::
Node
&
op
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
args
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
out
);
template
<
typename
BASE
>
void
generate_calls
(
const
element
::
Type
&
type
,
ngraph
::
Node
&
op
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
args
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
out
)
{
if
(
type
==
element
::
boolean
)
{
op_engine
<
BASE
,
char
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
f32
)
{
op_engine
<
BASE
,
float
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
f64
)
{
op_engine
<
BASE
,
double
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
i8
)
{
op_engine
<
BASE
,
int8_t
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
i16
)
{
op_engine
<
BASE
,
int16_t
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
i32
)
{
op_engine
<
BASE
,
int32_t
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
i64
)
{
op_engine
<
BASE
,
int64_t
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
u8
)
{
op_engine
<
BASE
,
uint8_t
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
u16
)
{
op_engine
<
BASE
,
uint16_t
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
u32
)
{
op_engine
<
BASE
,
uint32_t
>
(
op
,
args
,
out
);
}
else
if
(
type
==
element
::
u64
)
{
op_engine
<
BASE
,
uint64_t
>
(
op
,
args
,
out
);
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported element type "
<<
type
<<
" op "
<<
op
.
get_name
();
throw
std
::
runtime_error
(
ss
.
str
());
}
}
template
<
typename
T
,
typename
S
>
void
op_engine
(
ngraph
::
Node
&
node
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
args
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
out
)
{
std
::
string
node_op
=
node
.
description
();
if
(
node_op
==
"Abs"
)
{
reference
::
abs
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Acos"
)
{
reference
::
acos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Add"
)
{
reference
::
add
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
#ifdef NGRAPH_DISTRIBUTED
else
if
(
node_op
==
"AllReduce"
)
{
reference
::
allreduce
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_element_type
(),
static_cast
<
int
>
(
args
[
0
]
->
get_element_count
()));
}
#endif
else
if
(
node_op
==
"And"
)
{
reference
::
logical_and
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Asin"
)
{
reference
::
asin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Atan"
)
{
reference
::
atan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"AvgPool"
)
{
ngraph
::
op
::
AvgPool
*
avg_pool
=
dynamic_cast
<
ngraph
::
op
::
AvgPool
*>
(
&
node
);
reference
::
avg_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
avg_pool
->
get_window_shape
(),
avg_pool
->
get_window_movement_strides
(),
avg_pool
->
get_padding_below
(),
avg_pool
->
get_padding_above
(),
avg_pool
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"AvgPoolBackprop"
)
{
ngraph
::
op
::
AvgPoolBackprop
*
apb
=
dynamic_cast
<
ngraph
::
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
apb
->
get_window_shape
(),
apb
->
get_window_movement_strides
(),
apb
->
get_padding_below
(),
apb
->
get_padding_above
(),
apb
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"Broadcast"
)
{
ngraph
::
op
::
Broadcast
*
broadcast
=
dynamic_cast
<
ngraph
::
op
::
Broadcast
*>
(
&
node
);
Shape
in_shape
=
args
[
0
]
->
get_shape
();
Shape
out_shape
=
out
[
0
]
->
get_shape
();
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
reference
::
broadcast
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shape
,
out_shape
,
broadcast_axes
);
}
else
if
(
node_op
==
"Ceiling"
)
{
reference
::
ceiling
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Concat"
)
{
const
op
::
Concat
*
concat
=
static_cast
<
const
op
::
Concat
*>
(
&
node
);
std
::
vector
<
const
T
*>
in_args
;
std
::
vector
<
Shape
>
in_shapes
;
for
(
std
::
shared_ptr
<
HostTensorView
>
arg
:
args
)
{
in_args
.
push_back
(
reinterpret_cast
<
T
*>
(
arg
->
get_data_ptr
()));
in_shapes
.
push_back
(
arg
->
get_shape
());
}
reference
::
concat
<
T
>
(
in_args
,
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shapes
,
out
[
0
]
->
get_shape
(),
concat
->
get_concatenation_axis
());
}
else
if
(
node_op
==
"Constant"
)
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
reference
::
constant
<
T
>
(
reinterpret_cast
<
const
T
*>
(
c
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Convert"
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
S
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Convolution"
)
{
auto
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides
(),
c
->
get_window_dilation_strides
(),
c
->
get_padding_below
(),
c
->
get_padding_above
(),
c
->
get_data_dilation_strides
(),
0
,
1
,
1
,
0
,
0
,
1
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropFilters"
)
{
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
1
,
0
,
0
,
1
,
1
,
0
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropData"
)
{
// Note that args[1] and args[0] are switched here from the usual order.
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
0
,
1
,
0
,
1
,
0
,
1
,
true
);
}
else
if
(
node_op
==
"Cos"
)
{
reference
::
cos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Cosh"
)
{
reference
::
cosh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Divide"
)
{
reference
::
divide
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Dot"
)
{
ngraph
::
op
::
Dot
*
dot
=
dynamic_cast
<
ngraph
::
op
::
Dot
*>
(
&
node
);
reference
::
dot
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
dot
->
get_reduction_axes_count
());
}
else
if
(
node_op
==
"Equal"
)
{
reference
::
equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Exp"
)
{
reference
::
exp
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Floor"
)
{
reference
::
floor
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"FunctionCall"
)
{
std
::
shared_ptr
<
Function
>
function
=
node
.
get_functions
()[
0
];
call
(
function
,
out
,
args
);
}
else
if
(
node_op
==
"Greater"
)
{
reference
::
greater
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"GreaterEq"
)
{
reference
::
greater_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Less"
)
{
reference
::
less
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"LessEq"
)
{
reference
::
less_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Log"
)
{
reference
::
log
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Max"
)
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Maximum"
)
{
reference
::
maximum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"MaxPool"
)
{
ngraph
::
op
::
MaxPool
*
max_pool
=
dynamic_cast
<
ngraph
::
op
::
MaxPool
*>
(
&
node
);
reference
::
max_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool
->
get_window_shape
(),
max_pool
->
get_window_movement_strides
(),
max_pool
->
get_padding_below
(),
max_pool
->
get_padding_above
());
}
else
if
(
node_op
==
"MaxPoolBackprop"
)
{
ngraph
::
op
::
MaxPoolBackprop
*
max_pool_backprop
=
dynamic_cast
<
ngraph
::
op
::
MaxPoolBackprop
*>
(
&
node
);
reference
::
max_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool_backprop
->
get_window_shape
(),
max_pool_backprop
->
get_window_movement_strides
(),
max_pool_backprop
->
get_padding_below
(),
max_pool_backprop
->
get_padding_above
());
}
else
if
(
node_op
==
"Min"
)
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
min
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Minimum"
)
{
reference
::
minimum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Multiply"
)
{
reference
::
multiply
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Negative"
)
{
reference
::
negate
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Not"
)
{
reference
::
logical_not
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"NotEqual"
)
{
reference
::
not_equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"OneHot"
)
{
auto
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
oh
->
get_one_hot_axis
());
}
else
if
(
node_op
==
"Or"
)
{
reference
::
logical_or
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Parameter"
)
{
}
else
if
(
node_op
==
"Pad"
)
{
ngraph
::
op
::
Pad
*
pad
=
dynamic_cast
<
ngraph
::
op
::
Pad
*>
(
&
node
);
reference
::
pad
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
pad
->
get_padding_below
(),
pad
->
get_padding_above
(),
pad
->
get_padding_interior
());
}
else
if
(
node_op
==
"Power"
)
{
reference
::
power
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Product"
)
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
product
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Reduce"
)
{
ngraph
::
op
::
Reduce
*
reduce
=
dynamic_cast
<
ngraph
::
op
::
Reduce
*>
(
&
node
);
std
::
shared_ptr
<
ngraph
::
Function
>
reduction_function
=
reduce
->
get_functions
()[
0
];
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
reduce
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
reduce
->
get_reduction_axes
(),
f
);
}
else
if
(
node_op
==
"ReduceWindow"
)
{
ngraph
::
op
::
ReduceWindow
*
reduce_window
=
dynamic_cast
<
ngraph
::
op
::
ReduceWindow
*>
(
&
node
);
std
::
shared_ptr
<
ngraph
::
Function
>
reduction_function
=
reduce_window
->
get_functions
()[
0
];
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_window_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_window_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_window_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
reduce_window
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
f
,
reduce_window
->
get_window_shape
(),
reduce_window
->
get_window_movement_strides
());
}
else
if
(
node_op
==
"Relu"
)
{
reference
::
relu
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"ReluBackprop"
)
{
reference
::
relu_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
// else if (node_op == "Remainder")
// {
// // node = make_shared<op::Remainder>(args[0], args[1]);
// }
else
if
(
node_op
==
"ReplaceSlice"
)
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Reshape"
)
{
ngraph
::
op
::
Reshape
*
reshape
=
dynamic_cast
<
ngraph
::
op
::
Reshape
*>
(
&
node
);
reference
::
reshape
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
reshape
->
get_input_order
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Result"
)
{
ngraph
::
op
::
Result
*
res
=
dynamic_cast
<
ngraph
::
op
::
Result
*>
(
&
node
);
reference
::
result
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
shape_size
(
res
->
get_shape
()));
}
else
if
(
node_op
==
"Reverse"
)
{
ngraph
::
op
::
Reverse
*
reverse
=
dynamic_cast
<
ngraph
::
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
reverse
->
get_reversed_axes
());
}
else
if
(
node_op
==
"Select"
)
{
reference
::
select
<
T
>
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"SelectAndScatter"
)
{
ngraph
::
op
::
SelectAndScatter
*
select_and_scatter
=
dynamic_cast
<
ngraph
::
op
::
SelectAndScatter
*>
(
&
node
);
std
::
shared_ptr
<
ngraph
::
Function
>
selection_function
=
select_and_scatter
->
get_functions
()[
0
];
std
::
function
<
bool
(
T
,
T
)
>
f_selection
=
[
this
,
&
node
,
selection_function
](
T
x
,
T
y
)
->
bool
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"selection_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"selection_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
element
::
boolean
,
Shape
{},
"selection_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
selection_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
char
*>
(
tr
->
get_data_ptr
()));
};
std
::
shared_ptr
<
ngraph
::
Function
>
scatter_function
=
select_and_scatter
->
get_functions
()[
1
];
std
::
function
<
T
(
T
,
T
)
>
f_scatter
=
[
this
,
&
node
,
scatter_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"scatter_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"scatter_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"scatter_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
scatter_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
select_and_scatter
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
f_selection
,
f_scatter
,
select_and_scatter
->
get_window_shape
(),
select_and_scatter
->
get_window_movement_strides
());
}
else
if
(
node_op
==
"Sign"
)
{
reference
::
sign
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sin"
)
{
reference
::
sin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sinh"
)
{
reference
::
sinh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Slice"
)
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Softmax"
)
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_shape
(),
softmax
->
get_axes
());
}
else
if
(
node_op
==
"Sqrt"
)
{
reference
::
sqrt
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Subtract"
)
{
reference
::
subtract
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sum"
)
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
sum
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Tan"
)
{
reference
::
tan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Tanh"
)
{
reference
::
tanh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported op "
<<
node_op
;
throw
std
::
runtime_error
(
ss
.
str
());
}
}
};
src/ngraph/runtime/interpreter/int_external_function.cpp
deleted
100644 → 0
View file @
66198b33
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <fstream>
#include <memory>
#include <string>
#include <tuple>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/descriptor/input.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/descriptor/output.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/acos.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/cos.hpp"
#include "ngraph/op/cosh.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/function_call.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/reduce.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/dump_sorted.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/interpreter/int_backend.hpp"
#include "ngraph/runtime/interpreter/int_call_frame.hpp"
#include "ngraph/runtime/interpreter/int_external_function.hpp"
using
namespace
std
;
using
namespace
ngraph
;
static
const
string
s_output_dir
=
"cpu_codegen"
;
class
StaticInitializers
{
public
:
StaticInitializers
()
{
file_util
::
remove_directory
(
s_output_dir
);
}
};
static
StaticInitializers
s_static_initializers
;
using
descriptor
::
layout
::
DenseTensorViewLayout
;
runtime
::
interpreter
::
ExternalFunction
::
ExternalFunction
(
const
shared_ptr
<
Function
>&
function
,
bool
release_function
)
:
m_function
(
function
)
,
m_release_function
(
release_function
)
,
m_is_compiled
(
false
)
,
m_timing
(
false
)
{
}
void
runtime
::
interpreter
::
ExternalFunction
::
compile
()
{
if
(
m_is_compiled
)
{
return
;
}
pass
::
Manager
pass_manager
;
// For now, just make everyone row-major.
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorViewLayout
>>
();
pass_manager
.
register_pass
<
pass
::
Liveness
>
();
pass_manager
.
run_passes
(
m_function
);
m_is_compiled
=
true
;
if
(
m_release_function
)
{
release_function
();
}
}
shared_ptr
<
runtime
::
interpreter
::
INT_CallFrame
>
runtime
::
interpreter
::
ExternalFunction
::
make_call_frame
()
{
if
(
!
m_is_compiled
)
{
compile
();
}
return
make_shared
<
runtime
::
interpreter
::
INT_CallFrame
>
(
m_function
);
}
src/ngraph/runtime/interpreter/int_external_function.hpp
deleted
100644 → 0
View file @
66198b33
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/function.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
interpreter
{
class
INT_CallFrame
;
class
ExternalFunction
{
public
:
ExternalFunction
(
const
std
::
shared_ptr
<
ngraph
::
Function
>&
function
,
bool
release_function
=
false
);
std
::
shared_ptr
<
INT_CallFrame
>
make_call_frame
();
protected
:
void
compile
();
void
release_function
()
{
m_function
=
nullptr
;
}
std
::
shared_ptr
<
ngraph
::
Function
>
m_function
;
bool
m_release_function
;
bool
m_is_compiled
;
bool
m_timing
;
};
}
}
}
test/CMakeLists.txt
View file @
89963725
...
...
@@ -69,7 +69,6 @@ add_subdirectory(files)
#================================================================================================
# TODO add interpreter back to unit tests when it works
set
(
BACKEND_NAMES
${
BACKEND_NAMES
}
"INTERPRETER"
)
set
(
BACKEND_NAMES
${
BACKEND_NAMES
}
"IE"
)
if
(
MKLDNN_INCLUDE_DIR
)
include_directories
(
SYSTEM
${
MKLDNN_INCLUDE_DIR
}
)
...
...
test/backend_debug_api.cpp
View file @
89963725
...
...
@@ -37,8 +37,8 @@ TEST(INTERPRETER, nan_check_input)
auto
backend
=
runtime
::
Backend
::
create
(
"INTERPRETER"
);
shared_ptr
<
runtime
::
interpreter
::
INT
_
Backend
>
ibackend
=
static_pointer_cast
<
runtime
::
interpreter
::
INT
_
Backend
>
(
backend
);
shared_ptr
<
runtime
::
interpreter
::
INTBackend
>
ibackend
=
static_pointer_cast
<
runtime
::
interpreter
::
INTBackend
>
(
backend
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
...
...
@@ -60,8 +60,8 @@ TEST(INTERPRETER, nan_check_output)
auto
backend
=
runtime
::
Backend
::
create
(
"INTERPRETER"
);
shared_ptr
<
runtime
::
interpreter
::
INT
_
Backend
>
ibackend
=
static_pointer_cast
<
runtime
::
interpreter
::
INT
_
Backend
>
(
backend
);
shared_ptr
<
runtime
::
interpreter
::
INTBackend
>
ibackend
=
static_pointer_cast
<
runtime
::
interpreter
::
INTBackend
>
(
backend
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment