Skip to content

Commit ab57724

Browse files
committed
Added missing unit tests
1 parent fc7d1d2 commit ab57724

2 files changed

Lines changed: 28 additions & 11 deletions

File tree

syncropatch_export/trace.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def get_trace_sweeps(self, sweeps=None, leakcorrect=False):
175175
corrected data can be obtained by setting ``leakcorrect=True``.
176176
177177
Args:
178-
sweeps (int): The number of sweeps to return.
178+
sweeps (list): A list of sweep indexes to return, e.g. ``[0, 1, 2]``.
179179
leakcorrect (bool): Used to choose corrected or uncorrected data.
180180
181181
Returns:
@@ -191,18 +191,22 @@ def get_trace_sweeps(self, sweeps=None, leakcorrect=False):
191191
for ijWell in iCol:
192192
out_dict[ijWell] = []
193193

194+
# No sweeps selected? Then return full set
194195
if sweeps is None:
195-
# Sometimes NofSweeps seems to be incorrect
196196
sweeps = list(range(self.NofSweeps))
197-
198-
# Check `sweeps` is something sensible
199-
elif len(sweeps) > self.NofSweeps:
200-
raise ValueError('Required #sweeps > total #sweeps.')
201-
202-
# convert negative values to positive
203-
for i, sweep in enumerate(sweeps):
204-
if sweep < 0:
205-
sweeps[i] = self.NofSweeps + sweep
197+
else:
198+
# Allow negative values to index later sweeps
199+
sweeps = [self.NofSweeps + x if x < 0 else x for x in sweeps]
200+
# Check all sweeps exist
201+
if max(sweeps) >= self.NofSweeps:
202+
raise ValueError(
203+
f'Invalid sweep selection: sweep {max(sweeps)} requested,'
204+
f' but only {self.NofSweeps} available.')
205+
if min(sweeps) < 0:
206+
raise ValueError(
207+
f'Invalid sweep selection: sweep'
208+
f' {min(sweeps) - self.NofSweeps} requested, but only'
209+
f' {self.NofSweeps} available.')
206210

207211
trace_file_idxs, idx_is = self.get_trace_file(sweeps)
208212

tests/test_trace_class.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,26 @@ def test_get_traces(self):
9999
ts = self.trace.get_times()
100100
all_traces = self.trace.get_all_traces(leakcorrect=True)
101101
all_traces = self.trace.get_all_traces()
102+
# TODO: Check the output, numerically, by comparing a few points
102103

103104
self.assertTrue(np.all(np.isfinite(v)))
104105
self.assertTrue(np.all(np.isfinite(ts)))
105106

106107
for well, trace in all_traces.items():
107108
self.assertTrue(np.all(np.isfinite(trace)))
108109

110+
# Test complex sweep selection
111+
a = self.trace.get_trace_sweeps([-1, -2])
112+
b = self.trace.get_trace_sweeps([1, 0])
113+
self.assertEqual(len(a), len(b))
114+
self.assertTrue(np.all(a['A01'] == b['A01']))
115+
116+
# Test asking for non-existent sweeps
117+
self.assertRaisesRegex(ValueError, 'Invalid sweep selection',
118+
self.trace.get_trace_sweeps, [2])
119+
self.assertRaisesRegex(ValueError, 'Invalid sweep selection',
120+
self.trace.get_trace_sweeps, [-3])
121+
109122
'''
110123
# plot test output
111124
if False:

0 commit comments

Comments
 (0)