From c26dcf33f4d0e6f6ceaa8a36d149481abc0c98e3 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Wed, 6 May 2026 15:48:26 +0200 Subject: [PATCH] Arm backend: Test avgpool non-square kernels Add a regression test for AVG_POOL2D output shape with a non-square kernel and height-only padding. Change-Id: Ib2f0c15720aa7ba5c15a7406bafbc2d37aa4fa5a Signed-off-by: Sebastian Larsson --- .../test_tosa_dialect_avg_pool2d_adaptive.py | 21 +++++++++++++++++++ backends/arm/tosa/dialect/ops/avg_pool2d.py | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/misc/test_tosa_dialect_avg_pool2d_adaptive.py b/backends/arm/test/misc/test_tosa_dialect_avg_pool2d_adaptive.py index 81b577c708f..014d58eabb4 100644 --- a/backends/arm/test/misc/test_tosa_dialect_avg_pool2d_adaptive.py +++ b/backends/arm/test/misc/test_tosa_dialect_avg_pool2d_adaptive.py @@ -19,6 +19,27 @@ from torch._subclasses.fake_tensor import FakeTensorMode +def test_avg_pool2d_tosa_non_square_kernel_output_shape(): + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + x = mode.from_tensor(torch.randn((1, 20, 20, 8), dtype=torch.float32)) + input_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32)) + output_zp = mode.from_tensor(torch.zeros((1,), dtype=torch.float32)) + + output = exir_ops.backend.tosa.AVG_POOL2D.default( + x, + input_zp, + output_zp, + [2, 3], + [2, 1], + [1, 1, 0, 0], + torch.float32, + ) + + assert tuple(output.shape) == (1, 11, 18, 8) + + def test_avg_pool2d_adaptive_tosa_INT(): sample_inputs = [ ( diff --git a/backends/arm/tosa/dialect/ops/avg_pool2d.py b/backends/arm/tosa/dialect/ops/avg_pool2d.py index 61132db9893..8fcf4c85445 100644 --- a/backends/arm/tosa/dialect/ops/avg_pool2d.py +++ b/backends/arm/tosa/dialect/ops/avg_pool2d.py @@ -136,7 +136,7 @@ def compute_avg_pool2d_output_shape( pad: List[IntLikeType] | List[int], op: str = "AVG_POOL2D", ) -> List[IntLikeType]: - """Compute the output shape for NCHW avg-pool.""" + """Compute the output shape for NHWC avg-pool.""" if x.dim() != 4: raise TosaValueError(f"{op} requires a 4D tensor, got {x.dim()}D", op=op)