Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
ff2e7fe4
Unverified
Commit
ff2e7fe4
authored
Jul 22, 2019
by
Scott Cyphers
Committed by
GitHub
Jul 22, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into cyphers/s-barannikov
parents
50f3bcbd
d92dcdfe
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
11 deletions
+13
-11
util.py
python/test/ngraph/util.py
+13
-11
No files found.
python/test/ngraph/util.py
View file @
ff2e7fe4
...
...
@@ -16,9 +16,8 @@
import
numpy
as
np
import
ngraph
as
ng
from
string
import
ascii_uppercase
from
ngraph.utils.types
import
NumericData
from
typing
import
Any
,
Callable
,
List
import
test
...
...
@@ -32,10 +31,14 @@ def get_runtime():
def
run_op_node
(
input_data
,
op_fun
,
*
args
):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`.
`op_fun` has to accept a node as an argument.
This function converts passed raw input data to nGraph Constant Node and that form is passed
to `op_fun`.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out.
...
...
@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args):
comp_args
=
[]
op_fun_args
=
[]
comp_inputs
=
[]
for
idx
,
data
in
enumerate
(
input_data
):
if
np
.
isscalar
(
data
):
op_fun_args
.
append
(
ng
.
constant
(
data
,
_get_numpy_dtype
(
data
)))
else
:
node
=
ng
.
parameter
(
data
.
shape
,
name
=
ascii_uppercase
[
idx
],
dtype
=
data
.
dtype
)
op_fun_args
.
append
(
node
)
comp_args
.
append
(
node
)
comp_inputs
.
append
(
data
)
for
data
in
input_data
:
op_fun_args
.
append
(
ng
.
constant
(
data
,
_get_numpy_dtype
(
data
)))
op_fun_args
.
extend
(
args
)
node
=
op_fun
(
*
op_fun_args
)
computation
=
runtime
.
computation
(
node
,
*
comp_args
)
...
...
@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args):
def
run_op_numeric_data
(
input_data
,
op_fun
,
*
args
):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`.
`op_fun` has to accept a scalar or an array.
This function passess input data AS IS. This mean that in case they're a scalar (integral,
or floating point value) or a NumPy's ndarray object they will be automatically converted
to nGraph's Constant Nodes.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment