Commit ff581e05 authored by tsocha's avatar tsocha Committed by Michał Karzyński

[Py] Check input shape is the same as tensor shape (#1016)

parent cd59bfe4
......@@ -24,6 +24,7 @@ from ngraph.impl.runtime import Backend
from ngraph.impl.op import Parameter
from ngraph.utils.types import get_dtype, NumericData
from ngraph.exceptions import UserInputError
log = logging.getLogger(__file__)
......@@ -107,6 +108,9 @@ class Computation:
def _write_ndarray_to_tensor_view(value, tensor_view):
# type: (np.ndarray, TensorViewType) -> None
tensor_view_dtype = get_dtype(tensor_view.element_type)
if list(tensor_view.shape) != list(value.shape) and len(value.shape) > 0:
raise UserInputError('Provided tensor\'s shape: %s does not match the expected: %s.',
list(value.shape), list(tensor_view.shape))
if value.dtype != tensor_view_dtype:
log.warning(
'Attempting to write a %s value to a %s tensor. Will attempt type conversion.',
......
......@@ -20,6 +20,7 @@ import json
import ngraph as ng
from test.ngraph.util import get_runtime, run_op_node
from ngraph.impl import Function, NodeVector
from ngraph.exceptions import UserInputError
@pytest.mark.parametrize('dtype', [np.float32, np.float64,
......@@ -157,3 +158,16 @@ def test_convert_to_uint(val_type):
expected = np.array(input_data, dtype=val_type)
result = run_op_node([input_data], ng.convert, val_type)
assert np.allclose(result, expected)
def test_bad_data_shape():
A = ng.parameter(shape=[2, 2], name='A', dtype=np.float32)
B = ng.parameter(shape=[2, 2], name='B')
model = (A + B)
runtime = ng.runtime(backend_name='CPU')
computation = runtime.computation(model, A, B)
value_a = np.array([[1, 2]], dtype=np.float32)
value_b = np.array([[5, 6], [7, 8]], dtype=np.float32)
with pytest.raises(UserInputError):
computation(value_a, value_b)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment