Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
ef2e0118
Commit
ef2e0118
authored
Jan 26, 2019
by
Robert Kimball
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
copy executable from bob/backend_api2
parent
122754c1
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
374 additions
and
258 deletions
+374
-258
__init__.py
python/ngraph/impl/runtime/__init__.py
+1
-0
runtime.py
python/ngraph/runtime.py
+2
-2
backend.cpp
python/pyngraph/runtime/backend.cpp
+1
-17
executable.cpp
python/pyngraph/runtime/executable.cpp
+40
-0
executable.hpp
python/pyngraph/runtime/executable.hpp
+23
-0
regmodule_pyngraph_runtime.cpp
python/pyngraph/runtime/regmodule_pyngraph_runtime.cpp
+1
-0
setup.py
python/setup.py
+1
-0
test_ops.py
python/test/test_ops.py
+0
-0
CMakeLists.txt
src/ngraph/CMakeLists.txt
+2
-2
backend.cpp
src/ngraph/runtime/backend.cpp
+72
-27
backend.hpp
src/ngraph/runtime/backend.hpp
+61
-41
hybrid_backend.cpp
src/ngraph/runtime/hybrid/hybrid_backend.cpp
+32
-30
hybrid_backend.hpp
src/ngraph/runtime/hybrid/hybrid_backend.hpp
+24
-15
int_backend.cpp
src/ngraph/runtime/interpreter/int_backend.cpp
+44
-66
int_backend.hpp
src/ngraph/runtime/interpreter/int_backend.hpp
+29
-31
nop_backend.cpp
src/ngraph/runtime/nop/nop_backend.cpp
+16
-5
nop_backend.hpp
src/ngraph/runtime/nop/nop_backend.hpp
+10
-4
benchmark.cpp
src/tools/nbench/benchmark.cpp
+4
-5
CMakeLists.txt
test/CMakeLists.txt
+0
-1
backend_debug_api.cpp
test/backend_debug_api.cpp
+11
-12
No files found.
python/ngraph/impl/runtime/__init__.py
View file @
ef2e0118
...
@@ -28,4 +28,5 @@ else:
...
@@ -28,4 +28,5 @@ else:
sys
.
setdlopenflags
(
flags
)
sys
.
setdlopenflags
(
flags
)
from
_pyngraph.runtime
import
Backend
from
_pyngraph.runtime
import
Backend
from
_pyngraph.runtime
import
Executable
from
_pyngraph.runtime
import
Tensor
from
_pyngraph.runtime
import
Tensor
python/ngraph/runtime.py
View file @
ef2e0118
...
@@ -20,7 +20,7 @@ from typing import List, Union
...
@@ -20,7 +20,7 @@ from typing import List, Union
import
numpy
as
np
import
numpy
as
np
from
ngraph.impl
import
Function
,
Node
,
Shape
,
serialize
,
util
from
ngraph.impl
import
Function
,
Node
,
Shape
,
serialize
,
util
from
ngraph.impl.runtime
import
Backend
,
Tensor
from
ngraph.impl.runtime
import
Backend
,
Executable
,
Tensor
from
ngraph.utils.types
import
get_dtype
,
NumericData
from
ngraph.utils.types
import
get_dtype
,
NumericData
from
ngraph.exceptions
import
UserInputError
from
ngraph.exceptions
import
UserInputError
...
@@ -93,7 +93,7 @@ class Computation(object):
...
@@ -93,7 +93,7 @@ class Computation(object):
value
=
np
.
array
(
value
)
value
=
np
.
array
(
value
)
Computation
.
_write_ndarray_to_tensor_view
(
value
,
tensor_view
)
Computation
.
_write_ndarray_to_tensor_view
(
value
,
tensor_view
)
self
.
runtime
.
backend
.
call
(
self
.
handle
,
self
.
result_views
,
self
.
tensor_views
)
self
.
handle
.
call
(
self
.
result_views
,
self
.
tensor_views
)
results
=
[]
results
=
[]
for
result_view
in
self
.
result_views
:
for
result_view
in
self
.
result_views
:
...
...
python/pyngraph/runtime/backend.cpp
View file @
ef2e0118
...
@@ -35,23 +35,7 @@ void regclass_pyngraph_runtime_Backend(py::module m)
...
@@ -35,23 +35,7 @@ void regclass_pyngraph_runtime_Backend(py::module m)
const
ngraph
::
element
::
Type
&
,
const
ngraph
::
Shape
&
))
&
const
ngraph
::
element
::
Type
&
,
const
ngraph
::
Shape
&
))
&
ngraph
::
runtime
::
Backend
::
create_tensor
);
ngraph
::
runtime
::
Backend
::
create_tensor
);
backend
.
def
(
"compile"
,
backend
.
def
(
"compile"
,
(
std
::
shared_ptr
<
ngraph
::
Function
>
(
ngraph
::
runtime
::
Backend
::*
)(
(
std
::
unique_ptr
<
ngraph
::
runtime
::
Executable
>
(
ngraph
::
runtime
::
Backend
::*
)(
std
::
shared_ptr
<
ngraph
::
Function
>
))
&
std
::
shared_ptr
<
ngraph
::
Function
>
))
&
ngraph
::
runtime
::
Backend
::
compile
);
ngraph
::
runtime
::
Backend
::
compile
);
backend
.
def
(
"call"
,
(
bool
(
ngraph
::
runtime
::
Backend
::*
)(
std
::
shared_ptr
<
ngraph
::
Function
>
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>&
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>&
))
&
ngraph
::
runtime
::
Backend
::
call
);
backend
.
def
(
"remove_compiled_function"
,
(
void
(
ngraph
::
runtime
::
Backend
::*
)(
std
::
shared_ptr
<
ngraph
::
Function
>
))
&
ngraph
::
runtime
::
Backend
::
remove_compiled_function
);
backend
.
def
(
"enable_performance_data"
,
(
void
(
ngraph
::
runtime
::
Backend
::*
)(
std
::
shared_ptr
<
ngraph
::
Function
>
,
bool
))
&
ngraph
::
runtime
::
Backend
::
enable_performance_data
);
backend
.
def
(
"get_performance_data"
,
(
std
::
vector
<
ngraph
::
runtime
::
PerformanceCounter
>
(
ngraph
::
runtime
::
Backend
::*
)(
std
::
shared_ptr
<
ngraph
::
Function
>
))
&
ngraph
::
runtime
::
Backend
::
get_performance_data
);
}
}
python/pyngraph/runtime/executable.cpp
0 → 100644
View file @
ef2e0118
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ngraph/runtime/backend.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "pyngraph/runtime/executable.hpp"
namespace
py
=
pybind11
;
void
regclass_pyngraph_runtime_Executable
(
py
::
module
m
)
{
py
::
class_
<
ngraph
::
runtime
::
Executable
,
std
::
unique_ptr
<
ngraph
::
runtime
::
Executable
>>
executable
(
m
,
"Executable"
);
executable
.
doc
()
=
"ngraph.impl.runtime.Executable wraps ngraph::runtime::Executable"
;
executable
.
def
(
"call"
,
(
bool
(
ngraph
::
runtime
::
Executable
::*
)(
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>&
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>&
))
&
ngraph
::
runtime
::
Executable
::
call
);
executable
.
def
(
"get_performance_data"
,
(
std
::
vector
<
ngraph
::
runtime
::
PerformanceCounter
>
(
ngraph
::
runtime
::
Executable
::*
)())
&
ngraph
::
runtime
::
Executable
::
get_performance_data
);
}
python/pyngraph/runtime/executable.hpp
0 → 100644
View file @
ef2e0118
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <pybind11/pybind11.h>
namespace
py
=
pybind11
;
void
regclass_pyngraph_runtime_Executable
(
py
::
module
m
);
python/pyngraph/runtime/regmodule_pyngraph_runtime.cpp
View file @
ef2e0118
...
@@ -25,4 +25,5 @@ void regmodule_pyngraph_runtime(py::module m)
...
@@ -25,4 +25,5 @@ void regmodule_pyngraph_runtime(py::module m)
m
.
def_submodule
(
"runtime"
,
"Package ngraph.impl.runtime wraps ngraph::runtime"
);
m
.
def_submodule
(
"runtime"
,
"Package ngraph.impl.runtime wraps ngraph::runtime"
);
regclass_pyngraph_runtime_Tensor
(
m_runtime
);
regclass_pyngraph_runtime_Tensor
(
m_runtime
);
regclass_pyngraph_runtime_Backend
(
m_runtime
);
regclass_pyngraph_runtime_Backend
(
m_runtime
);
regclass_pyngraph_runtime_Executable
(
m_runtime
);
}
}
python/setup.py
View file @
ef2e0118
...
@@ -228,6 +228,7 @@ sources = [
...
@@ -228,6 +228,7 @@ sources = [
'pyngraph/ops/softmax.cpp'
,
'pyngraph/ops/softmax.cpp'
,
'pyngraph/ops/result.cpp'
,
'pyngraph/ops/result.cpp'
,
'pyngraph/runtime/backend.cpp'
,
'pyngraph/runtime/backend.cpp'
,
'pyngraph/runtime/executable.cpp'
,
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp'
,
'pyngraph/runtime/regmodule_pyngraph_runtime.cpp'
,
'pyngraph/runtime/tensor.cpp'
,
'pyngraph/runtime/tensor.cpp'
,
'pyngraph/passes/manager.cpp'
,
'pyngraph/passes/manager.cpp'
,
...
...
python/test/test_ops.py
View file @
ef2e0118
This diff is collapsed.
Click to expand it.
src/ngraph/CMakeLists.txt
View file @
ef2e0118
...
@@ -139,8 +139,8 @@ set (SRC
...
@@ -139,8 +139,8 @@ set (SRC
pass/memory_visualize.cpp
pass/memory_visualize.cpp
pass/nop_elimination.cpp
pass/nop_elimination.cpp
pass/pass.cpp
pass/pass.cpp
pass/pass_config.cpp
pass/pass_config.cpp
pass/prefix_reshape_elimination.cpp
pass/prefix_reshape_elimination.cpp
pass/propagate_cacheability.cpp
pass/propagate_cacheability.cpp
pass/reshape_elimination.cpp
pass/reshape_elimination.cpp
pass/reshape_sinking.cpp
pass/reshape_sinking.cpp
...
...
src/ngraph/runtime/backend.cpp
View file @
ef2e0118
...
@@ -39,78 +39,123 @@ vector<string> runtime::Backend::get_registered_devices()
...
@@ -39,78 +39,123 @@ vector<string> runtime::Backend::get_registered_devices()
return
BackendManager
::
get_registered_backends
();
return
BackendManager
::
get_registered_backends
();
}
}
void
runtime
::
Backend
::
remove_compiled_function
(
shared_ptr
<
Function
>
func
)
bool
runtime
::
Backend
::
is_supported
(
const
Node
&
node
)
const
{
{
// The default behavior is that a backend does not support any ops. If this is not the case
// then override this method and enhance.
return
false
;
}
}
vector
<
ngraph
::
runtime
::
PerformanceCounter
>
runtime
::
Executable
::
Executable
()
runtime
::
Backend
::
get_performance_data
(
shared_ptr
<
Function
>
func
)
const
{
{
return
vector
<
PerformanceCounter
>
();
}
}
void
runtime
::
Backend
::
validate
(
shared_ptr
<
const
Function
>
function
,
runtime
::
Executable
::~
Executable
()
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
{
const
ParameterVector
&
input_parameters
=
function
->
get_parameters
();
}
if
(
input_parameters
.
size
()
!=
inputs
.
size
())
bool
runtime
::
Executable
::
call_with_validate
(
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
validate
(
outputs
,
inputs
);
return
call
(
outputs
,
inputs
);
}
void
runtime
::
Executable
::
validate
(
const
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
const
ParameterVector
&
parameters
=
get_parameters
();
const
ResultVector
&
results
=
get_results
();
if
(
parameters
.
size
()
!=
inputs
.
size
())
{
{
stringstream
ss
;
stringstream
ss
;
ss
<<
"Call input count "
<<
inputs
.
size
()
<<
" does not match Function's Parameter count "
ss
<<
"Call input count "
<<
inputs
.
size
()
<<
" does not match Function's Parameter count "
<<
input_
parameters
.
size
();
<<
parameters
.
size
();
throw
runtime_error
(
ss
.
str
());
throw
runtime_error
(
ss
.
str
());
}
}
if
(
function
->
get_output_
size
()
!=
outputs
.
size
())
if
(
results
.
size
()
!=
outputs
.
size
())
{
{
stringstream
ss
;
stringstream
ss
;
ss
<<
"Call output count "
<<
outputs
.
size
()
<<
" does not match Function's Result count "
ss
<<
"Call output count "
<<
outputs
.
size
()
<<
" does not match Function's Result count "
<<
function
->
get_output_
size
();
<<
results
.
size
();
throw
runtime_error
(
ss
.
str
());
throw
runtime_error
(
ss
.
str
());
}
}
for
(
size_t
i
=
0
;
i
<
input_
parameters
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
parameters
.
size
();
i
++
)
{
{
if
(
input_
parameters
[
i
]
->
get_element_type
()
!=
inputs
[
i
]
->
get_element_type
())
if
(
parameters
[
i
]
->
get_element_type
()
!=
inputs
[
i
]
->
get_element_type
())
{
{
stringstream
ss
;
stringstream
ss
;
ss
<<
"Input "
<<
i
<<
" type '"
<<
inputs
[
i
]
->
get_element_type
()
ss
<<
"Input "
<<
i
<<
" type '"
<<
inputs
[
i
]
->
get_element_type
()
<<
"' does not match Parameter type '"
<<
input_parameters
[
i
]
->
get_element_type
()
<<
"' does not match Parameter type '"
<<
parameters
[
i
]
->
get_element_type
()
<<
"'"
;
<<
"'"
;
throw
runtime_error
(
ss
.
str
());
throw
runtime_error
(
ss
.
str
());
}
}
if
(
input_
parameters
[
i
]
->
get_shape
()
!=
inputs
[
i
]
->
get_shape
())
if
(
parameters
[
i
]
->
get_shape
()
!=
inputs
[
i
]
->
get_shape
())
{
{
stringstream
ss
;
stringstream
ss
;
ss
<<
"Input "
<<
i
<<
" shape {"
<<
join
(
inputs
[
i
]
->
get_shape
())
ss
<<
"Input "
<<
i
<<
" shape {"
<<
join
(
inputs
[
i
]
->
get_shape
())
<<
"} does not match Parameter shape {"
<<
join
(
input_parameters
[
i
]
->
get_shape
())
<<
"} does not match Parameter shape {"
<<
join
(
parameters
[
i
]
->
get_shape
())
<<
"}"
;
<<
"}"
;
throw
runtime_error
(
ss
.
str
());
throw
runtime_error
(
ss
.
str
());
}
}
}
}
for
(
size_t
i
=
0
;
i
<
function
->
get_output_
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
results
.
size
();
i
++
)
{
{
if
(
function
->
get_output_element_type
(
i
)
!=
outputs
[
i
]
->
get_element_type
())
if
(
results
[
i
]
->
get_element_type
(
)
!=
outputs
[
i
]
->
get_element_type
())
{
{
stringstream
ss
;
stringstream
ss
;
ss
<<
"Output "
<<
i
<<
" type '"
<<
outputs
[
i
]
->
get_element_type
()
ss
<<
"Output "
<<
i
<<
" type '"
<<
outputs
[
i
]
->
get_element_type
()
<<
"' does not match Result type '"
<<
function
->
get_output_element_type
(
i
)
<<
"'"
;
<<
"' does not match Result type '"
<<
results
[
i
]
->
get_element_type
(
)
<<
"'"
;
throw
runtime_error
(
ss
.
str
());
throw
runtime_error
(
ss
.
str
());
}
}
if
(
function
->
get_output_shape
(
i
)
!=
outputs
[
i
]
->
get_shape
())
if
(
results
[
i
]
->
get_shape
(
)
!=
outputs
[
i
]
->
get_shape
())
{
{
stringstream
ss
;
stringstream
ss
;
ss
<<
"Output "
<<
i
<<
" shape {"
<<
join
(
outputs
[
i
]
->
get_shape
())
ss
<<
"Output "
<<
i
<<
" shape {"
<<
join
(
outputs
[
i
]
->
get_shape
())
<<
"} does not match Result shape {"
<<
join
(
function
->
get_output_shape
(
i
))
<<
"}"
;
<<
"} does not match Result shape {"
<<
join
(
results
[
i
]
->
get_shape
(
))
<<
"}"
;
throw
runtime_error
(
ss
.
str
());
throw
runtime_error
(
ss
.
str
());
}
}
}
}
}
}
bool
runtime
::
Backend
::
is_supported
(
const
Node
&
node
)
const
const
ngraph
::
ParameterVector
&
runtime
::
Executable
::
get_parameters
()
const
{
return
m_parameters
;
}
const
ngraph
::
ResultVector
&
runtime
::
Executable
::
get_results
()
const
{
return
m_results
;
}
void
runtime
::
Executable
::
set_parameters_and_results
(
const
Function
&
func
)
{
m_parameters
=
func
.
get_parameters
();
m_results
=
func
.
get_results
();
}
vector
<
runtime
::
PerformanceCounter
>
runtime
::
Executable
::
get_performance_data
()
const
{
return
vector
<
PerformanceCounter
>
();
}
bool
runtime
::
Backend
::
is_supported_property
(
const
Property
prop
)
const
{
{
// The default behavior is that a backend does not support any ops. If this is not the case
// then override this method and enhance.
return
false
;
return
false
;
}
}
bool
runtime
::
Backend
::
call_with_validate
(
std
::
shared_ptr
<
Executable
>
exec
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
return
exec
->
call_with_validate
(
outputs
,
inputs
);
}
bool
runtime
::
Backend
::
call_with_validate
(
const
std
::
unique_ptr
<
Executable
>&
exec
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
return
exec
->
call_with_validate
(
outputs
,
inputs
);
}
src/ngraph/runtime/backend.hpp
View file @
ef2e0118
...
@@ -30,7 +30,8 @@ namespace ngraph
...
@@ -30,7 +30,8 @@ namespace ngraph
class
ExternalFunction
;
class
ExternalFunction
;
class
Tensor
;
class
Tensor
;
class
Backend
;
class
Backend
;
using
Handle
=
std
::
shared_ptr
<
Function
>
;
class
Executable
;
using
Handle
=
std
::
shared_ptr
<
Executable
>
;
}
}
}
}
...
@@ -81,43 +82,8 @@ public:
...
@@ -81,43 +82,8 @@ public:
/// \brief Compiles a Function.
/// \brief Compiles a Function.
/// \param func The function to compile
/// \param func The function to compile
/// \returns compiled function or nullptr on failure
/// \returns compiled function or nullptr on failure
virtual
Handle
compile
(
std
::
shared_ptr
<
Function
>
func
)
=
0
;
virtual
std
::
shared_ptr
<
Executable
>
compile
(
std
::
shared_ptr
<
Function
>
func
,
bool
enable_performance_data
=
false
)
=
0
;
/// \brief Executes a single iteration of a Function. If func is not compiled the call will
/// compile it.
/// \param func The function to execute
/// \returns true if iteration is successful, false otherwise
virtual
bool
call
(
std
::
shared_ptr
<
Function
>
func
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
=
0
;
/// \brief Executes a single iteration of a Function. If func is not compiled the call will
/// compile it. Optionally validates the inputs and outputs against the function graph.
/// \param func The function to execute
/// \returns true if iteration is successful, false otherwise
bool
call_with_validate
(
std
::
shared_ptr
<
Function
>
func
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
validate
(
func
,
outputs
,
inputs
);
return
call
(
func
,
outputs
,
inputs
);
}
/// \brief Compiled functions may be cached. This function removes a compiled function
/// from the cache.
/// \param func The function to execute
virtual
void
remove_compiled_function
(
std
::
shared_ptr
<
Function
>
func
);
/// \brief Enable the collection of per-op performance information on a specified Function.
/// Data collection is via the `get_performance_data` method.
/// \param func The function to collect perfomance data on.
/// \param enable Set to true to enable or false to disable data collection
virtual
void
enable_performance_data
(
std
::
shared_ptr
<
Function
>
func
,
bool
enable
)
{}
/// \brief Collect performance information gathered on a Function.
/// \param func The function to get collected data.
/// \returns Vector of PerformanceCounter information.
virtual
std
::
vector
<
PerformanceCounter
>
get_performance_data
(
std
::
shared_ptr
<
Function
>
func
)
const
;
/// \brief Test if a backend is capable of supporting an op
/// \brief Test if a backend is capable of supporting an op
/// \param node is the op to test.
/// \param node is the op to test.
...
@@ -133,8 +99,62 @@ public:
...
@@ -133,8 +99,62 @@ public:
/// \brief Test if a backend particular property is supported
/// \brief Test if a backend particular property is supported
/// \param prop is the feature to test.
/// \param prop is the feature to test.
/// \returns true if the property is supported, false otherwise.
/// \returns true if the property is supported, false otherwise.
virtual
bool
is_supported_property
(
const
Property
prop
)
const
{
return
false
;
}
virtual
bool
is_supported_property
(
const
Property
prop
)
const
;
void
validate
(
std
::
shared_ptr
<
const
Function
>
func
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
/// The following methods are temporary hacks to reduce the number of changes in this PR
/// They will be removed in a follow-on PR
bool
call_with_validate
(
std
::
shared_ptr
<
Executable
>
handle
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
);
bool
call_with_validate
(
const
std
::
unique_ptr
<
Executable
>&
handle
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
);
};
class
ngraph
::
runtime
::
Executable
{
public
:
Executable
();
virtual
~
Executable
();
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
/// \returns true if iteration is successful, false otherwise
virtual
bool
call
(
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
=
0
;
/// \brief Executes a single iteration of a Function.
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
/// \returns true if iteration is successful, false otherwise
bool
call_with_validate
(
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
);
/// \brief Collect performance information gathered on a Function.
/// \returns Vector of PerformanceCounter information.
virtual
std
::
vector
<
PerformanceCounter
>
get_performance_data
()
const
;
/// \brief Validates a Function.
/// \param outputs vector of runtime::Tensor used as outputs
/// \param inputs vector of runtime::Tensor used as inputs
void
validate
(
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
);
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
);
/// \brief Query the input Parameters
/// \returns an ngraph::op::ParameterVector of all input parameters
const
ngraph
::
ParameterVector
&
get_parameters
()
const
;
/// \brief Query the output Results
/// \returns an ngraph::ResultVector of all input parameters
const
ngraph
::
ResultVector
&
get_results
()
const
;
protected
:
/// \brief Called at the end of compile to the the values to be returned by get_parameters
/// and get_results
/// \param func The function with Results fully resolved.
void
set_parameters_and_results
(
const
Function
&
func
);
private
:
ngraph
::
ParameterVector
m_parameters
;
ngraph
::
ResultVector
m_results
;
};
};
src/ngraph/runtime/hybrid/hybrid_backend.cpp
View file @
ef2e0118
...
@@ -62,14 +62,24 @@ static void node_modifiers(const Node& node, vector<string>& attributes)
...
@@ -62,14 +62,24 @@ static void node_modifiers(const Node& node, vector<string>& attributes)
}
}
}
}
runtime
::
Handle
runtime
::
hybrid
::
HybridBackend
::
compile
(
shared_ptr
<
Function
>
func
)
shared_ptr
<
runtime
::
Executable
>
runtime
::
hybrid
::
HybridBackend
::
compile
(
shared_ptr
<
Function
>
func
,
bool
enable_performance_collection
)
{
{
if
(
m_function_map
.
find
(
func
)
==
m_function_map
.
end
())
return
make_shared
<
HybridExecutable
>
(
{
m_backend_list
,
func
,
enable_performance_collection
,
m_debug_enabled
);
// Clone function
}
FunctionInstance
instance
;
instance
.
m_function
=
clone_function
(
*
func
);
runtime
::
hybrid
::
HybridExecutable
::
HybridExecutable
(
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Backend
>>&
backend_list
,
const
shared_ptr
<
Function
>&
func
,
bool
enable_performance_collection
,
bool
debug_enabled
)
:
m_function
{
func
}
,
m_backend_list
{
backend_list
}
,
m_debug_enabled
{
debug_enabled
}
{
{
// Run placement pass
// Run placement pass
ngraph
::
pass
::
Manager
pass_manager
;
ngraph
::
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
runtime
::
hybrid
::
pass
::
AssignPlacement
>
(
m_backend_list
);
pass_manager
.
register_pass
<
runtime
::
hybrid
::
pass
::
AssignPlacement
>
(
m_backend_list
);
...
@@ -81,16 +91,15 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
...
@@ -81,16 +91,15 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
{
{
pass_manager
.
register_pass
<
ngraph
::
pass
::
VisualizeTree
>
(
"graph.png"
,
node_modifiers
);
pass_manager
.
register_pass
<
ngraph
::
pass
::
VisualizeTree
>
(
"graph.png"
,
node_modifiers
);
}
}
pass_manager
.
run_passes
(
instance
.
m_function
);
pass_manager
.
run_passes
(
m_function
);
// Split function to sub_functions
// Split function to sub_functions
tie
(
instance
.
m_sub_functions
,
instance
.
m_map_parameter_to_result
)
=
tie
(
m_sub_functions
,
m_map_parameter_to_result
)
=
runtime
::
hybrid
::
split_function_by_placement
(
instance
.
m_function
);
runtime
::
hybrid
::
split_function_by_placement
(
m_function
);
m_function_map
.
insert
({
func
,
instance
});
// Compile subfunctions in corresponding backends
// Compile subfunctions in corresponding backends
size_t
subfunction_number
=
0
;
size_t
subfunction_number
=
0
;
for
(
shared_ptr
<
Function
>&
sub_function
:
instance
.
m_sub_functions
)
for
(
shared_ptr
<
Function
>&
sub_function
:
m_sub_functions
)
{
{
size_t
placement
=
runtime
::
hybrid
::
get_colocated_function_placement
(
sub_function
);
size_t
placement
=
runtime
::
hybrid
::
get_colocated_function_placement
(
sub_function
);
if
(
m_debug_enabled
)
if
(
m_debug_enabled
)
...
@@ -102,7 +111,8 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
...
@@ -102,7 +111,8 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
pm
.
run_passes
(
sub_function
);
pm
.
run_passes
(
sub_function
);
}
}
auto
backend
=
m_backend_list
[
placement
];
auto
backend
=
m_backend_list
[
placement
];
backend
->
compile
(
sub_function
);
shared_ptr
<
Executable
>
exec
=
backend
->
compile
(
sub_function
);
m_executable_map
[
sub_function
]
=
exec
;
// Compile will replace nodes so we need to make one more pass through all
// Compile will replace nodes so we need to make one more pass through all
// ops to reset placement
// ops to reset placement
...
@@ -113,38 +123,29 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
...
@@ -113,38 +123,29 @@ runtime::Handle runtime::hybrid::HybridBackend::compile(shared_ptr<Function> fun
}
}
}
}
return
func
;
set_parameters_and_results
(
*
func
)
;
}
}
bool
runtime
::
hybrid
::
HybridBackend
::
call
(
shared_ptr
<
Function
>
func
,
bool
runtime
::
hybrid
::
HybridExecutable
::
call
(
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
{
// Get FunctionInstance
bool
rc
=
true
;
bool
rc
=
true
;
using
node_map_t
=
unordered_map
<
shared_ptr
<
Node
>
,
shared_ptr
<
runtime
::
Tensor
>>
;
using
node_map_t
=
unordered_map
<
shared_ptr
<
Node
>
,
shared_ptr
<
runtime
::
Tensor
>>
;
auto
fit
=
m_function_map
.
find
(
func
);
if
(
fit
==
m_function_map
.
end
())
{
throw
runtime_error
(
"compile() must be called before call()."
);
}
FunctionInstance
&
instance
=
fit
->
second
;
// Parameter and result node in sub_function maps to one Tensor
// Parameter and result node in sub_function maps to one Tensor
node_map_t
map_node_to_tensor
;
node_map_t
map_node_to_tensor
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
{
map_node_to_tensor
[
instance
.
m_function
->
get_parameters
()[
i
]]
=
inputs
[
i
];
map_node_to_tensor
[
m_function
->
get_parameters
()[
i
]]
=
inputs
[
i
];
}
}
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
{
map_node_to_tensor
[
instance
.
m_function
->
get_results
()[
i
]]
=
outputs
[
i
];
map_node_to_tensor
[
m_function
->
get_results
()[
i
]]
=
outputs
[
i
];
}
}
// Call subfunctions
// Call subfunctions
for
(
const
shared_ptr
<
Function
>&
sub_function
:
instance
.
m_sub_functions
)
for
(
const
shared_ptr
<
Function
>&
sub_function
:
m_sub_functions
)
{
{
// Init backend
// Init backend
size_t
placement
=
runtime
::
hybrid
::
get_colocated_function_placement
(
sub_function
);
size_t
placement
=
runtime
::
hybrid
::
get_colocated_function_placement
(
sub_function
);
...
@@ -172,7 +173,7 @@ bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func,
...
@@ -172,7 +173,7 @@ bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func,
else
else
{
{
// Handle temporary tensors that go between subgraphs
// Handle temporary tensors that go between subgraphs
auto
result_node
=
instance
.
m_map_parameter_to_result
.
at
(
parameter_node
);
auto
result_node
=
m_map_parameter_to_result
.
at
(
parameter_node
);
auto
result
=
map_node_to_tensor
.
at
(
result_node
);
auto
result
=
map_node_to_tensor
.
at
(
result_node
);
auto
parameter
=
backend
->
create_tensor
(
parameter_node
->
get_element_type
(),
auto
parameter
=
backend
->
create_tensor
(
parameter_node
->
get_element_type
(),
parameter_node
->
get_shape
());
parameter_node
->
get_shape
());
...
@@ -213,7 +214,8 @@ bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func,
...
@@ -213,7 +214,8 @@ bool runtime::hybrid::HybridBackend::call(shared_ptr<Function> func,
}
}
// Call
// Call
backend
->
call
(
sub_function
,
results
,
parameters
);
auto
exec
=
m_executable_map
[
sub_function
];
exec
->
call
(
results
,
parameters
);
// Need to copy any results to the correct device
// Need to copy any results to the correct device
for
(
const
auto
&
p
:
copy_back
)
for
(
const
auto
&
p
:
copy_back
)
...
@@ -229,7 +231,7 @@ bool runtime::hybrid::HybridBackend::is_supported(const Node& node) const
...
@@ -229,7 +231,7 @@ bool runtime::hybrid::HybridBackend::is_supported(const Node& node) const
return
true
;
return
true
;
}
}
size_t
runtime
::
hybrid
::
Hybrid
Backend
::
get_placement
(
const
runtime
::
Tensor
*
t
)
size_t
runtime
::
hybrid
::
Hybrid
Executable
::
get_placement
(
const
runtime
::
Tensor
*
t
)
{
{
size_t
index
=
0
;
size_t
index
=
0
;
for
(
const
shared_ptr
<
ngraph
::
runtime
::
Backend
>&
be
:
m_backend_list
)
for
(
const
shared_ptr
<
ngraph
::
runtime
::
Backend
>&
be
:
m_backend_list
)
...
...
src/ngraph/runtime/hybrid/hybrid_backend.hpp
View file @
ef2e0118
...
@@ -30,6 +30,7 @@ namespace ngraph
...
@@ -30,6 +30,7 @@ namespace ngraph
namespace
hybrid
namespace
hybrid
{
{
class
HybridBackend
;
class
HybridBackend
;
class
HybridExecutable
;
}
}
}
}
}
}
...
@@ -48,29 +49,37 @@ public:
...
@@ -48,29 +49,37 @@ public:
const
ngraph
::
Shape
&
shape
,
const
ngraph
::
Shape
&
shape
,
void
*
memory_pointer
)
override
;
void
*
memory_pointer
)
override
;
Handle
compile
(
std
::
shared_ptr
<
ngraph
::
Function
>
func
)
override
;
std
::
shared_ptr
<
Executable
>
compile
(
std
::
shared_ptr
<
ngraph
::
Function
>
func
,
bool
enable_performance_data
=
false
)
override
;
bool
call
(
std
::
shared_ptr
<
ngraph
::
Function
>
func
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>&
inputs
)
override
;
bool
is_supported
(
const
ngraph
::
Node
&
node
)
const
override
;
bool
is_supported
(
const
ngraph
::
Node
&
node
)
const
override
;
void
set_debug_enabled
(
bool
flag
)
{
m_debug_enabled
=
flag
;
}
void
set_debug_enabled
(
bool
flag
)
{
m_debug_enabled
=
flag
;
}
private
:
private
:
class
FunctionInstance
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Backend
>>
m_backend_list
;
{
bool
m_debug_enabled
=
false
;
public
:
};
std
::
shared_ptr
<
ngraph
::
Function
>
m_function
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>>
m_sub_functions
;
class
ngraph
::
runtime
::
hybrid
::
HybridExecutable
:
public
runtime
::
Executable
std
::
unordered_map
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>
,
{
std
::
shared_ptr
<
ngraph
::
op
::
Result
>>
public
:
m_map_parameter_to_result
;
HybridExecutable
(
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Backend
>>&
backend_list
,
};
const
std
::
shared_ptr
<
Function
>&
func
,
bool
enable_performance_collection
=
false
,
bool
debug_enabled
=
false
);
bool
call
(
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
runtime
::
Tensor
>>&
inputs
)
override
;
private
:
std
::
shared_ptr
<
ngraph
::
Function
>
m_function
;
std
::
vector
<
std
::
shared_ptr
<
ngraph
::
Function
>>
m_sub_functions
;
std
::
unordered_map
<
std
::
shared_ptr
<
ngraph
::
op
::
Parameter
>
,
std
::
shared_ptr
<
ngraph
::
op
::
Result
>>
m_map_parameter_to_result
;
std
::
map
<
std
::
shared_ptr
<
ngraph
::
Function
>
,
FunctionInstance
>
m_function_map
;
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Backend
>>
m_backend_list
;
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Backend
>>
m_backend_list
;
bool
m_debug_enabled
=
false
;
bool
m_debug_enabled
=
false
;
std
::
unordered_map
<
std
::
shared_ptr
<
Function
>
,
std
::
shared_ptr
<
Executable
>>
m_executable_map
;
size_t
get_placement
(
const
runtime
::
Tensor
*
t
);
size_t
get_placement
(
const
runtime
::
Tensor
*
t
);
};
};
src/ngraph/runtime/interpreter/int_backend.cpp
View file @
ef2e0118
...
@@ -64,12 +64,17 @@ shared_ptr<runtime::Tensor> runtime::interpreter::INTBackend::create_tensor(
...
@@ -64,12 +64,17 @@ shared_ptr<runtime::Tensor> runtime::interpreter::INTBackend::create_tensor(
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
,
this
);
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
,
this
);
}
}
runtime
::
Handle
runtime
::
interpreter
::
INTBackend
::
compile
(
shared_ptr
<
Function
>
function
)
shared_ptr
<
runtime
::
Executable
>
runtime
::
interpreter
::
INTBackend
::
compile
(
shared_ptr
<
Function
>
function
,
bool
enable_performance_collection
)
{
return
make_shared
<
INTExecutable
>
(
function
,
enable_performance_collection
);
}
runtime
::
interpreter
::
INTExecutable
::
INTExecutable
(
const
shared_ptr
<
Function
>&
function
,
bool
enable_performance_collection
)
{
{
FunctionInstance
&
instance
=
m_function_map
[
function
];
if
(
!
instance
.
m_is_compiled
)
{
{
instance
.
m_is_compiled
=
true
;
pass
::
Manager
pass_manager
;
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
LikeReplacement
>
();
pass_manager
.
register_pass
<
pass
::
LikeReplacement
>
();
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorLayout
>>
();
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorLayout
>>
();
...
@@ -78,32 +83,20 @@ runtime::Handle runtime::interpreter::INTBackend::compile(shared_ptr<Function> f
...
@@ -78,32 +83,20 @@ runtime::Handle runtime::interpreter::INTBackend::compile(shared_ptr<Function> f
pass_manager
.
run_passes
(
function
);
pass_manager
.
run_passes
(
function
);
size_t
memory_pool_size
=
function
->
get_temporary_pool_size
();
size_t
memory_pool_size
=
function
->
get_temporary_pool_size
();
instance
.
m_temporary_memory
.
reset
(
new
AlignedBuffer
(
memory_pool_size
,
get_alignment
()));
m_temporary_memory
.
reset
(
new
AlignedBuffer
(
memory_pool_size
,
get_alignment
()));
for
(
const
shared_ptr
<
Node
>&
node
:
function
->
get_ordered_ops
())
for
(
const
shared_ptr
<
Node
>&
node
:
function
->
get_ordered_ops
())
{
{
instance
.
m_wrapped_nodes
.
emplace_back
(
node
);
m_wrapped_nodes
.
emplace_back
(
node
);
}
}
}
}
return
function
;
set_parameters_and_results
(
*
function
)
;
}
}
bool
runtime
::
interpreter
::
INTBackend
::
call
(
shared_ptr
<
Function
>
function
,
bool
runtime
::
interpreter
::
INTExecutable
::
call
(
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
{
auto
fit
=
m_function_map
.
find
(
function
);
if
(
fit
==
m_function_map
.
end
())
{
throw
runtime_error
(
"compile() must be called before call()."
);
}
FunctionInstance
&
instance
=
fit
->
second
;
if
(
!
instance
.
m_is_compiled
)
{
throw
runtime_error
(
"compile() must be called before call()."
);
}
// convert inputs to HostTensor
// convert inputs to HostTensor
vector
<
void
*>
func_inputs
;
vector
<
void
*>
func_inputs
;
vector
<
shared_ptr
<
runtime
::
HostTensor
>>
htv_inputs
;
vector
<
shared_ptr
<
runtime
::
HostTensor
>>
htv_inputs
;
...
@@ -113,7 +106,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
...
@@ -113,7 +106,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
func_inputs
.
push_back
(
static_cast
<
void
*>
(
host_tensor
->
get_data_ptr
()));
func_inputs
.
push_back
(
static_cast
<
void
*>
(
host_tensor
->
get_data_ptr
()));
htv_inputs
.
push_back
(
host_tensor
);
htv_inputs
.
push_back
(
host_tensor
);
}
}
if
(
instance
.
m_nan_check_enabled
)
if
(
m_nan_check_enabled
)
{
{
perform_nan_check
(
htv_inputs
);
perform_nan_check
(
htv_inputs
);
}
}
...
@@ -129,7 +122,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
...
@@ -129,7 +122,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
// map function params -> HostTensor
// map function params -> HostTensor
unordered_map
<
descriptor
::
Tensor
*
,
void
*>
tensor_map
;
unordered_map
<
descriptor
::
Tensor
*
,
void
*>
tensor_map
;
size_t
input_count
=
0
;
size_t
input_count
=
0
;
for
(
auto
param
:
function
->
get_parameters
())
for
(
auto
param
:
get_parameters
())
{
{
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
{
{
...
@@ -139,9 +132,9 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
...
@@ -139,9 +132,9 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
}
}
// map function outputs -> HostTensor
// map function outputs -> HostTensor
for
(
size_t
output_count
=
0
;
output_count
<
function
->
get_output_
size
();
++
output_count
)
for
(
size_t
output_count
=
0
;
output_count
<
get_results
().
size
();
++
output_count
)
{
{
auto
output
=
function
->
get_output_op
(
output_count
)
;
auto
output
=
get_results
()[
output_count
]
;
if
(
!
dynamic_pointer_cast
<
op
::
Result
>
(
output
))
if
(
!
dynamic_pointer_cast
<
op
::
Result
>
(
output
))
{
{
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
...
@@ -151,7 +144,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
...
@@ -151,7 +144,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
}
}
// for each ordered op in the graph
// for each ordered op in the graph
for
(
const
NodeWrapper
&
wrapped
:
instance
.
m_wrapped_nodes
)
for
(
const
NodeWrapper
&
wrapped
:
m_wrapped_nodes
)
{
{
const
Node
*
op
=
&
wrapped
.
get_node
();
const
Node
*
op
=
&
wrapped
.
get_node
();
auto
type_id
=
wrapped
.
get_typeid
();
auto
type_id
=
wrapped
.
get_typeid
();
...
@@ -185,7 +178,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
...
@@ -185,7 +178,7 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
if
(
it
==
tensor_map
.
end
())
if
(
it
==
tensor_map
.
end
())
{
{
auto
offset
=
op
->
get_output_tensor
(
i
).
get_pool_offset
();
auto
offset
=
op
->
get_output_tensor
(
i
).
get_pool_offset
();
host_tensor
=
instance
.
get_temporary_pointer
(
offset
);
host_tensor
=
get_temporary_pointer
(
offset
);
tensor_map
.
insert
({
tensor
,
host_tensor
});
tensor_map
.
insert
({
tensor
,
host_tensor
});
}
}
else
else
...
@@ -224,16 +217,16 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
...
@@ -224,16 +217,16 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
}
}
#pragma GCC diagnostic pop
#pragma GCC diagnostic pop
if
(
instance
.
m_performance_counters_enabled
)
if
(
m_performance_counters_enabled
)
{
{
instance
.
m_timer_map
[
op
].
start
();
m_timer_map
[
op
].
start
();
}
}
generate_calls
(
type
,
wrapped
,
op_outputs
,
op_inputs
,
instance
);
generate_calls
(
type
,
wrapped
,
op_outputs
,
op_inputs
);
if
(
instance
.
m_performance_counters_enabled
)
if
(
m_performance_counters_enabled
)
{
{
instance
.
m_timer_map
[
op
].
stop
();
m_timer_map
[
op
].
stop
();
}
}
if
(
instance
.
m_nan_check_enabled
)
if
(
m_nan_check_enabled
)
{
{
perform_nan_check
(
htv_outputs
,
op
);
perform_nan_check
(
htv_outputs
,
op
);
}
}
...
@@ -242,26 +235,25 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
...
@@ -242,26 +235,25 @@ bool runtime::interpreter::INTBackend::call(shared_ptr<Function> function,
return
true
;
return
true
;
}
}
void
runtime
::
interpreter
::
INTBackend
::
generate_calls
(
const
element
::
Type
&
type
,
void
runtime
::
interpreter
::
INTExecutable
::
generate_calls
(
const
element
::
Type
&
type
,
const
NodeWrapper
&
op
,
const
NodeWrapper
&
op
,
const
vector
<
void
*>&
outputs
,
const
vector
<
void
*>&
outputs
,
const
vector
<
const
void
*>&
inputs
,
const
vector
<
const
void
*>&
inputs
)
FunctionInstance
&
instance
)
{
{
stringstream
ss
;
stringstream
ss
;
switch
(
type
.
get_type_enum
())
switch
(
type
.
get_type_enum
())
{
{
case
element
:
:
Type_t
::
boolean
:
op_engine
<
char
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
boolean
:
op_engine
<
char
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
f32
:
op_engine
<
float
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
f32
:
op_engine
<
float
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
f64
:
op_engine
<
double
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
f64
:
op_engine
<
double
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
i8
:
op_engine
<
int8_t
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
i8
:
op_engine
<
int8_t
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
i16
:
op_engine
<
int16_t
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
i16
:
op_engine
<
int16_t
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
i32
:
op_engine
<
int32_t
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
i32
:
op_engine
<
int32_t
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
i64
:
op_engine
<
int64_t
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
i64
:
op_engine
<
int64_t
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
u8
:
op_engine
<
uint8_t
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
u8
:
op_engine
<
uint8_t
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
u16
:
op_engine
<
uint16_t
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
u16
:
op_engine
<
uint16_t
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
u32
:
op_engine
<
uint32_t
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
u32
:
op_engine
<
uint32_t
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
u64
:
op_engine
<
uint64_t
>
(
op
,
outputs
,
inputs
,
instance
);
break
;
case
element
:
:
Type_t
::
u64
:
op_engine
<
uint64_t
>
(
op
,
outputs
,
inputs
);
break
;
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
bf16
:
case
element
:
:
Type_t
::
bf16
:
...
@@ -270,25 +262,11 @@ void runtime::interpreter::INTBackend::generate_calls(const element::Type& type,
...
@@ -270,25 +262,11 @@ void runtime::interpreter::INTBackend::generate_calls(const element::Type& type,
}
}
}
}
void
runtime
::
interpreter
::
INTBackend
::
set_nan_check
(
shared_ptr
<
Function
>
func
,
bool
enable
)
{
FunctionInstance
&
instance
=
m_function_map
[
func
];
instance
.
m_nan_check_enabled
=
enable
;
}
void
runtime
::
interpreter
::
INTBackend
::
enable_performance_data
(
shared_ptr
<
Function
>
func
,
bool
enable
)
{
FunctionInstance
&
instance
=
m_function_map
[
func
];
instance
.
m_performance_counters_enabled
=
enable
;
}
vector
<
runtime
::
PerformanceCounter
>
vector
<
runtime
::
PerformanceCounter
>
runtime
::
interpreter
::
INT
Backend
::
get_performance_data
(
shared_ptr
<
Function
>
func
)
const
runtime
::
interpreter
::
INT
Executable
::
get_performance_data
(
)
const
{
{
vector
<
runtime
::
PerformanceCounter
>
rc
;
vector
<
runtime
::
PerformanceCounter
>
rc
;
const
FunctionInstance
&
instance
=
m_function_map
.
at
(
func
);
for
(
const
pair
<
const
Node
*
,
stopwatch
>
p
:
m_timer_map
)
for
(
const
pair
<
const
Node
*
,
stopwatch
>
p
:
instance
.
m_timer_map
)
{
{
rc
.
emplace_back
(
p
.
first
->
get_name
().
c_str
(),
rc
.
emplace_back
(
p
.
first
->
get_name
().
c_str
(),
p
.
second
.
get_total_microseconds
(),
p
.
second
.
get_total_microseconds
(),
...
@@ -297,7 +275,7 @@ vector<runtime::PerformanceCounter>
...
@@ -297,7 +275,7 @@ vector<runtime::PerformanceCounter>
return
rc
;
return
rc
;
}
}
void
runtime
::
interpreter
::
INT
Backend
::
perform_nan_check
(
void
runtime
::
interpreter
::
INT
Executable
::
perform_nan_check
(
const
vector
<
shared_ptr
<
HostTensor
>>&
tensors
,
const
Node
*
op
)
const
vector
<
shared_ptr
<
HostTensor
>>&
tensors
,
const
Node
*
op
)
{
{
size_t
arg_number
=
1
;
size_t
arg_number
=
1
;
...
...
src/ngraph/runtime/interpreter/int_backend.hpp
View file @
ef2e0118
...
@@ -143,6 +143,7 @@ namespace ngraph
...
@@ -143,6 +143,7 @@ namespace ngraph
namespace
interpreter
namespace
interpreter
{
{
class
INTBackend
;
class
INTBackend
;
class
INTExecutable
;
}
}
}
}
}
}
...
@@ -161,52 +162,49 @@ public:
...
@@ -161,52 +162,49 @@ public:
std
::
shared_ptr
<
Tensor
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
override
;
std
::
shared_ptr
<
Tensor
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
override
;
Handle
compile
(
std
::
shared_ptr
<
Function
>
function
)
override
;
std
::
shared_ptr
<
Executable
>
compile
(
std
::
shared_ptr
<
Function
>
function
,
bool
enable_performance_data
=
false
)
override
;
bool
call
(
std
::
shared_ptr
<
Function
>
function
,
bool
is_supported
(
const
Node
&
node
)
const
override
;
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
intputs
)
override
;
void
set_nan_check
(
std
::
shared_ptr
<
Function
>
func
,
bool
);
private
:
std
::
set
<
std
::
string
>
m_unsupported_op_name_list
;
};
void
enable_performance_data
(
std
::
shared_ptr
<
Function
>
func
,
bool
enable
)
override
;
class
ngraph
::
runtime
::
interpreter
::
INTExecutable
:
public
Executable
std
::
vector
<
PerformanceCounter
>
{
get_performance_data
(
std
::
shared_ptr
<
Function
>
func
)
const
override
;
public
:
INTExecutable
(
const
std
::
shared_ptr
<
Function
>&
function
,
bool
enable_performance_collection
=
false
);
bool
is_supported
(
const
Node
&
node
)
const
override
;
bool
call
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
intputs
)
override
;
void
set_nan_check
(
bool
value
)
{
m_nan_check_enabled
=
value
;
}
std
::
vector
<
PerformanceCounter
>
get_performance_data
()
const
override
;
private
:
private
:
int
get_alignment
()
const
{
return
64
;
}
int
get_alignment
()
const
{
return
64
;
}
class
FunctionInstance
bool
m_nan_check_enabled
=
false
;
{
bool
m_performance_counters_enabled
=
false
;
public
:
std
::
unordered_map
<
const
Node
*
,
stopwatch
>
m_timer_map
;
bool
m_is_compiled
=
false
;
std
::
vector
<
NodeWrapper
>
m_wrapped_nodes
;
bool
m_nan_check_enabled
=
false
;
std
::
unordered_map
<
const
Node
*
,
std
::
shared_ptr
<
RNGState
>>
m_states
;
bool
m_performance_counters_enabled
=
false
;
std
::
shared_ptr
<
AlignedBuffer
>
m_temporary_memory
;
std
::
unordered_map
<
const
Node
*
,
stopwatch
>
m_timer_map
;
std
::
vector
<
NodeWrapper
>
m_wrapped_nodes
;
std
::
unordered_map
<
const
Node
*
,
std
::
shared_ptr
<
RNGState
>>
m_states
;
std
::
shared_ptr
<
AlignedBuffer
>
m_temporary_memory
;
void
*
get_temporary_pointer
(
size_t
offset
)
{
return
m_temporary_memory
->
get_ptr
(
offset
);
}
};
std
::
map
<
std
::
shared_ptr
<
Function
>
,
FunctionInstance
>
m_function_map
;
std
::
set
<
std
::
string
>
m_unsupported_op_name_list
;
void
*
get_temporary_pointer
(
size_t
offset
)
{
return
m_temporary_memory
->
get_ptr
(
offset
);
}
static
void
perform_nan_check
(
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>>&
,
static
void
perform_nan_check
(
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>>&
,
const
Node
*
op
=
nullptr
);
const
Node
*
op
=
nullptr
);
void
generate_calls
(
const
element
::
Type
&
type
,
void
generate_calls
(
const
element
::
Type
&
type
,
const
NodeWrapper
&
op
,
const
NodeWrapper
&
op
,
const
std
::
vector
<
void
*>&
outputs
,
const
std
::
vector
<
void
*>&
outputs
,
const
std
::
vector
<
const
void
*>&
inputs
,
const
std
::
vector
<
const
void
*>&
inputs
);
FunctionInstance
&
instance
);
template
<
typename
T
>
template
<
typename
T
>
void
op_engine
(
const
NodeWrapper
&
node_wrapper
,
void
op_engine
(
const
NodeWrapper
&
node_wrapper
,
const
std
::
vector
<
void
*>&
out
,
const
std
::
vector
<
void
*>&
out
,
const
std
::
vector
<
const
void
*>&
args
,
const
std
::
vector
<
const
void
*>&
args
)
FunctionInstance
&
instance
)
{
{
const
Node
&
node
=
node_wrapper
.
get_node
();
const
Node
&
node
=
node_wrapper
.
get_node
();
std
::
string
node_op
=
node
.
description
();
std
::
string
node_op
=
node
.
description
();
...
@@ -364,15 +362,15 @@ private:
...
@@ -364,15 +362,15 @@ private:
}
}
case
OP_TYPEID
:
:
GenerateMask
:
case
OP_TYPEID
:
:
GenerateMask
:
{
{
if
(
instance
.
m_states
.
count
(
&
node
)
==
0
)
if
(
m_states
.
count
(
&
node
)
==
0
)
{
{
const
op
::
GenerateMask
*
gm
=
static_cast
<
const
op
::
GenerateMask
*>
(
&
node
);
const
op
::
GenerateMask
*
gm
=
static_cast
<
const
op
::
GenerateMask
*>
(
&
node
);
instance
.
m_states
[
&
node
]
=
std
::
unique_ptr
<
ngraph
::
RNGState
>
(
m_states
[
&
node
]
=
std
::
unique_ptr
<
ngraph
::
RNGState
>
(
ngraph
::
RNGState
::
create_rng_state
(
gm
->
get_seed
(),
gm
->
get_probability
()));
ngraph
::
RNGState
::
create_rng_state
(
gm
->
get_seed
(),
gm
->
get_probability
()));
}
}
bool
training
=
static_cast
<
bool
>
(
static_cast
<
const
T
*>
(
args
[
0
])[
0
]);
bool
training
=
static_cast
<
bool
>
(
static_cast
<
const
T
*>
(
args
[
0
])[
0
]);
auto
state
=
instance
.
m_states
.
at
(
&
node
).
get
();
auto
state
=
m_states
.
at
(
&
node
).
get
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
generate_mask
<
T
>
(
reference
::
generate_mask
<
T
>
(
reinterpret_cast
<
T
*>
(
out
[
0
]),
element_count
,
state
,
training
);
reinterpret_cast
<
T
*>
(
out
[
0
]),
element_count
,
state
,
training
);
...
...
src/ngraph/runtime/nop/nop_backend.cpp
View file @
ef2e0118
...
@@ -54,14 +54,25 @@ shared_ptr<runtime::Tensor> runtime::nop::NOPBackend::create_tensor(const elemen
...
@@ -54,14 +54,25 @@ shared_ptr<runtime::Tensor> runtime::nop::NOPBackend::create_tensor(const elemen
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
,
"external"
);
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
,
"external"
);
}
}
runtime
::
Handle
runtime
::
nop
::
NOPBackend
::
compile
(
shared_ptr
<
Function
>
function
)
shared_ptr
<
runtime
::
Executable
>
runtime
::
nop
::
NOPBackend
::
compile
(
shared_ptr
<
Function
>
function
,
bool
enable_performance_collection
)
{
{
return
function
;
return
make_shared
<
NOPExecutable
>
(
function
,
enable_performance_collection
)
;
}
}
bool
runtime
::
nop
::
NOPBackend
::
call
(
shared_ptr
<
Function
>
function
,
runtime
::
nop
::
NOPExecutable
::
NOPExecutable
(
shared_ptr
<
Function
>
function
,
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
bool
enable_performance_collection
)
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorLayout
>>
();
pass_manager
.
run_passes
(
function
);
set_parameters_and_results
(
*
function
);
}
bool
runtime
::
nop
::
NOPExecutable
::
call
(
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
{
{
return
true
;
return
true
;
}
}
src/ngraph/runtime/nop/nop_backend.hpp
View file @
ef2e0118
...
@@ -32,6 +32,7 @@ namespace ngraph
...
@@ -32,6 +32,7 @@ namespace ngraph
namespace
nop
namespace
nop
{
{
class
NOPBackend
;
class
NOPBackend
;
class
NOPExecutable
;
}
}
}
}
}
}
...
@@ -44,9 +45,14 @@ public:
...
@@ -44,9 +45,14 @@ public:
std
::
shared_ptr
<
Tensor
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
override
;
std
::
shared_ptr
<
Tensor
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
override
;
Handle
compile
(
std
::
shared_ptr
<
Function
>
function
)
override
;
std
::
shared_ptr
<
Executable
>
compile
(
std
::
shared_ptr
<
Function
>
function
,
bool
enable_performance_data
=
false
)
override
;
};
bool
call
(
std
::
shared_ptr
<
Function
>
function
,
class
ngraph
::
runtime
::
nop
::
NOPExecutable
:
public
Executable
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
outputs
,
{
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
intputs
)
override
;
public
:
NOPExecutable
(
std
::
shared_ptr
<
Function
>
function
,
bool
enable_performance_collection
=
false
);
bool
call
(
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
runtime
::
Tensor
>>&
inputs
)
override
;
};
};
src/tools/nbench/benchmark.cpp
View file @
ef2e0118
...
@@ -136,8 +136,7 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
...
@@ -136,8 +136,7 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
stopwatch
timer
;
stopwatch
timer
;
timer
.
start
();
timer
.
start
();
auto
backend
=
runtime
::
Backend
::
create
(
backend_name
);
auto
backend
=
runtime
::
Backend
::
create
(
backend_name
);
backend
->
enable_performance_data
(
f
,
timing_detail
);
auto
compiled_func
=
backend
->
compile
(
f
,
timing_detail
);
auto
compiled_func
=
backend
->
compile
(
f
);
timer
.
stop
();
timer
.
stop
();
cout
.
imbue
(
locale
(
""
));
cout
.
imbue
(
locale
(
""
));
cout
<<
"compile time: "
<<
timer
.
get_milliseconds
()
<<
"ms"
<<
endl
;
cout
<<
"compile time: "
<<
timer
.
get_milliseconds
()
<<
"ms"
<<
endl
;
...
@@ -183,7 +182,7 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
...
@@ -183,7 +182,7 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
{
{
for
(
int
i
=
0
;
i
<
warmup_iterations
;
i
++
)
for
(
int
i
=
0
;
i
<
warmup_iterations
;
i
++
)
{
{
backend
->
call
(
compiled_func
,
results
,
args
);
compiled_func
->
call
(
results
,
args
);
}
}
}
}
...
@@ -205,7 +204,7 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
...
@@ -205,7 +204,7 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
}
}
}
}
}
}
backend
->
call
(
compiled_func
,
results
,
args
);
compiled_func
->
call
(
results
,
args
);
if
(
copy_data
)
if
(
copy_data
)
{
{
for
(
size_t
result_index
=
0
;
result_index
<
results
.
size
();
result_index
++
)
for
(
size_t
result_index
=
0
;
result_index
<
results
.
size
();
result_index
++
)
...
@@ -222,6 +221,6 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
...
@@ -222,6 +221,6 @@ vector<runtime::PerformanceCounter> run_benchmark(shared_ptr<Function> f,
float
time
=
t1
.
get_milliseconds
();
float
time
=
t1
.
get_milliseconds
();
cout
<<
time
/
iterations
<<
"ms per iteration"
<<
endl
;
cout
<<
time
/
iterations
<<
"ms per iteration"
<<
endl
;
vector
<
runtime
::
PerformanceCounter
>
perf_data
=
backend
->
get_performance_data
(
f
);
vector
<
runtime
::
PerformanceCounter
>
perf_data
=
compiled_func
->
get_performance_data
(
);
return
perf_data
;
return
perf_data
;
}
}
test/CMakeLists.txt
View file @
ef2e0118
...
@@ -36,7 +36,6 @@ set(SRC
...
@@ -36,7 +36,6 @@ set(SRC
cse.cpp
cse.cpp
element_type.cpp
element_type.cpp
file_util.cpp
file_util.cpp
graph_partition.cpp
includes.cpp
includes.cpp
input_output_assign.cpp
input_output_assign.cpp
main.cpp
main.cpp
...
...
test/backend_debug_api.cpp
View file @
ef2e0118
...
@@ -37,9 +37,6 @@ TEST(INTERPRETER, nan_check_input)
...
@@ -37,9 +37,6 @@ TEST(INTERPRETER, nan_check_input)
shared_ptr
<
runtime
::
Backend
>
backend
=
runtime
::
Backend
::
create
(
"INTERPRETER"
);
shared_ptr
<
runtime
::
Backend
>
backend
=
runtime
::
Backend
::
create
(
"INTERPRETER"
);
shared_ptr
<
runtime
::
interpreter
::
INTBackend
>
ibackend
=
static_pointer_cast
<
runtime
::
interpreter
::
INTBackend
>
(
backend
);
// Create some tensors for input/output
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
auto
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
copy_data
(
a
,
vector
<
float
>
{
2
,
4
,
NAN
,
16
});
copy_data
(
a
,
vector
<
float
>
{
2
,
4
,
NAN
,
16
});
...
@@ -47,9 +44,12 @@ TEST(INTERPRETER, nan_check_input)
...
@@ -47,9 +44,12 @@ TEST(INTERPRETER, nan_check_input)
copy_data
(
b
,
vector
<
float
>
{
1
,
2
,
1
,
8
});
copy_data
(
b
,
vector
<
float
>
{
1
,
2
,
1
,
8
});
auto
result
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
auto
result
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
auto
handle
=
backend
->
compile
(
f
);
shared_ptr
<
runtime
::
Executable
>
handle
=
backend
->
compile
(
f
);
ibackend
->
set_nan_check
(
handle
,
true
);
EXPECT_ANY_THROW
(
ibackend
->
call_with_validate
(
handle
,
{
result
},
{
a
,
b
}));
shared_ptr
<
runtime
::
interpreter
::
INTExecutable
>
ihandle
=
static_pointer_cast
<
runtime
::
interpreter
::
INTExecutable
>
(
handle
);
ihandle
->
set_nan_check
(
true
);
EXPECT_ANY_THROW
(
handle
->
call_with_validate
({
result
},
{
a
,
b
}));
}
}
TEST
(
INTERPRETER
,
nan_check_output
)
TEST
(
INTERPRETER
,
nan_check_output
)
...
@@ -61,9 +61,6 @@ TEST(INTERPRETER, nan_check_output)
...
@@ -61,9 +61,6 @@ TEST(INTERPRETER, nan_check_output)
shared_ptr
<
runtime
::
Backend
>
backend
=
runtime
::
Backend
::
create
(
"INTERPRETER"
);
shared_ptr
<
runtime
::
Backend
>
backend
=
runtime
::
Backend
::
create
(
"INTERPRETER"
);
shared_ptr
<
runtime
::
interpreter
::
INTBackend
>
ibackend
=
static_pointer_cast
<
runtime
::
interpreter
::
INTBackend
>
(
backend
);
// Create some tensors for input/output
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
auto
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
copy_data
(
a
,
vector
<
float
>
{
2
,
4
,
0
,
16
});
copy_data
(
a
,
vector
<
float
>
{
2
,
4
,
0
,
16
});
...
@@ -71,7 +68,9 @@ TEST(INTERPRETER, nan_check_output)
...
@@ -71,7 +68,9 @@ TEST(INTERPRETER, nan_check_output)
copy_data
(
b
,
vector
<
float
>
{
1
,
2
,
0
,
8
});
copy_data
(
b
,
vector
<
float
>
{
1
,
2
,
0
,
8
});
auto
result
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
auto
result
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
auto
handle
=
backend
->
compile
(
f
);
shared_ptr
<
runtime
::
Executable
>
handle
=
backend
->
compile
(
f
);
ibackend
->
set_nan_check
(
handle
,
true
);
shared_ptr
<
runtime
::
interpreter
::
INTExecutable
>
ihandle
=
EXPECT_ANY_THROW
(
ibackend
->
call_with_validate
(
handle
,
{
result
},
{
a
,
b
}));
static_pointer_cast
<
runtime
::
interpreter
::
INTExecutable
>
(
handle
);
ihandle
->
set_nan_check
(
true
);
EXPECT_ANY_THROW
(
handle
->
call_with_validate
({
result
},
{
a
,
b
}));
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment