Commit d92dcdfe authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[Py] Update documentation for utility test functions. (#3115)

* Update documentation for utility test functions.

* Remove unnecesary code and update comment.
parent 6b8336a8
......@@ -16,9 +16,8 @@
import numpy as np
import ngraph as ng
from string import ascii_uppercase
from ngraph.utils.types import NumericData
from typing import Any, Callable, List
import test
......@@ -32,10 +31,14 @@ def get_runtime():
def run_op_node(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`.
`op_fun` has to accept a node as an argument.
This function converts passed raw input data to nGraph Constant Node and that form is passed
to `op_fun`.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out.
......@@ -45,14 +48,8 @@ def run_op_node(input_data, op_fun, *args):
comp_args = []
op_fun_args = []
comp_inputs = []
for idx, data in enumerate(input_data):
if np.isscalar(data):
op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
else:
node = ng.parameter(data.shape, name=ascii_uppercase[idx], dtype=data.dtype)
op_fun_args.append(node)
comp_args.append(node)
comp_inputs.append(data)
for data in input_data:
op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
op_fun_args.extend(args)
node = op_fun(*op_fun_args)
computation = runtime.computation(node, *comp_args)
......@@ -60,10 +57,15 @@ def run_op_node(input_data, op_fun, *args):
def run_op_numeric_data(input_data, op_fun, *args):
# type: (NumericData, Callable, *Any) -> List[NumericData]
"""Run computation on node performing `op_fun`.
`op_fun` has to accept a scalar or an array.
This function passess input data AS IS. This mean that in case they're a scalar (integral,
or floating point value) or a NumPy's ndarray object they will be automatically converted
to nGraph's Constant Nodes.
:param input_data: The input data for performed computation.
:param op_fun: The function handler for operation we want to carry out.
:param args: The arguments passed to operation we want to carry out.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment