diff --git a/exir/tensor.py b/exir/tensor.py index b1619d16bdf..eff2bcb1726 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -69,30 +69,32 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]: """ from torch.fx.experimental.symbolic_shapes import ( guard_or_false, - guard_size_oblivious, + guard_or_true, ) - for _, s in enumerate(stride): - if guard_or_false(s == 0): - raise ValueError("0 in strides is not supported for ExecuTorch.") + for s in stride: + torch._check(s != 0, lambda: "0 in strides is not supported for ExecuTorch.") class K(NamedTuple): stride: int def __lt__(self, other): - return guard_size_oblivious(self.stride < other.stride) - - def __gt__(self, other): - return guard_size_oblivious(self.stride > other.stride) - - def __le__(self, other): - return guard_size_oblivious(self.stride <= other.stride) - - def __ge__(self, other): - return guard_size_oblivious(self.stride >= other.stride) - - def __eq__(self, other): - return guard_size_oblivious(self.stride == other.stride) + # For backed/concrete strides this is practically a `<` operation. + # For unbacked, we return True if `<` is statically known, then + # try to answer symbolically with stride-ordering semantics: + # u0 < u0 -> False + # u0 < u1 (no info) -> False + # u0 < 2 * u0 -> True (divisibility) + # 1 < u0 -> True (1 divides anything; unprovable equality treated optimistically) + return ( + guard_or_false( + self.stride < other.stride + ) # statically known inequality + or ( + guard_or_false(other.stride % self.stride == 0) + and guard_or_true(self.stride != other.stride) + ) # symbolic inequality (e.g. u0 < 2048 * u0) + ) sorted_dims = [ i[0] for i in sorted(enumerate(stride), key=lambda x: K(x[1]), reverse=True) diff --git a/exir/tests/test_tensor.py b/exir/tests/test_tensor.py index c5383b0dac2..d29c059a96a 100644 --- a/exir/tests/test_tensor.py +++ b/exir/tests/test_tensor.py @@ -246,9 +246,52 @@ def test_dim_order_from_stride(self) -> None: # dim[2] is broadcasting dim # shape = (5, 1, 15, 10) strides = (10, 10, 0, 1) - with self.assertRaises(ValueError): + # torch._check raises RuntimeError on concrete 0. + with self.assertRaises(RuntimeError): dim_order = dim_order_from_stride(strides) + def test_dim_order_from_stride_unbacked(self) -> None: + """ + dim_order_from_stride should produce a sane permutation even when the + strides contain unbacked SymInts. The comparator falls back to + divisibility-based reasoning so common cases like (1, u0) and + (u0, 2 * u0) order correctly. + """ + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + shape_env = ShapeEnv() + u0 = shape_env.create_unbacked_symint() + u1 = shape_env.create_unbacked_symint() + + # 1 < u0 should be True via divisibility (u0 % 1 == 0) + optimistic + # `1 != u0`. Descending sort puts u0 outer, stride 1 inner. + dim_order = dim_order_from_stride((1, u0)) + self.assertEqual((1, 0), dim_order) + + # u0 < 2 * u0 should be True via divisibility ((2*u0) % u0 == 0) and + # provable inequality (u0 != 0 after torch._check). + dim_order = dim_order_from_stride((u0, 2 * u0)) + self.assertEqual((1, 0), dim_order) + + # Mixed concrete + symbolic: (1, u0, 2 * u0). Descending stride order + # is (2*u0, u0, 1) -> indices (2, 1, 0). + dim_order = dim_order_from_stride((1, u0, 2 * u0)) + self.assertEqual((2, 1, 0), dim_order) + + # u0 < u1 (independent unbackeds) is genuinely ambiguous; stable sort + # preserves original order under reverse=True (no swap on ambiguous). + dim_order = dim_order_from_stride((u0, u1)) + self.assertEqual((0, 1), dim_order) + + # u0 < u0 is False both ways (symmetric); stable sort preserves order. + dim_order = dim_order_from_stride((u0, u0)) + self.assertEqual((0, 1), dim_order) + + # Unbacked stride of 0 (concrete 0 mixed with unbacked) -> RuntimeError + # via torch._check. + with self.assertRaises(RuntimeError): + dim_order_from_stride((u0, 0, 1)) + def test_strides_from_dim_order(self) -> None: sizes = [] dim_order = []