Unverified Commit ff2e7fe4 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into cyphers/s-barannikov

parents 50f3bcbd d92dcdfe
...@@ -16,9 +16,8 @@ ...@@ -16,9 +16,8 @@
import numpy as np import numpy as np
import ngraph as ng import ngraph as ng
from ngraph.utils.types import NumericData
from string import ascii_uppercase from typing import Any, Callable, List
import test import test
...@@ -32,10 +31,14 @@ def get_runtime(): ...@@ -32,10 +31,14 @@ def get_runtime():
def run_op_node(input_data, op_fun, *args): def run_op_node(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`. """Run computation on node performing `op_fun`.
`op_fun` has to accept a node as an argument. `op_fun` has to accept a node as an argument.
This function converts passed raw input data to nGraph Constant Node and that form is passed
to `op_fun`.
:param input_data: The input data for performed computation. :param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out. :param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out. :param args: The arguments passed to operation we want to carry out.
...@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args): ...@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args):
comp_args = [] comp_args = []
op_fun_args = [] op_fun_args = []
comp_inputs = [] comp_inputs = []
for idx, data in enumerate(input_data): for data in input_data:
if np.isscalar(data): op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
else:
node = ng.parameter(data.shape, name=ascii_uppercase[idx], dtype=data.dtype)
op_fun_args.append(node)
comp_args.append(node)
comp_inputs.append(data)
op_fun_args.extend(args) op_fun_args.extend(args)
node = op_fun(*op_fun_args) node = op_fun(*op_fun_args)
computation = runtime.computation(node, *comp_args) computation = runtime.computation(node, *comp_args)
...@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args): ...@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args):
def run_op_numeric_data(input_data, op_fun, *args): def run_op_numeric_data(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`. """Run computation on node performing `op_fun`.
`op_fun` has to accept a scalar or an array. `op_fun` has to accept a scalar or an array.
This function passess input data AS IS. This mean that in case they're a scalar (integral,
or floating point value) or a NumPy's ndarray object they will be automatically converted
to nGraph's Constant Nodes.
:param input_data: The input data for performed computation. :param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out. :param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out. :param args: The arguments passed to operation we want to carry out.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment