Skip to content

ShardTensor Refactor#1556

Open
coreyjadams wants to merge 10 commits intoNVIDIA:mainfrom
coreyjadams:sharded_view_backwards
Open

ShardTensor Refactor#1556
coreyjadams wants to merge 10 commits intoNVIDIA:mainfrom
coreyjadams:sharded_view_backwards

Conversation

@coreyjadams
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

This pull request implements a substantial refactor of shard tensor components. I'll add some motivations as comments on the change log, for clarity, but major highlights:

  • ShardTensor becomes its own class inheriting straight from torch.Tensor, instead of DTensor. We no longer are tied to DTensor API. This frees us to make future decisions about API. It doesn't change today's programming model.
  • One challenge here was ensure we manage the tensor properties about requires_grad and is_leaf correctly, which requires care since calling setters on them will go through the dispatch mechanism, fall back to DTensor, and eventually no-op unless we guard it. This took me an embarrassingly long time to debug 😭 .
  • The logic of the dispatch mechanism has been streamlined and clarified. The challenge is that the __torch_function__ ops are differentiable, while __torch_dispatch__ ops are not and attach a gradient function to the outputs. So a lot of care has to be taken, there. I've tried to make the logic chain clean and clear in the codebase... hopefully it is.
  • I've implemented several components of ShardTensor that are necesary for torch.compile support. Ops like __tensor_flatten__ and __tensor_unflatten__ need to be available and we need to implement them ourself for compilation to work properly.
  • The unbind op, which was previously implemented as an extension to DTensor and we use it as a fall back, is now just implemented as a ShardTensor op. This was causing issues because the PyTorch API kept changing ....
  • I also fixed a couple small bugs. Thanks Claude.

I also updated the tests slightly: I increased the conv2d image sizes for stability, and added some tests for new ops.

From a user perspective, this should be nearly transparent: no API changes, all tests still pass, etc. It's basically an under-the-hood refactor. But if we want to enable torch.compile for ShardTensor we are going to need this.

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Comment on lines -42 to -50
if check_version_spec("torch", "2.10.0a"):
from torch.distributed.tensor._ops.registration import (
register_prop_rule,
)
else:
from torch.distributed.tensor._ops.utils import (
register_prop_rule,
)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was fragile and causing issues in CI, and the primary reason for refactoring this into a shardtensor op instead of DTensor.

Comment on lines -285 to +291
if weight is not None and weight.requires_grad:
if weight is not None and ctx.needs_input_grad[3]:
# grad_weight_c = sum_{n, spatial} grad_output * y (per-channel)
y_c = y.view(N, C, HxW_local)
grad_out_c = local_grad_output.view(N, C, HxW_local)
grad_weight = (grad_out_c * y_c).sum(dim=(0, 2)) # (C,)

if bias is not None and bias.requires_grad:
if bias is not None and ctx.needs_input_grad[4]:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a nasty bug: we have to check the context if something needs a grad, not the tensor itself.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only since .requires_grad is implicated, want to ask if we'll need to modify/update this fix I'm adding: #1566

It is more related to frozen params with nn.Parameter(..., requires_grad=False) getting silently updated by distribute_module in PyT <= 2.10. Relevant bugged PyT code is here, it has since been fixed on main here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your PR shouldn't be affected. This code here is fragile, I believe, because this requires grad state is not reliable inside of this torch.autograd.Function context. Outside of that its OK still.

This isn't a ShardTensor thing, in particular - I think I was doing this subtly wrong previously before.

Comment on lines +54 to +56
# ============================================================================
# Layer 1 -- Semi-private conversions (no autograd, no spec inference)
# ============================================================================
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opted for a design which builds up this dispatch fallback mechanism in stages, so it can be careful and deliberate about what is routing where and when.

So I am calling it with "layers" and each layer is getting more complex in the autograd graph and ShardTensorSpec / TensorSpec interpolation. This one here is the least complex, not doing auto grad, not figuring out specs.

Comment on lines +66 to +68
dtensor = torch.Tensor._dtensor__new__(
DTensor, st._local_tensor, st._spec, requires_grad=st.requires_grad
)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a newer API. so the fall back is also available too.

Comment on lines +107 to +123
class _DTensorToShardTensor(torch.autograd.Function):
r"""Differentiable promotion: DTensor -> ShardTensor.

This is to always connect the graphs for the backward pass
when we have to use a fallback option.

Forward: :func:`_dtensor_to_shard_tensor`.
Backward: :func:`_shard_tensor_to_dtensor`.
"""

@staticmethod
def forward(ctx, dtensor: DTensor, spec: ShardTensorSpec) -> "ShardTensor":
return _dtensor_to_shard_tensor(dtensor, spec)

@staticmethod
def backward(ctx, grad_output: "ShardTensor"):
return _shard_tensor_to_dtensor(grad_output), None
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The differentiable conversions just let autograd attach grad functions correctly to the conversion functions in level 1.


@staticmethod
def forward(ctx, st: "ShardTensor") -> DTensor:
ctx.shard_tensor_spec = st._spec
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cache the spec, because in the backward pass it's likely that the spec from the forward pass is correct for the input gradients in the backward pass.

Comment on lines +158 to +171
for arg in input_args:
if (
isinstance(arg, ShardTensor)
and dtensor._spec.tensor_meta == arg._spec.tensor_meta
and dtensor._spec.placements == arg._spec.placements
):
return arg._spec
return _infer_shard_tensor_spec_from_local_chunks(
dtensor._local_tensor,
dtensor._spec.mesh,
dtensor._spec.placements,
sharding_shapes="chunk",
global_shape=dtensor.shape,
)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal here is to figure out a correct spec for a DTensor -> ShardTensor promotion based on other args. So, if we're doing a unary op, for example: y = 2*x, we fall back to DTensor, do the op, and get a DTensor back. We don't want to communicate to figure out the sharding shapes, though, so this looks at all the input args to figure out, "Hey, does any input arg match this output well enough to use it's spec?"

We match on the tensor meta (overall shape) + placements + input arg has to be a shard tensor, which seems to be robust here.

Otherwise, we assume the DTensor has DTensor chunking and promote that.

Comment on lines +178 to +215
def _conversion_active() -> bool:
r"""Return whether ShardTensor<->DTensor conversion is currently active."""
return getattr(_conversion_guard, "depth", 0) > 0

@contextmanager
def _conversion_scope():
r"""Re-entrant conversion guard for cast-down/cast-up paths."""
previous_depth = getattr(_conversion_guard, "depth", 0)
_conversion_guard.depth = previous_depth + 1
try:
yield
finally:
if previous_depth == 0:
delattr(_conversion_guard, "depth")
else:
_conversion_guard.depth = previous_depth


def _dispatch_fallback_via_dtensor(
func: torch._ops.OpOverload,
args: tuple[object, ...],
kwargs: dict[str, object] | None = None,
) -> object:
r"""Execute an ATen op through DTensor fallback using PURE data conversion.

Native Autograd wraps this hook, so we must NOT build an internal graph
using .apply(). We just do the math and let PyTorch track the outer graph.
"""
with _conversion_scope():
converted_args = tuple(_convert_args_to_dtensor(arg, use_autograd=False) for arg in args)
converted_kwargs = {
k: _convert_args_to_dtensor(v, use_autograd=False) for k, v in (kwargs or {}).items()
}

dispatch_res = func(*converted_args, **(converted_kwargs or {}))

with _conversion_scope():
return _convert_results_to_shard_tensor(dispatch_res, args, use_autograd=False)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two things going on here. This one is critical, actually.

First, there is a possible infinite recurision depth if we're dumb: because the dispatch fallbatch itself uses torch functions to attach autograd functions, for the differentiable paths, if we do not prevent that we will infinitely dispatch and reach the recursion depth in python. So that's fun. So we use this conversion scope and call it in __torch_function__ below, in shard tensor itself, to skip the dispatch conversion if we're already within a conversion scope.

And, we make it thread safe with a thread local counter rather than a boolean. Got my eye on you, free threaded python ...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can tell you had fun debugging this 😱 I find Layer 3 to be the gnarliest among the conversion layers

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually took me 1+ month, off and on, coming back to this and poking at it to track down what the heck was happening in all of this.

Comment on lines +250 to +265
match arg:
case ShardTensor():
if use_autograd and arg.requires_grad and torch.is_grad_enabled():
return _ShardTensorToDTensor.apply(arg)
return _shard_tensor_to_dtensor(arg)
case DTensor():
# DTensor can be iterable; exit early deliberately
return arg
case Mapping():
return type(arg)({k: _convert_args_to_dtensor(v, use_autograd) for k, v in arg.items()})
case tuple():
return tuple(_convert_args_to_dtensor(a, use_autograd) for a in arg)
case list():
return [_convert_args_to_dtensor(a, use_autograd) for a in arg]
case _:
return arg
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I caught most of the cases, here, but if I missed any we can add them.

Comment on lines +268 to +275
def _convert_results_to_shard_tensor(
result: object, input_args: tuple, use_autograd: bool = False
) -> object:
r"""Recursively replace DTensors with ShardTensors in an op result.

If use_autograd is True, uses Layer 2 to preserve the graph connection.
Handles None returns gracefully for inplace ATen operations.
"""
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the function that is meant to take outputs from DTensor and turn them back to ShardTensor, if they shouold be ShardTensor.



class ShardTensor(DTensor):
class ShardTensor(torch.Tensor):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now straight from the torch.Tensor tap!

Comment thread physicsnemo/domain_parallel/shard_tensor.py Outdated
Comment on lines +749 to +750
def __tensor_flatten__(self):
return ["_local_tensor"], (self._spec, self.requires_grad)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really doing anything yet. We're going to need this for compilation, but it's not yet enabled... we could cut this for this PR if you like.

Comment on lines +753 to +773
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
spec, requires_grad = flatten_spec
local_tensor = inner_tensors["_local_tensor"]
unflatten_meta = TensorMeta(
shape=outer_size,
stride=outer_stride,
dtype=spec.tensor_meta.dtype,
)
unflatten_spec = ShardTensorSpec(
mesh=spec.mesh,
placements=spec.placements,
tensor_meta=unflatten_meta,
_local_shape=local_tensor.shape,
_sharding_shapes=spec._sharding_shapes,
)
return ShardTensor.__new__(
ShardTensor,
local_tensor=local_tensor.requires_grad_(requires_grad),
spec=unflatten_spec,
requires_grad=requires_grad,
)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to flatten. We could cut for now.

Comment on lines +861 to +866
if isinstance(dtensor, ShardTensor):
return dtensor
spec = _resolve_spec_for_dtensor(dtensor)
if dtensor.grad_fn is not None:
return _DTensorToShardTensor.apply(dtensor, spec)
return _dtensor_to_shard_tensor(dtensor, spec)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fall back to workers, above.

Comment on lines +872 to +882
if _conversion_active():
# When converting shard tensor to dtensor, or dtensor to shard tensor,
# we just run the function without ShardTensor dispatch.
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
if func in cls._function_registry and cls._enable_shard_patches:
return cls._function_registry[func](func, types, args, kwargs)
if str(func) in cls._named_function_registry and cls._enable_shard_patches:
return cls._named_function_registry[str(func)](func, types, args, kwargs)
res = _torch_function_fallback_via_dtensor(func, args, kwargs)
return res
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now all streamlined to the above functions. It's really just a conversion check (are we converting? Yes? Then the function being run is the conversion, so run it right away and return) otherwise check the handlers and then go to the fall back handler.

Comment on lines +892 to +899
# Use a handler, if we have one:
handler = cls._dispatch_registry.get(func)
if handler is None:
handler = cls._dispatch_registry_by_name.get(str(func))
if handler is not None:
return handler(*args, **kwargs)
# Otherwise, try the dtensor route:
return _dispatch_fallback_via_dtensor(func, args, kwargs)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, streamlined. Check the registry, and then go for the fallback.

No conversion check here, that's not a dispatch op.

Comment on lines +1124 to +1152
class FSDPOutputTensorAdapter(nn.Module):
"""Wrap a module and convert ShardTensor outputs to torch.Tensor."""

def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module

def forward(self, *args, **kwargs):
out = self.module(*args, **kwargs)
return out.to_local() if isinstance(out, ShardTensor) else out


def wrap_for_fsdp(module: nn.Module) -> nn.Module:
"""Return a module wrapper that exposes tensor outputs for FSDP hooks."""
return FSDPOutputTensorAdapter(module)


def distribute_over_domain_for_fsdp(
module: nn.Module,
device_mesh: DeviceMesh,
partition_fn: (Callable[[str, nn.Module, DeviceMesh], None] | None) = None,
) -> nn.Module:
"""Distribute a module over a domain mesh and adapt outputs for FSDP."""
distributed_module = distribute_module(
module,
device_mesh=device_mesh,
partition_fn=partition_fn,
)
return wrap_for_fsdp(distributed_module)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are new, which I did to let us plug models into FSDP more easily and try to fix some of these bugs. I am not certain they are necessary or not any more with some of the other fixes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did to let us plug models into FSDP more easily and try to fix some of these bugs.

Can you elaborate on the bugs in question? 😅

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we had some reported user issues with FSDP, no? Didn't that bug report come from a user to me via you? 😅

I guess if we can't reproduce any bugs we don't need this API....

backward,
):
"""Test view (6,) -> (2, 3, 1) with Shard(0): trailing dim must stay in group.
"""Test view (48,) -> (8, 6, 1) with Shard(0): trailing singleton in target.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making it bigger so the view is actually defined properly on more than 2 GPUs.


@pytest.mark.multigpu_static
@pytest.mark.parametrize("H", [32, 256])
@pytest.mark.parametrize("H", [128, 256])
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for cudnn stability, 32 /2 and 32 / 4 were really numerically unstable.

Comment on lines +124 to +148
def scatter_tensor_requires_grad_contract_worker(mesh, requires_grad: bool):
r"""Validate scatter_tensor construction contract for requires_grad modes."""
dm = DistributedManager()
rank = dm.rank
global_shape, placements = init_global_shape_and_placements(mesh)
source = 0

if rank == source:
raw_data = torch.randn(global_shape, device=torch.device(f"cuda:{dm.local_rank}"))
else:
raw_data = None

st = scatter_tensor(
raw_data,
source,
mesh,
placements,
global_shape=torch.Size(global_shape),
dtype=torch.float32,
requires_grad=requires_grad,
)

assert st.requires_grad is requires_grad
if requires_grad:
assert st.is_leaf
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new test to make sure if we use scatter_tensor, and say require_grad, it actually obeys that.

@coreyjadams coreyjadams marked this pull request as ready for review April 9, 2026 17:20
@coreyjadams coreyjadams requested a review from pzharrington April 9, 2026 17:20
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Greptile Summary

This PR refactors ShardTensor to inherit directly from torch.Tensor instead of DTensor, adds torch.compile support via __tensor_flatten__/__tensor_unflatten__, introduces a layered DTensor fallback dispatch mechanism, and replaces the DTensor-backed unbind with a native implementation.

  • The new _convert_results_to_shard_tensor helper reconstructs multi-valued op results with type(result)(generator), which fails for torch.return_types.* (PyStructSequence) objects — any DTensor fallback op with multiple outputs (e.g. topk, sort, max) would raise a TypeError.

Vulnerabilities

No security concerns identified.

Important Files Changed

Filename Overview
physicsnemo/domain_parallel/shard_tensor.py Core refactor — ShardTensor now directly subclasses torch.Tensor instead of DTensor; introduces layered conversion helpers, dispatch/torch_function fallback via DTensor, and FSDP adapter utilities; has a redundant variable assignment and a potential Iterable reconstruction issue in the result conversion path.
physicsnemo/domain_parallel/custom_ops/_tensor_ops.py Replaces DTensor-based unbind_rules with a self-contained unbind_wrapper + _unbind_dispatch; both torch_function and torch_dispatch handlers are registered correctly.
physicsnemo/domain_parallel/custom_ops/_reductions.py Adds build_reduction_result helper that constructs ShardTensor directly via new to avoid autograd side-effects in the dispatch path; uses ShardTensorSpec instead of DTensorSpec throughout.
physicsnemo/domain_parallel/shard_utils/view_ops.py New module providing differentiable view/reshape for ShardTensor; correctly handles shard dimension tracking, dtype-reinterpret view, and registers both torch_function and torch_dispatch handlers.
physicsnemo/domain_parallel/init.py Exports the three new FSDP helpers and renames unbind_rules to unbind_wrapper; None stubs added for the unavailable-version fallback.
physicsnemo/domain_parallel/custom_ops/init.py Single rename: unbind_rules → unbind_wrapper, consistent with tensor_ops.py change.
physicsnemo/domain_parallel/shard_utils/normalization_patches.py Unchanged except for import path update; group norm implementation is intact.

Reviews (1): Last reviewed commit: "Merge branch 'main' into sharded_view_ba..." | Re-trigger Greptile

Comment on lines +297 to +302
if isinstance(result, Iterable) and not isinstance(result, (str, bytes)):
return type(result)(
_convert_results_to_shard_tensor(d, input_args, use_autograd)
for d in result
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 NamedTuple-style return types fail reconstruction via generator

type(result)(generator_expression) works for list and tuple, but fails for torch.return_types.* objects (e.g. topk, sort, max, min). These are PyStructSequence types whose constructors expect positional arguments, not a single lazy generator. Any DTensor fallback op with multiple outputs would raise a TypeError here at runtime.

Suggested change
if isinstance(result, Iterable) and not isinstance(result, (str, bytes)):
return type(result)(
_convert_results_to_shard_tensor(d, input_args, use_autograd)
for d in result
)
if isinstance(result, Iterable) and not isinstance(result, (str, bytes)):
converted = [
_convert_results_to_shard_tensor(d, input_args, use_autograd)
for d in result
]
try:
return type(result)(*converted)
except TypeError:
return type(result)(converted)

Comment thread physicsnemo/domain_parallel/shard_tensor.py Outdated
Comment thread physicsnemo/domain_parallel/shard_tensor.py
@dallasfoster dallasfoster self-requested a review April 9, 2026 17:33
@jleinonen
Copy link
Copy Markdown
Collaborator

jleinonen commented Apr 13, 2026

Hi @coreyjadams, could you run the test suite here https://github.com/NVIDIA/physicsnemo/blob/main/examples/weather/stormcast/test_training.py before merging to make sure the PR doesn't break something unexpected? That test suite is part of an example so it won't run automatically as part of CI but contains many end-to-end tests using ShardTensor. See here https://github.com/NVIDIA/physicsnemo/tree/main/examples/weather/stormcast#testing for instructions on how to run the tests (including multi-GPU tests).

Or even better, you could enable torch compiling in the tests and check that it works with domain parallelism!

@coreyjadams
Copy link
Copy Markdown
Collaborator Author

@jleinonen Thanks for bringing it to my attention - I'll run your tests before merge, yes.

@pzharrington
Copy link
Copy Markdown
Collaborator

Those stormcast end-to-end tests are likely to catch things, but yeah also the domain parallel diffusion and distributed checkpoint save/load multigpu tests I added should provide a second layer of sanity.

# # We need to return a fresh Tensor object there as autograd metadata
# # will be inplaced into it. So we don't want to pollute the Tensor
# # object stored in the _local_tensor of this ShardTensor.
# return local_tensor.view_as(local_tensor)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can drop the commented out bit

# would convert to DTensor and set on a temporary).
if requires_grad:
with torch._C.DisableTorchFunctionSubclass():
torch.Tensor.requires_grad.__set__(ret, True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How stable do you think this approach will be across PyT versions?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As stable as possible - I don't see any issues with this. We were really getting bit on PyTorch moving the DTensor wrapper APIs around in prereleases, which is what NVIDIA was building containers on. That's the #1 reason I pulled the unbind interface up from DTensor into ShardTensor in this PR too, so that we don't have exposure to that API.

# -- Autograd property overrides -------------------------------------------
# The C-level requires_grad is authoritative for autograd engine
# decisions; we read it first and fall back to _local_tensor for the
# case where _make_wrapper_subclass didn't propagate it correctly.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fall back to _local_tensor for the case where _make_wrapper_subclass didn't propagate it correctly.

Can you explain the scenario that would lead to this happening a bit more?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had Claude make a diagram!

image

tl;dr once scenario is very much an edge case right now but will become a real concern when we enable compile in the future and start dealing with flatten / unflatten.

The second case is if someone manually sets the grad on _local_tensor, somewhere, but shardtensor didn't get set. So we can catch that too.

@pzharrington
Copy link
Copy Markdown
Collaborator

@coreyjadams general question on this:

From a user perspective, this should be nearly transparent: no API changes, all tests still pass, etc. It's basically an under-the-hood refactor.

Does that mean the recommended procedure for data and domain parallel training is still to distribute_module for any params that need sharding, then wrap with FSDP?

@coreyjadams
Copy link
Copy Markdown
Collaborator Author

Does that mean the recommended procedure for data and domain parallel training is still to distribute_module for any params that need sharding, then wrap with FSDP?

Yes, for today this is still the case. I know we've been mulling that API change and whether it's a good design decision or not. That's not part of this PR at least. I propose we at least postpone any major API changes to next release if you don't mind, and focus on getting to a stable product here :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants