Commit 3e680b68 authored by tsocha's avatar tsocha Committed by Scott Cyphers

[Py]Fix problem with double set layout (#808)

* [Py]Fix problem with double set layout

* Extend UT for coverage double set layout
parent c2a09de9
...@@ -65,6 +65,8 @@ class Computation: ...@@ -65,6 +65,8 @@ class Computation:
element_type = parameter.get_element_type() element_type = parameter.get_element_type()
self.tensor_views.append(runtime.backend.make_primary_tensor_view(element_type, shape)) self.tensor_views.append(runtime.backend.make_primary_tensor_view(element_type, shape))
self.function = Function(self.node, self.parameters, 'ngraph_computation') self.function = Function(self.node, self.parameters, 'ngraph_computation')
external = self.runtime.manager.compile(self.function)
self.call_frame = self.runtime.backend.make_call_frame(external)
def __repr__(self): # type: () -> str def __repr__(self): # type: () -> str
params_string = ', '.join([param.name for param in self.parameters]) params_string = ', '.join([param.name for param in self.parameters])
...@@ -85,9 +87,7 @@ class Computation: ...@@ -85,9 +87,7 @@ class Computation:
result_element_type, result_shape) result_element_type, result_shape)
result_arr = np.empty(result_shape, dtype=result_dtype) result_arr = np.empty(result_shape, dtype=result_dtype)
external = self.runtime.manager.compile(self.function) self.call_frame.call([result_view], self.tensor_views)
call_frame = self.runtime.backend.make_call_frame(external)
call_frame.call([result_view], self.tensor_views)
Computation._read_tensor_view_to_ndarray(result_view, result_arr) Computation._read_tensor_view_to_ndarray(result_view, result_arr)
result_arr = result_arr.reshape(result_shape) result_arr = result_arr.reshape(result_shape)
......
...@@ -40,6 +40,12 @@ def test_simple_computation_on_ndarrays(dtype): ...@@ -40,6 +40,12 @@ def test_simple_computation_on_ndarrays(dtype):
result = computation(value_a, value_b, value_c) result = computation(value_a, value_b, value_c)
assert np.allclose(result, np.array([[54, 80], [110, 144]], dtype=dtype)) assert np.allclose(result, np.array([[54, 80], [110, 144]], dtype=dtype))
value_a = np.array([[13, 14], [15, 16]], dtype=dtype)
value_b = np.array([[17, 18], [19, 20]], dtype=dtype)
value_c = np.array([[21, 22], [23, 24]], dtype=dtype)
result = computation(value_a, value_b, value_c)
assert np.allclose(result, np.array([[630, 704], [782, 864]], dtype=dtype))
def test_serialization(): def test_serialization():
dtype = np.float32 dtype = np.float32
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment