Commit 2bd23b27 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[Py]Update unit tests (#1011)

parent bb3a8143
...@@ -20,20 +20,6 @@ import ngraph as ng ...@@ -20,20 +20,6 @@ import ngraph as ng
from test.ngraph.util import run_op_numeric_data, run_op_node from test.ngraph.util import run_op_numeric_data, run_op_node
@pytest.mark.xfail(reason='Results mismatch when passing created Constant node from raw data.')
@pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [
(ng.cos, np.cos, -100., 100.),
(ng.sin, np.sin, -100., 100.),
])
def test_unary_op_array_err(ng_api_fn, numpy_fn, range_start, range_end):
np.random.seed(133391)
input_data = range_start + np.random.rand(2, 3, 4) * (range_end - range_start)
expected = numpy_fn(input_data)
result = run_op_numeric_data(input_data, ng_api_fn)
assert np.allclose(result, expected)
@pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [ @pytest.mark.parametrize('ng_api_fn, numpy_fn, range_start, range_end', [
(ng.absolute, np.abs, -1, 1), (ng.absolute, np.abs, -1, 1),
(ng.abs, np.abs, -1, 1), (ng.abs, np.abs, -1, 1),
...@@ -42,14 +28,14 @@ def test_unary_op_array_err(ng_api_fn, numpy_fn, range_start, range_end): ...@@ -42,14 +28,14 @@ def test_unary_op_array_err(ng_api_fn, numpy_fn, range_start, range_end):
(ng.atan, np.arctan, -100., 100.), (ng.atan, np.arctan, -100., 100.),
(ng.ceiling, np.ceil, -100., 100.), (ng.ceiling, np.ceil, -100., 100.),
(ng.ceil, np.ceil, -100., 100.), (ng.ceil, np.ceil, -100., 100.),
(ng.cos, np.cos, -np.pi * 2., np.pi * 2.), (ng.cos, np.cos, -100., 100.),
(ng.cosh, np.cosh, -100., 100.), (ng.cosh, np.cosh, -100., 100.),
(ng.exp, np.exp, -100., 100.), (ng.exp, np.exp, -100., 100.),
(ng.floor, np.floor, -100., 100.), (ng.floor, np.floor, -100., 100.),
(ng.log, np.log, 0, 100.), (ng.log, np.log, 0, 100.),
(ng.relu, lambda x: np.maximum(0, x), -100., 100.), (ng.relu, lambda x: np.maximum(0, x), -100., 100.),
(ng.sign, np.sign, -100., 100.), (ng.sign, np.sign, -100., 100.),
(ng.sin, np.sin, -np.pi * 2., np.pi * 2.), (ng.sin, np.sin, -100., 100.),
(ng.sinh, np.sinh, -100., 100.), (ng.sinh, np.sinh, -100., 100.),
(ng.sqrt, np.sqrt, 0., 100.), (ng.sqrt, np.sqrt, 0., 100.),
(ng.tan, np.tan, -1., 1.), (ng.tan, np.tan, -1., 1.),
...@@ -62,10 +48,10 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end): ...@@ -62,10 +48,10 @@ def test_unary_op_array(ng_api_fn, numpy_fn, range_start, range_end):
expected = numpy_fn(input_data) expected = numpy_fn(input_data)
result = run_op_node([input_data], ng_api_fn) result = run_op_node([input_data], ng_api_fn)
assert np.allclose(result, expected) np.testing.assert_allclose(result, expected, rtol=0.001)
result = run_op_numeric_data(input_data, ng_api_fn) result = run_op_numeric_data(input_data, ng_api_fn)
assert np.allclose(result, expected) np.testing.assert_allclose(result, expected, rtol=0.001)
@pytest.mark.parametrize('ng_api_fn, numpy_fn, input_data', [ @pytest.mark.parametrize('ng_api_fn, numpy_fn, input_data', [
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment