Conversation
…r.py Move create_dist_sampler(), SamplerInput, and SamplerRuntime out of dist_sampling_producer.py into a shared utils module so they can be reused by the upcoming SharedDistSamplingBackend. Also rename `w` -> `worker` in DistSamplingProducer.init() for clarity. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce SharedDistSamplingBackend which manages a pool of worker processes servicing multiple compute-rank channels through a fair-queued round-robin scheduler. This replaces the per-channel producer model in graph-store mode with a shared backend + lightweight per-channel state. Includes tests for pure business logic helpers (_compute_num_batches, _epoch_batch_indices, _compute_worker_seeds_ranges), shuffle behavior, and completion reporting. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace the single-step create_sampling_producer with a two-phase API: - init_sampling_backend: creates/reuses a SharedDistSamplingBackend - register_sampling_input: registers a lightweight per-channel input The existing create_sampling_producer/destroy_sampling_producer methods are preserved as bridge methods that delegate to the new API, keeping existing loaders working without changes. Also adds InitSamplingBackendRequest and RegisterBackendRequest message dataclasses, and per-channel fetch stats logging. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Refactor BaseDistLoader to use the two-phase sampling API directly: - init_sampling_backend (shared across all ranks per loader instance) - register_sampling_input (unique per compute rank) Key changes: - Add GroupLeaderInfo, _compute_group_leader, _dispatch_grouped_graph_store_phase for generic leader-elected grouped RPC dispatch - Add _init_graph_store_sampling_backends and _register_graph_store_sampling_inputs - Replace _producer_id_list with _backend_id_list + _channel_id_list - Remove create_sampling_producer/destroy_sampling_producer bridge methods - Keep per-class _counter in each loader (not a global counter) since type-prefixed _backend_key already prevents cross-type collisions - Fix test_multiple_loaders_in_graph_store to use num_compute_nodes=2 so backend-sharing assertions are exercised across ranks Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
/all_test |
GiGL Automation@ 21:22:55UTC : 🔄 @ 22:40:25UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 21:22:57UTC : 🔄 @ 22:32:01UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 21:22:59UTC : 🔄 @ 21:32:06UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 21:22:59UTC : 🔄 @ 22:47:31UTC : ✅ Workflow completed successfully. |
GiGL Automation@ 21:23:00UTC : 🔄 @ 21:31:11UTC : ✅ Workflow completed successfully. |
| time.sleep(group_info.stagger_sleep) | ||
| results = issue_phase_rpcs() if group_info.is_leader else [] | ||
| all_results: list[list[T]] = [[] for _ in range(runtime.world_size)] | ||
| torch.distributed.all_gather_object(all_results, results) |
There was a problem hiding this comment.
Semgrep identified an issue in your code:
The torch.distributed.all_gather_object() call deserializes untrusted data from remote ranks using pickle, allowing arbitrary code execution if an attacker controls any rank in the distributed system.
More details about this
The torch.distributed.all_gather_object() call is using pickle deserialization under the hood to share producer_id_list across all ranks in the distributed system. This means untrusted data from other processes gets automatically deserialized and executed.
An attacker with access to any rank in the distributed training job could craft a malicious pickled object and send it during the all-gather phase. When producer_id_list (or other ranks' data) gets deserialized on your rank, the attacker's code runs immediately with the same privileges as your training process.
For example:
- Attacker gains write access to rank 1's memory or intercepts its state
- They insert a pickled Python object that executes shell commands when unpickled (e.g.,
os.system('steal_data.sh')) - Your rank calls
torch.distributed.all_gather_object(all_producer_ids, producer_id_list) - PyTorch's pickle unpickles all ranks' data, triggering the attacker's code during deserialization on your process
- The shell commands run with your process's credentials, potentially exfiltrating model weights or training data
To resolve this comment:
✨ Commit Assistant Fix Suggestion
- Avoid using
torch.distributed.all_gather_object, as this relies on Python pickling, which may lead to arbitrary code execution if untrusted data is ever deserialized. - Replace the use of
all_gather_objectwith a tensor-based collective, such astorch.distributed.all_gather, by converting your data to a tensor (for example, usetorch.tensor(producer_id_list)). - Predefine the size and type of the tensor holding the gathered data, for example:
all_producer_ids = torch.empty(runtime.world_size * num_producers, dtype=torch.long), wherenum_producersis the expected length ofproducer_id_listfor each rank. - Call
torch.distributed.all_gather([all_producer_ids], producer_id_list_tensor), whereproducer_id_list_tensor = torch.tensor(producer_id_list, dtype=torch.long)for each rank. - After gathering, reconstruct the original data structure as needed from
all_producer_ids. For example, split the combined tensor into per-rank lists using slicing. - If the number of producers may vary, agree on a fixed length and pad with a sentinel value such as
-1so tensors can be safely communicated.
This change ensures only primitive tensor data is shared between ranks, eliminating pickle-related risks.
💬 Ignore this finding
Reply with Semgrep commands to ignore this finding.
/fp <comment>for false positive/ar <comment>for acceptable risk/other <comment>for all other reasons
Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.
You can view more details about this finding in the Semgrep AppSec Platform.
| The leader's RPC results, broadcast to all ranks in the group. | ||
| """ | ||
| all_keys: list[Optional[str]] = [None] * runtime.world_size | ||
| torch.distributed.all_gather_object(all_keys, my_key) |
There was a problem hiding this comment.
Semgrep identified an issue in your code:
Using torch.distributed.all_gather_object() to share my_worker_key creates an arbitrary code execution risk, since pickle deserialization can execute attacker-controlled code if a compromised worker sends malicious data.
More details about this
The torch.distributed.all_gather_object() call on this line uses pickle to serialize and deserialize the my_worker_key string across distributed processes. An attacker who can control the data sent from any worker process could craft a malicious pickle payload that executes arbitrary code when deserialized.
Exploit scenario:
- An attacker compromises or spoofs one of the worker processes in the distributed training cluster
- They set
my_worker_keyto a malicious pickle-serialized object instead of a normal string - When this line executes, PyTorch deserializes the pickle object on all receiving ranks
- The malicious pickle payload executes arbitrary code on those processes with the same privileges as the training job, potentially stealing model weights, injecting backdoors, or exfiltrating data
To resolve this comment:
✨ Commit Assistant Fix Suggestion
- Avoid using
torch.distributed.all_gather_object, as it uses pickle internally and can allow arbitrary code execution if untrusted data is deserialized. - If exchanging strings across ranks, switch to using
torch.distributed.all_gather, which is safe for tensors. You can do this by encoding your strings to byte tensors before gathering and decoding after. - Replace the vulnerable line with logic similar to:
- Convert the local string to bytes:
my_worker_key_bytes = my_worker_key.encode('utf-8') - Find the maximum length of all keys to ensure tensors are the same size across ranks. This usually requires an
all_reduceto get the max. For example:
key_len_tensor = torch.tensor([len(my_worker_key_bytes)], device='cpu')
max_len_tensor = key_len_tensor.clone()
torch.distributed.all_reduce(max_len_tensor, op=torch.distributed.ReduceOp.MAX) - Pad your byte string to
max_len_tensor.item():padded_bytes = my_worker_key_bytes.ljust(max_len_tensor.item(), b'\x00') - Create a tensor:
my_worker_key_tensor = torch.ByteTensor(list(padded_bytes)) - Prepare a gather tensor:
all_worker_key_tensors = [torch.empty_like(my_worker_key_tensor) for _ in range(runtime.world_size)] - Call
torch.distributed.all_gather(all_worker_key_tensors, my_worker_key_tensor) - After gathering, decode each tensor:
[bytes(t.tolist()).rstrip(b'\x00').decode('utf-8') for t in all_worker_key_tensors]
- Convert the local string to bytes:
- Replace all uses of
all_worker_keyswith the decoded string list.
Alternatively, if all ranks already know or can deterministically construct the set of worker keys, you can avoid broadcasting entirely by constructing the list locally.
Using tensors for communication prevents vulnerabilities from deserialization attacks, as tensor operations do not use pickle.
💬 Ignore this finding
Reply with Semgrep commands to ignore this finding.
/fp <comment>for false positive/ar <comment>for acceptable risk/other <comment>for all other reasons
Alternatively, triage in Semgrep AppSec Platform to ignore the finding created by pickles-in-pytorch-distributed.
You can view more details about this finding in the Semgrep AppSec Platform.
Scope of work done
Where is the documentation for this feature?: N/A
Did you add automated tests or write a test plan?
Updated Changelog.md? NO
Ready for code review?: NO