Make each thread have its own default stream#3281
Conversation
d7f5291 to
46c181b
Compare
Metal completion handlers run on dispatch queues where C++ exceptions cannot propagate — throwing causes std::terminate → SIGABRT, crashing the process with no diagnostic information. Instead, store the error message atomically in the CommandEncoder and check it at the next synchronous point (commit, synchronize). This converts fatal crashes into catchable runtime_error exceptions that the application can handle gracefully. Root cause analysis: the crash at 262K+ context reported as mlx#3216 was actually TWO separate issues: 1. Thread safety in stream management (fixed by PR ml-explore#3281) 2. C++ exceptions thrown from Metal completion handler callbacks (fixed by this commit) The GPU watchdog error (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) is a separate concern — macOS kills command buffers that block the GPU beyond the watchdog threshold. This commit ensures that error is reported as a Python RuntimeError instead of SIGABRT.
Fused SDPA regression on this branchWhile testing my chunked SDPA work (#3293, based on this branch), I discovered that Reproductionimport mlx.core as mx
import numpy as np
B, H, qL, D = 1, 2, 256, 128
scale = 1.0 / np.sqrt(D)
mx.random.seed(42)
q = mx.random.normal((B, H, qL, D)).astype(mx.bfloat16)
k = mx.random.normal((B, H, qL, D)).astype(mx.bfloat16)
v = mx.random.normal((B, H, qL, D)).astype(mx.bfloat16)
# Fused SDPA
o_fused = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
# Manual reference
q32, k32, v32 = q.astype(mx.float32), k.astype(mx.float32), v.astype(mx.float32)
o_ref = mx.softmax((q32 @ k32.swapaxes(-1, -2)) * scale, axis=-1) @ v32
mx.eval(o_fused, o_ref)
diff = (o_fused.astype(mx.float32) - o_ref).abs()
mx.eval(diff)
print(f"max_diff={diff.max().item():.4f}") # ~14.4 on this branch, ~0.001 on mainResults
The fused output is ~25x wrong in magnitude. Consistent across all head dims (64, 80, 128, 256) and both float16/bfloat16. The full attention path ( The SDPA kernel code itself is identical between the two branches — the regression must come from the stream/device management changes. Possibly Models using explicit matmul + softmax (e.g., mlx-lm's attention implementation) are unaffected since they don't use the fused SDPA path. |
|
Hmm the script produces same result on this branch and main, and our tests would have caught it if the result goes wrong. |
|
Fair enough — I suspect this was a stale Metal JIT cache on my side. I was switching between branches (main, thread-local-streams, and my fix branch) with partial file checkouts and pip reinstalls, which likely left cached kernel binaries from one branch being used with host code from another. The buffer binding changes in my chunked SDPA work (adding write_partial function constants) would cause exactly this kind of mismatch. Apologies for the noise — I should have done a clean build before reporting. I'll verify with a fresh clone if I see it again. |
angeloskath
left a comment
There was a problem hiding this comment.
This looks great!
I left basically one comment that needs addressing regarding the (natural) assumption that device indices will be less than the available devices. I think the fix should be in the device constructor.
| throw std::invalid_argument( | ||
| "[default_stream] Cannot get gpu stream without gpu backend."); | ||
| } | ||
| auto& s = default_stream_storage(d); |
There was a problem hiding this comment.
Well this is not necessarily a bug in this code but Device can have any index weirdly. ie I can make Device(Device::gpu, 7) and pass it to default_stream which will access out of bounds memory.
So for this code to be correct I think the constructor of Device needs to check that 0 <= index < device_count(dev_type).
There was a problem hiding this comment.
Checking index in Device::Device would break code like is_available(Device::gpu) which constructs an invalid device first and then checks it.
I changed default_stream_storage to do bound check by using default_streams.at(d.index).
There was a problem hiding this comment.
Yeah that makes sense. I thought of that but then I thought it is generally weird that we can create arbitrary devices but maybe that's fine.
There was a problem hiding this comment.
It is indeed a weird design, I think the API should be is_available(DeviceType type, int index) which should be compatible with most C++ code but would require API change in mlx-c, not sure if we should change it @andresy.
angeloskath
left a comment
There was a problem hiding this comment.
Forgot to approve before. I still think we should fix the device index bug before merging this.
f07eb26 to
a6df64d
Compare
a6df64d to
adb6cdb
Compare
|
I was hitting the exact Metal assertion this fixes ( |
Refs #3078, #3216.
Make sure each thread gets a different stream when using
get_default_stream(), which would make multi-thread code safe and lock-free by default.Changes:
get_stream(int index), which would require locks and is not public API.