Commit 86fd6a53 authored by arogowie-intel's avatar arogowie-intel Committed by Scott Cyphers

[Py] Python wrappers for nGraph operations. (#793)

* Add python wrappers for nGraph Cos, Cosh operations.

- Update docstrings.

* Enable auxiliary function running computation to accept multiple input nodes.

* Python wrapper for nGraph Dot function with UT.

* Update python wrappers for nGraph Exp and Equal operations.

- Update docstrings.
- Add UT for exp.

* Update python wrappers for nGraph Floor, Greater, GreaterEq, Less, LessEq operations.

- Update docstrings.
- Add UT for ng.floor.

* Update python wrapper for nGraph Log operation.

- Update docstring.
- Add UT.
parent 1a7ab108
......@@ -29,6 +29,8 @@ from ngraph.ops import concat
from ngraph.ops import constant
from ngraph.ops import convert
from ngraph.ops import convolution
from ngraph.ops import cos
from ngraph.ops import cosh
from ngraph.ops import divide
from ngraph.ops import dot
from ngraph.ops import equal
......
This diff is collapsed.
......@@ -58,15 +58,13 @@ def test_serialization():
assert serial_json[0]['name'] != ''
assert 10 == len(serial_json[0]['ops'])
def test_broadcast():
input_data = np.array([1, 2, 3])
new_shape = [3, 3]
expected = [[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]
result = run_op_node(input_data, ng.broadcast, new_shape)
result = run_op_node([input_data], ng.broadcast, new_shape)
assert np.allclose(result, expected)
axis = 0
......@@ -74,13 +72,13 @@ def test_broadcast():
[2, 2, 2],
[3, 3, 3]]
result = run_op_node(input_data, ng.broadcast, new_shape, axis)
result = run_op_node([input_data], ng.broadcast, new_shape, axis)
assert np.allclose(result, expected)
input_data = np.arange(4)
new_shape = [3, 4, 2, 4]
expected = np.broadcast_to(input_data, new_shape)
result = run_op_node(input_data, ng.broadcast, new_shape)
result = run_op_node([input_data], ng.broadcast, new_shape)
assert np.allclose(result, expected)
......@@ -89,7 +87,7 @@ def test_broadcast():
])
def test_convert_to_bool(val_type, input_data):
expected = np.array(input_data, dtype=val_type)
result = run_op_node(input_data, ng.convert, val_type)
result = run_op_node([input_data], ng.convert, val_type)
assert np.allclose(result, expected)
......@@ -101,7 +99,7 @@ def test_convert_to_float(val_type, range_start, range_end, in_dtype):
np.random.seed(133391)
input_data = np.random.randint(range_start, range_end, size=(2, 2), dtype=in_dtype)
expected = np.array(input_data, dtype=val_type)
result = run_op_node(input_data, ng.convert, val_type)
result = run_op_node([input_data], ng.convert, val_type)
assert np.allclose(result, expected)
......@@ -115,7 +113,7 @@ def test_convert_to_int(val_type):
np.random.seed(133391)
input_data = np.ceil(-8 + np.random.rand(2, 3, 4) * 16)
expected = np.array(input_data, dtype=val_type)
result = run_op_node(input_data, ng.convert, val_type)
result = run_op_node([input_data], ng.convert, val_type)
assert np.allclose(result, expected)
......@@ -129,5 +127,5 @@ def test_convert_to_uint(val_type):
np.random.seed(133391)
input_data = np.ceil(np.random.rand(2, 3, 4) * 16)
expected = np.array(input_data, dtype=val_type)
result = run_op_node(input_data, ng.convert, val_type)
result = run_op_node([input_data], ng.convert, val_type)
assert np.allclose(result, expected)
# ******************************************************************************
# Copyright 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import numpy as np
import pytest
import ngraph as ng
from test.ngraph.util import run_op_node
@pytest.mark.parametrize('left_shape, right_shape, reduction_axes_count, numpy_axes', [
# matrix, vector
([2, 4], [4], None, 1),
([4], [4, 2], None, 1),
# matrix, matrix
([2, 4], [4, 2], None, 1),
# result is a scalar
([2, 4], [2, 4], 2, 2),
# tensor, vector
([2, 4, 5], [5], None, 1),
([5], [5, 4, 2], None, 1),
# tensor, matrix
([2, 4, 5], [5, 4], None, 1),
([5, 4], [4, 5, 2], None, 1),
# tensor, tensor
([2, 3, 4, 5], [5, 2, 3], None, 1),
([2, 3, 4, 5], [4, 5, 2, 4], 2, 2),
])
def test_dot(left_shape, right_shape, reduction_axes_count, numpy_axes):
np.random.seed(133391)
left_input = -100.0 + np.random.rand(*left_shape) * 200.0
right_input = -100.0 + np.random.rand(*right_shape) * 200.0
expected = np.tensordot(left_input, right_input, numpy_axes)
result = run_op_node([left_input, right_input], ng.dot, reduction_axes_count)
assert np.allclose(result, expected)
def test_dot_tensor_scalar():
np.random.seed(133391)
left_input = 10.0
right_input = -100.0 + np.random.rand(2, 3, 4) * 200.0
expected = left_input * right_input
result = run_op_node([left_input, right_input], ng.dot)
assert np.allclose(result, expected)
result = run_op_node([right_input, left_input], ng.dot)
assert np.allclose(result, expected)
......@@ -28,13 +28,18 @@ from test.ngraph.util import run_op_numeric_data, run_op_node
(ng.atan, np.arctan, -100, 100),
(ng.ceiling, np.ceil, -100, 100),
(ng.ceil, np.ceil, -100, 100),
(ng.cos, np.cos, -np.pi, np.pi),
(ng.cosh, np.cosh, -np.pi, np.pi),
(ng.exp, np.exp, -100, 100),
(ng.floor, np.floor, -100, 100),
(ng.log, np.log, 0, 100),
])
def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
np.random.seed(133391)
input_data = range_start + np.random.rand(2, 3, 4) * (range_end - range_start)
expected = numpy_fn(input_data)
result = run_op_node(input_data, ng_api_fn)
result = run_op_node([input_data], ng_api_fn)
assert np.allclose(result, expected)
result = run_op_numeric_data(input_data, ng_api_fn)
......@@ -49,11 +54,16 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
(ng.atan, np.arctan, np.float32(-0.5)),
(ng.ceiling, np.ceil, np.float32(1.5)),
(ng.ceil, np.ceil, np.float32(1.5)),
(ng.cos, np.cos, np.float32(np.pi / 4.0)),
(ng.cosh, np.cosh, np.float32(np.pi / 4.0)),
(ng.exp, np.exp, np.float32(1.5)),
(ng.floor, np.floor, np.float32(1.5)),
(ng.log, np.log, np.float32(1.5)),
])
def test_unary_op_scalar(ng_api_fn, numpy_fn, input_data):
expected = numpy_fn(input_data)
result = run_op_node(input_data, ng_api_fn)
result = run_op_node([input_data], ng_api_fn)
assert np.allclose(result, expected)
result = run_op_numeric_data(input_data, ng_api_fn)
......
......@@ -40,5 +40,5 @@ def test_reduction_ops(ng_api_helper, numpy_function, reduction_axes):
input_data = np.random.randn(*shape).astype(np.float32)
expected = numpy_function(input_data, axis=reduction_axes)
result = run_op_node(input_data, ng_api_helper, reduction_axes)
result = run_op_node([input_data], ng_api_helper, reduction_axes)
assert np.allclose(result, expected)
......@@ -16,9 +16,14 @@
import numpy as np
import pytest
import ngraph as ng
from string import ascii_uppercase
def _get_numpy_dtype(scalar):
return np.array([scalar]).dtype
def get_runtime():
"""Return runtime object."""
......@@ -37,10 +42,21 @@ def run_op_node(input_data, op_fun, *args):
:return: The result from computations.
"""
runtime = get_runtime()
parameter_a = ng.parameter(input_data.shape, name='A', dtype=np.float32)
node = op_fun(parameter_a, *args)
computation = runtime.computation(node, parameter_a)
return computation(input_data)
comp_args = []
op_fun_args = []
comp_inputs = []
for idx, data in enumerate(input_data):
if np.isscalar(data):
op_fun_args.append(ng.constant(data, _get_numpy_dtype(data)))
else:
node = ng.parameter(data.shape, name=ascii_uppercase[idx], dtype=data.dtype)
op_fun_args.append(node)
comp_args.append(node)
comp_inputs.append(data)
op_fun_args.extend(args)
node = op_fun(*op_fun_args)
computation = runtime.computation(node, *comp_args)
return computation(*comp_inputs)
def run_op_numeric_data(input_data, op_fun, *args):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment