Conversation
…ch api shifts from breaking for us. Also increase conv image size for better test stability
| if check_version_spec("torch", "2.10.0a"): | ||
| from torch.distributed.tensor._ops.registration import ( | ||
| register_prop_rule, | ||
| ) | ||
| else: | ||
| from torch.distributed.tensor._ops.utils import ( | ||
| register_prop_rule, | ||
| ) | ||
|
|
There was a problem hiding this comment.
This was fragile and causing issues in CI, and the primary reason for refactoring this into a shardtensor op instead of DTensor.
| if weight is not None and weight.requires_grad: | ||
| if weight is not None and ctx.needs_input_grad[3]: | ||
| # grad_weight_c = sum_{n, spatial} grad_output * y (per-channel) | ||
| y_c = y.view(N, C, HxW_local) | ||
| grad_out_c = local_grad_output.view(N, C, HxW_local) | ||
| grad_weight = (grad_out_c * y_c).sum(dim=(0, 2)) # (C,) | ||
|
|
||
| if bias is not None and bias.requires_grad: | ||
| if bias is not None and ctx.needs_input_grad[4]: |
There was a problem hiding this comment.
This was a nasty bug: we have to check the context if something needs a grad, not the tensor itself.
There was a problem hiding this comment.
Only since .requires_grad is implicated, want to ask if we'll need to modify/update this fix I'm adding: #1566
It is more related to frozen params with nn.Parameter(..., requires_grad=False) getting silently updated by distribute_module in PyT <= 2.10. Relevant bugged PyT code is here, it has since been fixed on main here
There was a problem hiding this comment.
Your PR shouldn't be affected. This code here is fragile, I believe, because this requires grad state is not reliable inside of this torch.autograd.Function context. Outside of that its OK still.
This isn't a ShardTensor thing, in particular - I think I was doing this subtly wrong previously before.
| # ============================================================================ | ||
| # Layer 1 -- Semi-private conversions (no autograd, no spec inference) | ||
| # ============================================================================ |
There was a problem hiding this comment.
I opted for a design which builds up this dispatch fallback mechanism in stages, so it can be careful and deliberate about what is routing where and when.
So I am calling it with "layers" and each layer is getting more complex in the autograd graph and ShardTensorSpec / TensorSpec interpolation. This one here is the least complex, not doing auto grad, not figuring out specs.
| dtensor = torch.Tensor._dtensor__new__( | ||
| DTensor, st._local_tensor, st._spec, requires_grad=st.requires_grad | ||
| ) |
There was a problem hiding this comment.
This is a newer API. so the fall back is also available too.
| class _DTensorToShardTensor(torch.autograd.Function): | ||
| r"""Differentiable promotion: DTensor -> ShardTensor. | ||
|
|
||
| This is to always connect the graphs for the backward pass | ||
| when we have to use a fallback option. | ||
|
|
||
| Forward: :func:`_dtensor_to_shard_tensor`. | ||
| Backward: :func:`_shard_tensor_to_dtensor`. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, dtensor: DTensor, spec: ShardTensorSpec) -> "ShardTensor": | ||
| return _dtensor_to_shard_tensor(dtensor, spec) | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output: "ShardTensor"): | ||
| return _shard_tensor_to_dtensor(grad_output), None |
There was a problem hiding this comment.
The differentiable conversions just let autograd attach grad functions correctly to the conversion functions in level 1.
|
|
||
| @staticmethod | ||
| def forward(ctx, st: "ShardTensor") -> DTensor: | ||
| ctx.shard_tensor_spec = st._spec |
There was a problem hiding this comment.
We cache the spec, because in the backward pass it's likely that the spec from the forward pass is correct for the input gradients in the backward pass.
| for arg in input_args: | ||
| if ( | ||
| isinstance(arg, ShardTensor) | ||
| and dtensor._spec.tensor_meta == arg._spec.tensor_meta | ||
| and dtensor._spec.placements == arg._spec.placements | ||
| ): | ||
| return arg._spec | ||
| return _infer_shard_tensor_spec_from_local_chunks( | ||
| dtensor._local_tensor, | ||
| dtensor._spec.mesh, | ||
| dtensor._spec.placements, | ||
| sharding_shapes="chunk", | ||
| global_shape=dtensor.shape, | ||
| ) |
There was a problem hiding this comment.
The goal here is to figure out a correct spec for a DTensor -> ShardTensor promotion based on other args. So, if we're doing a unary op, for example: y = 2*x, we fall back to DTensor, do the op, and get a DTensor back. We don't want to communicate to figure out the sharding shapes, though, so this looks at all the input args to figure out, "Hey, does any input arg match this output well enough to use it's spec?"
We match on the tensor meta (overall shape) + placements + input arg has to be a shard tensor, which seems to be robust here.
Otherwise, we assume the DTensor has DTensor chunking and promote that.
| def _conversion_active() -> bool: | ||
| r"""Return whether ShardTensor<->DTensor conversion is currently active.""" | ||
| return getattr(_conversion_guard, "depth", 0) > 0 | ||
|
|
||
| @contextmanager | ||
| def _conversion_scope(): | ||
| r"""Re-entrant conversion guard for cast-down/cast-up paths.""" | ||
| previous_depth = getattr(_conversion_guard, "depth", 0) | ||
| _conversion_guard.depth = previous_depth + 1 | ||
| try: | ||
| yield | ||
| finally: | ||
| if previous_depth == 0: | ||
| delattr(_conversion_guard, "depth") | ||
| else: | ||
| _conversion_guard.depth = previous_depth | ||
|
|
||
|
|
||
| def _dispatch_fallback_via_dtensor( | ||
| func: torch._ops.OpOverload, | ||
| args: tuple[object, ...], | ||
| kwargs: dict[str, object] | None = None, | ||
| ) -> object: | ||
| r"""Execute an ATen op through DTensor fallback using PURE data conversion. | ||
|
|
||
| Native Autograd wraps this hook, so we must NOT build an internal graph | ||
| using .apply(). We just do the math and let PyTorch track the outer graph. | ||
| """ | ||
| with _conversion_scope(): | ||
| converted_args = tuple(_convert_args_to_dtensor(arg, use_autograd=False) for arg in args) | ||
| converted_kwargs = { | ||
| k: _convert_args_to_dtensor(v, use_autograd=False) for k, v in (kwargs or {}).items() | ||
| } | ||
|
|
||
| dispatch_res = func(*converted_args, **(converted_kwargs or {})) | ||
|
|
||
| with _conversion_scope(): | ||
| return _convert_results_to_shard_tensor(dispatch_res, args, use_autograd=False) |
There was a problem hiding this comment.
There are two things going on here. This one is critical, actually.
First, there is a possible infinite recurision depth if we're dumb: because the dispatch fallbatch itself uses torch functions to attach autograd functions, for the differentiable paths, if we do not prevent that we will infinitely dispatch and reach the recursion depth in python. So that's fun. So we use this conversion scope and call it in __torch_function__ below, in shard tensor itself, to skip the dispatch conversion if we're already within a conversion scope.
And, we make it thread safe with a thread local counter rather than a boolean. Got my eye on you, free threaded python ...
There was a problem hiding this comment.
I can tell you had fun debugging this 😱 I find Layer 3 to be the gnarliest among the conversion layers
There was a problem hiding this comment.
It actually took me 1+ month, off and on, coming back to this and poking at it to track down what the heck was happening in all of this.
| match arg: | ||
| case ShardTensor(): | ||
| if use_autograd and arg.requires_grad and torch.is_grad_enabled(): | ||
| return _ShardTensorToDTensor.apply(arg) | ||
| return _shard_tensor_to_dtensor(arg) | ||
| case DTensor(): | ||
| # DTensor can be iterable; exit early deliberately | ||
| return arg | ||
| case Mapping(): | ||
| return type(arg)({k: _convert_args_to_dtensor(v, use_autograd) for k, v in arg.items()}) | ||
| case tuple(): | ||
| return tuple(_convert_args_to_dtensor(a, use_autograd) for a in arg) | ||
| case list(): | ||
| return [_convert_args_to_dtensor(a, use_autograd) for a in arg] | ||
| case _: | ||
| return arg |
There was a problem hiding this comment.
I think I caught most of the cases, here, but if I missed any we can add them.
| def _convert_results_to_shard_tensor( | ||
| result: object, input_args: tuple, use_autograd: bool = False | ||
| ) -> object: | ||
| r"""Recursively replace DTensors with ShardTensors in an op result. | ||
|
|
||
| If use_autograd is True, uses Layer 2 to preserve the graph connection. | ||
| Handles None returns gracefully for inplace ATen operations. | ||
| """ |
There was a problem hiding this comment.
This is the function that is meant to take outputs from DTensor and turn them back to ShardTensor, if they shouold be ShardTensor.
|
|
||
|
|
||
| class ShardTensor(DTensor): | ||
| class ShardTensor(torch.Tensor): |
There was a problem hiding this comment.
Now straight from the torch.Tensor tap!
| def __tensor_flatten__(self): | ||
| return ["_local_tensor"], (self._spec, self.requires_grad) |
There was a problem hiding this comment.
This isn't really doing anything yet. We're going to need this for compilation, but it's not yet enabled... we could cut this for this PR if you like.
| def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): | ||
| spec, requires_grad = flatten_spec | ||
| local_tensor = inner_tensors["_local_tensor"] | ||
| unflatten_meta = TensorMeta( | ||
| shape=outer_size, | ||
| stride=outer_stride, | ||
| dtype=spec.tensor_meta.dtype, | ||
| ) | ||
| unflatten_spec = ShardTensorSpec( | ||
| mesh=spec.mesh, | ||
| placements=spec.placements, | ||
| tensor_meta=unflatten_meta, | ||
| _local_shape=local_tensor.shape, | ||
| _sharding_shapes=spec._sharding_shapes, | ||
| ) | ||
| return ShardTensor.__new__( | ||
| ShardTensor, | ||
| local_tensor=local_tensor.requires_grad_(requires_grad), | ||
| spec=unflatten_spec, | ||
| requires_grad=requires_grad, | ||
| ) |
There was a problem hiding this comment.
Similar to flatten. We could cut for now.
| if isinstance(dtensor, ShardTensor): | ||
| return dtensor | ||
| spec = _resolve_spec_for_dtensor(dtensor) | ||
| if dtensor.grad_fn is not None: | ||
| return _DTensorToShardTensor.apply(dtensor, spec) | ||
| return _dtensor_to_shard_tensor(dtensor, spec) |
There was a problem hiding this comment.
Fall back to workers, above.
| if _conversion_active(): | ||
| # When converting shard tensor to dtensor, or dtensor to shard tensor, | ||
| # we just run the function without ShardTensor dispatch. | ||
| with torch._C.DisableTorchFunctionSubclass(): | ||
| return func(*args, **kwargs) | ||
| if func in cls._function_registry and cls._enable_shard_patches: | ||
| return cls._function_registry[func](func, types, args, kwargs) | ||
| if str(func) in cls._named_function_registry and cls._enable_shard_patches: | ||
| return cls._named_function_registry[str(func)](func, types, args, kwargs) | ||
| res = _torch_function_fallback_via_dtensor(func, args, kwargs) | ||
| return res |
There was a problem hiding this comment.
This is now all streamlined to the above functions. It's really just a conversion check (are we converting? Yes? Then the function being run is the conversion, so run it right away and return) otherwise check the handlers and then go to the fall back handler.
| # Use a handler, if we have one: | ||
| handler = cls._dispatch_registry.get(func) | ||
| if handler is None: | ||
| handler = cls._dispatch_registry_by_name.get(str(func)) | ||
| if handler is not None: | ||
| return handler(*args, **kwargs) | ||
| # Otherwise, try the dtensor route: | ||
| return _dispatch_fallback_via_dtensor(func, args, kwargs) |
There was a problem hiding this comment.
Likewise, streamlined. Check the registry, and then go for the fallback.
No conversion check here, that's not a dispatch op.
| class FSDPOutputTensorAdapter(nn.Module): | ||
| """Wrap a module and convert ShardTensor outputs to torch.Tensor.""" | ||
|
|
||
| def __init__(self, module: nn.Module) -> None: | ||
| super().__init__() | ||
| self.module = module | ||
|
|
||
| def forward(self, *args, **kwargs): | ||
| out = self.module(*args, **kwargs) | ||
| return out.to_local() if isinstance(out, ShardTensor) else out | ||
|
|
||
|
|
||
| def wrap_for_fsdp(module: nn.Module) -> nn.Module: | ||
| """Return a module wrapper that exposes tensor outputs for FSDP hooks.""" | ||
| return FSDPOutputTensorAdapter(module) | ||
|
|
||
|
|
||
| def distribute_over_domain_for_fsdp( | ||
| module: nn.Module, | ||
| device_mesh: DeviceMesh, | ||
| partition_fn: (Callable[[str, nn.Module, DeviceMesh], None] | None) = None, | ||
| ) -> nn.Module: | ||
| """Distribute a module over a domain mesh and adapt outputs for FSDP.""" | ||
| distributed_module = distribute_module( | ||
| module, | ||
| device_mesh=device_mesh, | ||
| partition_fn=partition_fn, | ||
| ) | ||
| return wrap_for_fsdp(distributed_module) |
There was a problem hiding this comment.
These are new, which I did to let us plug models into FSDP more easily and try to fix some of these bugs. I am not certain they are necessary or not any more with some of the other fixes.
There was a problem hiding this comment.
I did to let us plug models into FSDP more easily and try to fix some of these bugs.
Can you elaborate on the bugs in question? 😅
There was a problem hiding this comment.
I believe we had some reported user issues with FSDP, no? Didn't that bug report come from a user to me via you? 😅
I guess if we can't reproduce any bugs we don't need this API....
| backward, | ||
| ): | ||
| """Test view (6,) -> (2, 3, 1) with Shard(0): trailing dim must stay in group. | ||
| """Test view (48,) -> (8, 6, 1) with Shard(0): trailing singleton in target. |
There was a problem hiding this comment.
Just making it bigger so the view is actually defined properly on more than 2 GPUs.
|
|
||
| @pytest.mark.multigpu_static | ||
| @pytest.mark.parametrize("H", [32, 256]) | ||
| @pytest.mark.parametrize("H", [128, 256]) |
There was a problem hiding this comment.
This is just for cudnn stability, 32 /2 and 32 / 4 were really numerically unstable.
| def scatter_tensor_requires_grad_contract_worker(mesh, requires_grad: bool): | ||
| r"""Validate scatter_tensor construction contract for requires_grad modes.""" | ||
| dm = DistributedManager() | ||
| rank = dm.rank | ||
| global_shape, placements = init_global_shape_and_placements(mesh) | ||
| source = 0 | ||
|
|
||
| if rank == source: | ||
| raw_data = torch.randn(global_shape, device=torch.device(f"cuda:{dm.local_rank}")) | ||
| else: | ||
| raw_data = None | ||
|
|
||
| st = scatter_tensor( | ||
| raw_data, | ||
| source, | ||
| mesh, | ||
| placements, | ||
| global_shape=torch.Size(global_shape), | ||
| dtype=torch.float32, | ||
| requires_grad=requires_grad, | ||
| ) | ||
|
|
||
| assert st.requires_grad is requires_grad | ||
| if requires_grad: | ||
| assert st.is_leaf |
There was a problem hiding this comment.
This is a new test to make sure if we use scatter_tensor, and say require_grad, it actually obeys that.
Greptile SummaryThis PR refactors
|
| Filename | Overview |
|---|---|
| physicsnemo/domain_parallel/shard_tensor.py | Core refactor — ShardTensor now directly subclasses torch.Tensor instead of DTensor; introduces layered conversion helpers, dispatch/torch_function fallback via DTensor, and FSDP adapter utilities; has a redundant variable assignment and a potential Iterable reconstruction issue in the result conversion path. |
| physicsnemo/domain_parallel/custom_ops/_tensor_ops.py | Replaces DTensor-based unbind_rules with a self-contained unbind_wrapper + _unbind_dispatch; both torch_function and torch_dispatch handlers are registered correctly. |
| physicsnemo/domain_parallel/custom_ops/_reductions.py | Adds build_reduction_result helper that constructs ShardTensor directly via new to avoid autograd side-effects in the dispatch path; uses ShardTensorSpec instead of DTensorSpec throughout. |
| physicsnemo/domain_parallel/shard_utils/view_ops.py | New module providing differentiable view/reshape for ShardTensor; correctly handles shard dimension tracking, dtype-reinterpret view, and registers both torch_function and torch_dispatch handlers. |
| physicsnemo/domain_parallel/init.py | Exports the three new FSDP helpers and renames unbind_rules to unbind_wrapper; None stubs added for the unavailable-version fallback. |
| physicsnemo/domain_parallel/custom_ops/init.py | Single rename: unbind_rules → unbind_wrapper, consistent with tensor_ops.py change. |
| physicsnemo/domain_parallel/shard_utils/normalization_patches.py | Unchanged except for import path update; group norm implementation is intact. |
Reviews (1): Last reviewed commit: "Merge branch 'main' into sharded_view_ba..." | Re-trigger Greptile
| if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): | ||
| return type(result)( | ||
| _convert_results_to_shard_tensor(d, input_args, use_autograd) | ||
| for d in result | ||
| ) | ||
|
|
There was a problem hiding this comment.
NamedTuple-style return types fail reconstruction via generator
type(result)(generator_expression) works for list and tuple, but fails for torch.return_types.* objects (e.g. topk, sort, max, min). These are PyStructSequence types whose constructors expect positional arguments, not a single lazy generator. Any DTensor fallback op with multiple outputs would raise a TypeError here at runtime.
| if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): | |
| return type(result)( | |
| _convert_results_to_shard_tensor(d, input_args, use_autograd) | |
| for d in result | |
| ) | |
| if isinstance(result, Iterable) and not isinstance(result, (str, bytes)): | |
| converted = [ | |
| _convert_results_to_shard_tensor(d, input_args, use_autograd) | |
| for d in result | |
| ] | |
| try: | |
| return type(result)(*converted) | |
| except TypeError: | |
| return type(result)(converted) |
|
Hi @coreyjadams, could you run the test suite here https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/stormcast/test_training.py before merging to make sure the PR doesn't break something unexpected? That test suite is part of an example so it won't run automatically as part of CI but contains many end-to-end tests using ShardTensor. See here https://github.com/NVIDIA/physicsnemo/tree/main/examples/weather/stormcast#testing for instructions on how to run the tests (including multi-GPU tests). Or even better, you could enable torch compiling in the tests and check that it works with domain parallelism! |
|
@jleinonen Thanks for bringing it to my attention - I'll run your tests before merge, yes. |
|
Those stormcast end-to-end tests are likely to catch things, but yeah also the domain parallel diffusion and distributed checkpoint save/load multigpu tests I added should provide a second layer of sanity. |
| # # We need to return a fresh Tensor object there as autograd metadata | ||
| # # will be inplaced into it. So we don't want to pollute the Tensor | ||
| # # object stored in the _local_tensor of this ShardTensor. | ||
| # return local_tensor.view_as(local_tensor) |
There was a problem hiding this comment.
Can drop the commented out bit
| # would convert to DTensor and set on a temporary). | ||
| if requires_grad: | ||
| with torch._C.DisableTorchFunctionSubclass(): | ||
| torch.Tensor.requires_grad.__set__(ret, True) |
There was a problem hiding this comment.
How stable do you think this approach will be across PyT versions?
There was a problem hiding this comment.
As stable as possible - I don't see any issues with this. We were really getting bit on PyTorch moving the DTensor wrapper APIs around in prereleases, which is what NVIDIA was building containers on. That's the #1 reason I pulled the unbind interface up from DTensor into ShardTensor in this PR too, so that we don't have exposure to that API.
| # -- Autograd property overrides ------------------------------------------- | ||
| # The C-level requires_grad is authoritative for autograd engine | ||
| # decisions; we read it first and fall back to _local_tensor for the | ||
| # case where _make_wrapper_subclass didn't propagate it correctly. |
There was a problem hiding this comment.
fall back to _local_tensor for the case where _make_wrapper_subclass didn't propagate it correctly.
Can you explain the scenario that would lead to this happening a bit more?
There was a problem hiding this comment.
I had Claude make a diagram!
tl;dr once scenario is very much an edge case right now but will become a real concern when we enable compile in the future and start dealing with flatten / unflatten.
The second case is if someone manually sets the grad on _local_tensor, somewhere, but shardtensor didn't get set. So we can catch that too.
|
@coreyjadams general question on this:
Does that mean the recommended procedure for data and domain parallel training is still to |
Yes, for today this is still the case. I know we've been mulling that API change and whether it's a good design decision or not. That's not part of this PR at least. I propose we at least postpone any major API changes to next release if you don't mind, and focus on getting to a stable product here :) |
PhysicsNeMo Pull Request
This pull request implements a substantial refactor of shard tensor components. I'll add some motivations as comments on the change log, for clarity, but major highlights:
requires_gradandis_leafcorrectly, which requires care since calling setters on them will go through the dispatch mechanism, fall back to DTensor, and eventually no-op unless we guard it. This took me an embarrassingly long time to debug 😭 .__torch_function__ops are differentiable, while__torch_dispatch__ops are not and attach a gradient function to the outputs. So a lot of care has to be taken, there. I've tried to make the logic chain clean and clear in the codebase... hopefully it is.ShardTensorthat are necesary fortorch.compilesupport. Ops like__tensor_flatten__and__tensor_unflatten__need to be available and we need to implement them ourself for compilation to work properly.I also updated the tests slightly: I increased the conv2d image sizes for stability, and added some tests for new ops.
From a user perspective, this should be nearly transparent: no API changes, all tests still pass, etc. It's basically an under-the-hood refactor. But if we want to enable
torch.compileforShardTensorwe are going to need this.Description
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.