Skip to content

Commit aa20f6e

Browse files
committed
Protect against failed fits
1 parent b041971 commit aa20f6e

File tree

3 files changed

+85
-19
lines changed

3 files changed

+85
-19
lines changed

wirecell/test/__main__.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def cli(ctx):
2222
'''
2323
pass
2424

25+
2526
@cli.command("plot")
2627
@click.option("-n", "--name", default="noise",
2728
help="The test name")
@@ -37,7 +38,6 @@ def plot(ctx, name, datafile, output):
3738
fp = ario.load(datafile)
3839
with plottools.pages(output) as out:
3940
mod.plot(fp, out)
40-
4141

4242

4343
def ssss_args(func):
@@ -53,8 +53,11 @@ def ssss_args(func):
5353
@functools.wraps(func)
5454
def wrapper(*args, **kwds):
5555

56-
kwds["splat"] = ssss.load_frame(kwds.pop("splat"))
57-
kwds["signal"] = ssss.load_frame(kwds.pop("signal"))
56+
kwds["splat_filename"] = kwds.pop("splat")
57+
kwds["signal_filename"] = kwds.pop("signal")
58+
59+
kwds["splat"] = ssss.load_frame(kwds["splat_filename"])
60+
kwds["signal"] = ssss.load_frame(kwds["signal_filename"])
5861

5962
channel_ranges = kwds.pop("channel_ranges")
6063
if channel_ranges:
@@ -74,6 +77,8 @@ def plot_ssss(channel_ranges, nsigma, nbins, splat, signal, output,
7477
Perform the simple splat / sim+signal process comparison test and make plots.
7578
'''
7679

80+
nminsig = 3 # sanity check
81+
7782
with pages(output) as out:
7883

7984
ssss.plot_frames(splat, signal, channel_ranges, title)
@@ -100,6 +105,13 @@ def plot_ssss(channel_ranges, nsigma, nbins, splat, signal, output,
100105

101106
spl_qch = numpy.sum(spl.activity[bbox], axis=1)
102107
sig_qch = numpy.sum(sig.activity[bbox], axis=1)
108+
109+
nspl = len(spl_qch)
110+
nsig = len(sig_qch)
111+
if nspl != nsig or nsig < nminsig:
112+
log.error(f'error: bad signals: {nspl=} {nsig=} {pln=} {ch=}')
113+
raise ValueError(f'bad signals: {nspl=} {nsig=}')
114+
103115
byplane.append((spl_qch, sig_qch))
104116

105117

@@ -132,7 +144,14 @@ def ssss_metrics(channel_ranges, nsigma, nbins, splat, signal, output, params, *
132144
spl_qch = numpy.sum(spl.activity[bbox], axis=1)
133145
sig_qch = numpy.sum(sig.activity[bbox], axis=1)
134146

135-
m = ssss.calc_metrics(spl_qch, sig_qch, nbins)
147+
try:
148+
m = ssss.calc_metrics(spl_qch, sig_qch, nbins)
149+
except Exception as err:
150+
splat_filename = kwds['splat_filename']
151+
signal_filename = kwds['signal_filename']
152+
log.error(f'error: ({err}) failed to calculate metrics for {pln=} {ch=} {splat_filename=} {signal_filename=}')
153+
m = ssss.Metrics()
154+
136155
metrics.append(dataclasses.asdict(m))
137156

138157
if params:
@@ -149,8 +168,10 @@ def ssss_metrics(channel_ranges, nsigma, nbins, splat, signal, output, params, *
149168
help="PDF file in which to plot metrics")
150169
@click.option("--coordinate-plane", default=None, type=int,
151170
help="Use given plane number as global coordinates plane, default uses per-plane coordinates")
171+
@click.option("-t","--title", default="",
172+
help="The title string")
152173
@click.argument("files",nargs=-1)
153-
def plot_metrics(output, coordinate_plane, files):
174+
def plot_metrics(output, coordinate_plane, title, files):
154175
'''Plot per-plane metrics from files.
155176
156177
Files are as produced by ssss-metrics and must include a "params" key.
@@ -176,9 +197,15 @@ def add(k,v):
176197

177198
pmet = met[plane]
178199
add('ineff', pmet['ineff'])
179-
add('bias', pmet['fit']['avg'])
180-
hi = pmet['fit']['hi']
181-
lo = pmet['fit']['lo']
200+
fit = pmet['fit']
201+
if fit is None:
202+
add('bias', 1) # fixme: best way to show failure?
203+
add('reso', 1)
204+
continue
205+
206+
add('bias', fit['avg'])
207+
hi = fit['hi']
208+
lo = fit['lo']
182209
add('reso', 0.5*(hi+lo) )
183210
continue;
184211

@@ -193,9 +220,15 @@ def add(k,v):
193220
add('ty', par['theta_y_wps'][plane])
194221
add('txz', par['theta_xz_wps'][plane])
195222
add('ineff', pmet['ineff'])
196-
add('bias', pmet['fit']['avg'])
197-
hi = pmet['fit']['hi']
198-
lo = pmet['fit']['lo']
223+
fit = pmet['fit']
224+
if fit is None:
225+
add('bias', 1) # fixme: best way to show failure?
226+
add('reso', 1)
227+
continue
228+
229+
add('bias', fit['avg'])
230+
hi = fit['hi']
231+
lo = fit['lo']
199232
add('reso', 0.5*(hi+lo) )
200233

201234

@@ -207,11 +240,13 @@ def add(k,v):
207240
pcolors = ('#58D453', '#7D99D1', '#D45853')
208241

209242
fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True)
243+
if title:
244+
title = ' - ' + title
210245
if coordinate_plane is None:
211-
fig.suptitle("Per-plane angles")
246+
fig.suptitle("Per-plane angles" + title)
212247
else:
213248
letter = "UVW"[coordinate_plane]
214-
fig.suptitle(f'Global angles ({letter}-plane)')
249+
fig.suptitle(f'Global angles ({letter}-plane)' + title)
215250

216251
todeg = 180/numpy.pi
217252
# xlabs = [f'{txz}/{ty}' for txz,ty in zip(

wirecell/test/ssss.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
baseline_noise,
1616
gauss as gauss_func
1717
)
18+
import logging
19+
log = logging.getLogger("wirecell.test")
20+
1821

1922
def relbias(a,b):
2023
'''
@@ -197,17 +200,17 @@ def plot_plane(spl_act, sig_act, nsigma=3.0, title=""):
197200
class Metrics:
198201
'''Metrics about a signal vs splat'''
199202

200-
neor: int
203+
neor: int = 0
201204
''' Number of channels with activity in either the signal or splat (or both)
202205
and over which the rest are calculated. This can be less than the number of
203206
channels in the original "activity" arrays if any given channel has zero
204207
activity in both "signal" and "splat". '''
205208

206-
ineff: float
209+
ineff: float = -1
207210
''' The relative inefficiency. This is the fraction of channels with splat
208211
but with zero signal. '''
209212

210-
fit: BaselineNoise
213+
fit: BaselineNoise | None = None
211214
'''
212215
Gaussian fit to relative difference. .mu is bias and .sigma is resolution.
213216
'''
@@ -220,6 +223,12 @@ def calc_metrics(spl_qch, sig_qch, nbins=50):
220223
- nbins :: the number of bins over which to fit the relative difference.
221224
'''
222225

226+
nspl = len(spl_qch)
227+
nsig = len(sig_qch)
228+
229+
if nspl != nsig:
230+
raise ValueError(f'length mismatch {nspl=} != {nsig=}')
231+
223232
# either-or, exclude channels where both are zero
224233
eor = numpy.logical_or (spl_qch > 0, sig_qch > 0)
225234
# both are nonzero
@@ -247,7 +256,13 @@ def plot_metrics(splat_signal_activity_pairs, nbins=50, title="", letters="UVW")
247256
fig, axes = plt.subplots(nrows=2, ncols=3, sharey="row")
248257
for pln, (spl_qch, sig_qch) in enumerate(splat_signal_activity_pairs):
249258

250-
m = calc_metrics(spl_qch, sig_qch, nbins)
259+
try:
260+
m = calc_metrics(spl_qch, sig_qch, nbins)
261+
except:
262+
log.error(f'error: failed to get metric for {pln=} {spl_qch.size=} {sig_qch.size=} {nbins=} {title=}')
263+
log.debug(f'skipped splat: {spl_qch=}')
264+
log.debug(f'skipped signal: {sig_qch=}')
265+
continue
251266
counts, edges = m.fit.hist
252267
model = gauss_func(edges[:-1], m.fit.A, m.fit.mu, m.fit.sigma)
253268

wirecell/util/peaks.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from wirecell.util.codec import dataclass_dictify
1717
from wirecell.util.bbox import union as union_bbox
1818

19+
import logging
20+
log = logging.getLogger("wirecell.util")
21+
1922
sqrt2pi = sqrt(2*pi)
2023

2124
def gauss(x, A, mu, sigma, *p):
@@ -45,6 +48,11 @@ class BaselineNoise:
4548
Width (fit standard deviation)
4649
'''
4750

51+
N : int
52+
'''
53+
Number of samples
54+
'''
55+
4856
C : float
4957
'''
5058
Normalization (sum)
@@ -93,11 +101,18 @@ def baseline_noise(array, bins=200, vrange=100):
93101
defines an extent about the MEDIAN VALUE. If it is a tuple it gives this
94102
extent explicitly or if scalar the extent is symmetric, ie median+/-vrange.
95103
104+
This will raise exceptions:
105+
106+
- ZeroDivisionError when the signal in the vrange is zero.
107+
108+
- RuntimeError when the fit fails.
109+
96110
'''
111+
nsig = len(array)
97112
lo, med, hi = numpy.quantile(array, [0.5-0.34,0.5,0.5+0.34])
98113

99114
if not isinstance(vrange, tuple):
100-
vrange=(-vrange, vrange)
115+
vrange=(med-vrange, med+vrange)
101116
vrange=(med+vrange[0], med+vrange[1])
102117

103118
hist = numpy.histogram(array, bins=bins, range=vrange)
@@ -113,11 +128,12 @@ def baseline_noise(array, bins=200, vrange=100):
113128
(A,mu,sig),cov = curve_fit(gauss, edges[:-1], counts, p0=p0)
114129
except RuntimeError:
115130
cov = None
131+
116132
return BaselineNoise(A=A, mu=mu, sigma=sig,
133+
N=nsig,
117134
C=C, avg=avg, rms=rms,
118135
med=med, lo=lo, hi=hi,
119136
cov=cov, hist=hist)
120-
121137

122138
@dataclasses.dataclass
123139
@dataclass_dictify

0 commit comments

Comments
 (0)