Commit c8691054 authored by Adam Procter's avatar Adam Procter

Further test refactoring: generate separate tests rather than one big one

parent a5c56eae
This diff is collapsed.
......@@ -208,6 +208,34 @@ def print_shape(dims):
return 'Shape{' + ','.join(strs) + '}'
def print_slice(sl):
if sl is None:
return 'newaxis'
elif sl is Ellipsis:
return "..."
elif isinstance(sl, slice):
s = ''
if sl.start is not None:
s += str(sl.start)
s += ':'
if sl.stop is not None:
s += str(sl.stop)
if sl.step is not None:
s += ':'
s += str(sl.step)
return s
else:
return str(sl)
def print_slices(slices):
slices = make_iterable(slices)
strs = []
for sl in slices:
strs.append(print_slice(sl))
return '[' + ','.join(strs) + ']'
#
# Class to intercept indexing operations and write an nGraph C++ test case. The
# generated test case will ensure that the output is identical to that which
......@@ -239,6 +267,7 @@ class SliceTestWriter:
n_elems = np.prod(shape)
self._dtype = dtype
self._stream = stream
self._test_counter = 0
def __getitem__(self, slices):
self.write_test(slices)
......@@ -252,10 +281,18 @@ class SliceTestWriter:
except TypeError:
pass
self._stream.write('\n')
self._stream.write('// slices are: %s\n' % print_slices(slices))
self._stream.write('// dtype is: %s\n' % self._dtype)
self._stream.write('// input shape is: %s\n' % print_shape(self._shape))
try:
data_out = data_in.__getitem__(slices)
except Exception as e:
self._stream.write(' check_failure<%s>\n'
self._stream.write('// failure is expected\n'
'NGRAPH_TEST(${BACKEND_NAME}, dyn_slice_%d)\n'
'{\n'
' check_failure<%s>\n'
' (%s,\n'
' %s,\n'
' std::vector<int64_t>{%s},\n'
......@@ -266,7 +303,9 @@ class SliceTestWriter:
' AxisSet{%s},\n'
' AxisSet{%s},\n'
' AxisSet{%s});\n'
% (np_dt_to_c(self._dtype),
'}\n'
% (self._test_counter,
np_dt_to_c(self._dtype),
np_dt_to_ng(self._dtype),
print_shape(data_in.shape),
print_lb_values(slices),
......@@ -278,7 +317,10 @@ class SliceTestWriter:
print_shrink_mask_axes(slices),
print_ellipsis_mask_axes(slices)))
else:
self._stream.write(' check_success<%s>\n'
self._stream.write('// expected output shape is %s\n'
'NGRAPH_TEST(${BACKEND_NAME}, dyn_slice_%d)\n'
'{\n'
' check_success<%s>\n'
' (%s,\n'
' %s,\n'
' std::vector<int64_t>{%s},\n'
......@@ -291,7 +333,10 @@ class SliceTestWriter:
' AxisSet{%s},\n'
' %s,\n'
' std::vector<%s>{%s});\n'
% (np_dt_to_c(self._dtype),
'}\n'
% (print_shape(data_out.shape),
self._test_counter,
np_dt_to_c(self._dtype),
np_dt_to_ng(self._dtype),
print_shape(data_in.shape),
print_lb_values(slices),
......@@ -305,16 +350,15 @@ class SliceTestWriter:
print_shape(data_out.shape),
np_dt_to_c(self._dtype), print_values(data_out.reshape(-1))))
self._test_counter += 1
def set_shape(self,shape):
self._shape = shape
def set_dtype(self,dtype):
self._dtype = dtype
def main():
assert(len(sys.argv) > 1)
f = open(sys.argv[1], 'w')
def write_header(f):
f.write('''\
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
......@@ -456,11 +500,22 @@ void check_success(const element::Type& input_element_type,
EXPECT_EQ(output_values, expected_values);
}
''')
NGRAPH_TEST(${BACKEND_NAME}, dyn_slice)
{
def write_footer(f):
f.write('''\
// clang-format on
''')
def main():
if len(sys.argv) < 2:
sys.stderr.write('Output filename is required\n')
sys.exit(1)
f = open(sys.argv[1], 'w')
write_header(f)
t = SliceTestWriter(stream=f)
t.set_shape((4,))
......@@ -561,11 +616,7 @@ NGRAPH_TEST(${BACKEND_NAME}, dyn_slice)
t.set_dtype('int32')
t[...,...] # error expected (too many ellipses)
f.write('''\
}
// clang-format on
''')
write_footer(f)
f.close()
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment