diff --git a/Lib/asyncio/exceptions.py b/Lib/asyncio/exceptions.py index 5ece595aad6475..fd02735f9946ec 100644 --- a/Lib/asyncio/exceptions.py +++ b/Lib/asyncio/exceptions.py @@ -4,7 +4,8 @@ __all__ = ('BrokenBarrierError', 'CancelledError', 'InvalidStateError', 'TimeoutError', 'IncompleteReadError', 'LimitOverrunError', - 'SendfileNotAvailableError') + 'SendfileNotAvailableError', + 'WouldBlock') class CancelledError(BaseException): @@ -60,3 +61,7 @@ def __reduce__(self): class BrokenBarrierError(RuntimeError): """Barrier is broken by barrier.abort() call.""" + + +class WouldBlock(Exception): + """Raised by nowait functions when the operation would block.""" diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py index fa3a94764b507a..6a7f4b668c55b5 100644 --- a/Lib/asyncio/locks.py +++ b/Lib/asyncio/locks.py @@ -1,10 +1,13 @@ """Synchronization primitives.""" __all__ = ('Lock', 'Event', 'Condition', 'Semaphore', - 'BoundedSemaphore', 'Barrier') + 'BoundedSemaphore', 'Barrier', + 'CapacityLimiter', 'CapacityLimiterStatistics') import collections +import dataclasses import enum +import math from . import exceptions from . import mixins @@ -615,3 +618,181 @@ def n_waiting(self): def broken(self): """Return True if the barrier is in a broken state.""" return self._state is _BarrierState.BROKEN + + +@dataclasses.dataclass(frozen=True) +class CapacityLimiterStatistics: + """Statistics for a CapacityLimiter.""" + borrowed_tokens: int + total_tokens: int | float + borrowers: tuple[object, ...] + tasks_waiting: int + + +class CapacityLimiter(_ContextManagerMixin, mixins._LoopBoundMixin): + """A capacity limiter that tracks borrowers and supports dynamic capacity. + + Unlike a Semaphore, a CapacityLimiter: + - Tracks which tasks hold tokens, preventing the same task from + acquiring twice (which would deadlock a semaphore). + - Allows dynamic adjustment of total_tokens at runtime. + - Supports acquiring/releasing on behalf of arbitrary objects. + + Usage:: + + limiter = CapacityLimiter(10) + + async with limiter: + # At most 10 tasks can be here concurrently + ... + + """ + + def __init__(self, total_tokens: int | float): + self._validate_tokens(total_tokens) + self._total_tokens: int | float = total_tokens + self._borrowers: set[object] = set() + self._waiters: collections.OrderedDict[object, object] = ( + collections.OrderedDict() + ) + + def __repr__(self): + res = super().__repr__() + extra = (f'borrowed:{self.borrowed_tokens}, ' + f'total:{self._total_tokens}') + if self._waiters: + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' + + @staticmethod + def _validate_tokens(total_tokens): + if not isinstance(total_tokens, (int, float)): + raise TypeError("'total_tokens' must be an int or float") + if isinstance(total_tokens, float) and total_tokens != math.inf: + raise ValueError( + "'total_tokens' must be an integer or math.inf" + ) + if total_tokens < 0: + raise ValueError("'total_tokens' must be >= 0") + + @property + def total_tokens(self) -> int | float: + """The total number of tokens available (read-write).""" + return self._total_tokens + + @total_tokens.setter + def total_tokens(self, value: int | float): + self._validate_tokens(value) + self._total_tokens = value + self._notify_waiters() + + @property + def borrowed_tokens(self) -> int: + """The number of tokens currently borrowed.""" + return len(self._borrowers) + + @property + def available_tokens(self) -> int | float: + """The number of tokens currently available.""" + return self._total_tokens - len(self._borrowers) + + def acquire_nowait(self) -> None: + """Acquire a token on behalf of the current task without blocking. + + Raises WouldBlock if a token is not immediately available. + Raises RuntimeError if the current task already holds a token. + """ + from . import tasks + self.acquire_on_behalf_of_nowait(tasks.current_task()) + + async def acquire(self) -> None: + """Acquire a token on behalf of the current task. + + Blocks until a token is available. + Raises RuntimeError if the current task already holds a token. + """ + from . import tasks + await self.acquire_on_behalf_of(tasks.current_task()) + + def acquire_on_behalf_of_nowait(self, borrower) -> None: + """Acquire a token on behalf of the given borrower without blocking. + + Raises WouldBlock if a token is not immediately available. + Raises RuntimeError if the borrower already holds a token. + """ + if borrower in self._borrowers: + raise RuntimeError( + "this borrower is already holding one of this " + "CapacityLimiter's tokens" + ) + if self._waiters or len(self._borrowers) >= self._total_tokens: + raise exceptions.WouldBlock + self._borrowers.add(borrower) + + async def acquire_on_behalf_of(self, borrower) -> None: + """Acquire a token on behalf of the given borrower. + + Blocks until a token is available. + Raises RuntimeError if the borrower already holds a token. + """ + try: + self.acquire_on_behalf_of_nowait(borrower) + except exceptions.WouldBlock: + pass + else: + return + + fut = self._get_loop().create_future() + self._waiters[borrower] = fut + try: + await fut + except exceptions.CancelledError: + self._waiters.pop(borrower, None) + # If the future was already resolved before we got cancelled, + # we already hold the token — release it and wake the next waiter. + if fut.done() and not fut.cancelled(): + self._borrowers.discard(borrower) + self._notify_waiters() + raise + else: + # Future completed successfully; borrower was added by + # _notify_waiters, nothing more to do. + pass + + def release(self) -> None: + """Release a token on behalf of the current task. + + Raises RuntimeError if the current task does not hold a token. + """ + from . import tasks + self.release_on_behalf_of(tasks.current_task()) + + def release_on_behalf_of(self, borrower) -> None: + """Release a token on behalf of the given borrower. + + Raises RuntimeError if the borrower does not hold a token. + """ + if borrower not in self._borrowers: + raise RuntimeError( + "this borrower is not holding any of this " + "CapacityLimiter's tokens" + ) + self._borrowers.discard(borrower) + self._notify_waiters() + + def _notify_waiters(self): + """Wake up waiters while capacity is available.""" + while self._waiters and len(self._borrowers) < self._total_tokens: + borrower, fut = self._waiters.popitem(last=False) + if not fut.done(): + self._borrowers.add(borrower) + fut.set_result(None) + + def statistics(self) -> CapacityLimiterStatistics: + """Return statistics about the current state of the limiter.""" + return CapacityLimiterStatistics( + borrowed_tokens=len(self._borrowers), + total_tokens=self._total_tokens, + borrowers=tuple(self._borrowers), + tasks_waiting=len(self._waiters), + ) diff --git a/Lib/test/test_asyncio/test_locks.py b/Lib/test/test_asyncio/test_locks.py index e025d2990a3f8a..7add7a286856ea 100644 --- a/Lib/test/test_asyncio/test_locks.py +++ b/Lib/test/test_asyncio/test_locks.py @@ -1821,5 +1821,326 @@ async def coro(): self.assertEqual(barrier1.n_waiting, 0) +class CapacityLimiterTests(unittest.IsolatedAsyncioTestCase): + + async def test_capacity_limiter_basic(self): + limiter = asyncio.CapacityLimiter(2) + self.assertEqual(limiter.total_tokens, 2) + self.assertEqual(limiter.borrowed_tokens, 0) + self.assertEqual(limiter.available_tokens, 2) + + await limiter.acquire() + self.assertEqual(limiter.borrowed_tokens, 1) + self.assertEqual(limiter.available_tokens, 1) + + limiter.release() + self.assertEqual(limiter.borrowed_tokens, 0) + self.assertEqual(limiter.available_tokens, 2) + + async def test_acquire_nowait(self): + limiter = asyncio.CapacityLimiter(1) + limiter.acquire_nowait() + self.assertEqual(limiter.borrowed_tokens, 1) + limiter.release() + self.assertEqual(limiter.borrowed_tokens, 0) + + async def test_acquire_nowait_would_block(self): + limiter = asyncio.CapacityLimiter(1) + limiter.acquire_nowait() + # Second acquire from a different borrower should raise WouldBlock + with self.assertRaises(asyncio.WouldBlock): + limiter.acquire_on_behalf_of_nowait("other_borrower") + + async def test_same_borrower_reacquire_error(self): + limiter = asyncio.CapacityLimiter(2) + limiter.acquire_nowait() + # Same task trying to acquire again should raise RuntimeError + with self.assertRaises(RuntimeError): + limiter.acquire_nowait() + + async def test_release_unborrowed_error(self): + limiter = asyncio.CapacityLimiter(1) + with self.assertRaises(RuntimeError): + limiter.release() + + async def test_fifo_fairness(self): + limiter = asyncio.CapacityLimiter(1) + order = [] + + async def waiter(name): + await limiter.acquire() + order.append(name) + limiter.release() + + # Fill the limiter + await limiter.acquire() + + # Start waiters in known order + t1 = asyncio.create_task(waiter("first")) + await asyncio.sleep(0) # let t1 reach the wait point + t2 = asyncio.create_task(waiter("second")) + await asyncio.sleep(0) # let t2 reach the wait point + t3 = asyncio.create_task(waiter("third")) + await asyncio.sleep(0) # let t3 reach the wait point + + # Release the token — waiters should proceed in FIFO order + limiter.release() + await t1 + await t2 + await t3 + + self.assertEqual(order, ["first", "second", "third"]) + + async def test_cancellation_during_acquire(self): + limiter = asyncio.CapacityLimiter(1) + await limiter.acquire() + + async def waiter(): + await limiter.acquire() + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0) # let task reach the wait point + + # Cancel the waiting task + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + # The limiter should still be usable — no leftover waiters + self.assertEqual(limiter.statistics().tasks_waiting, 0) + limiter.release() + self.assertEqual(limiter.borrowed_tokens, 0) + + async def test_dynamic_total_tokens_increase(self): + limiter = asyncio.CapacityLimiter(1) + acquired = [] + + async def waiter(name): + await limiter.acquire() + acquired.append(name) + # Hold the token until released externally + await event.wait() + limiter.release() + + event = asyncio.Event() + await limiter.acquire() + + t1 = asyncio.create_task(waiter("first")) + await asyncio.sleep(0) + t2 = asyncio.create_task(waiter("second")) + await asyncio.sleep(0) + + # Both should be waiting + self.assertEqual(limiter.statistics().tasks_waiting, 2) + + # Increase capacity — should wake waiters + limiter.total_tokens = 3 + await asyncio.sleep(0) + + self.assertEqual(acquired, ["first", "second"]) + self.assertEqual(limiter.borrowed_tokens, 3) + + # Release all and clean up + limiter.release() + event.set() + await t1 + await t2 + self.assertEqual(limiter.borrowed_tokens, 0) + + async def test_dynamic_total_tokens_decrease(self): + limiter = asyncio.CapacityLimiter(3) + await limiter.acquire() + self.assertEqual(limiter.borrowed_tokens, 1) + + # Decrease capacity — doesn't evict current borrowers + limiter.total_tokens = 1 + self.assertEqual(limiter.borrowed_tokens, 1) + self.assertEqual(limiter.available_tokens, 0) + + # New acquires should block + with self.assertRaises(asyncio.WouldBlock): + limiter.acquire_on_behalf_of_nowait("other") + + limiter.release() + + async def test_total_tokens_validation(self): + # Bad types + with self.assertRaises(TypeError): + asyncio.CapacityLimiter("10") + with self.assertRaises(TypeError): + asyncio.CapacityLimiter(None) + + # Negative values + with self.assertRaises(ValueError): + asyncio.CapacityLimiter(-1) + + # Non-inf float + with self.assertRaises(ValueError): + asyncio.CapacityLimiter(1.5) + + # Valid: int and math.inf + import math + limiter = asyncio.CapacityLimiter(0) + self.assertEqual(limiter.total_tokens, 0) + limiter = asyncio.CapacityLimiter(math.inf) + self.assertEqual(limiter.total_tokens, math.inf) + + # Setter validation + limiter = asyncio.CapacityLimiter(1) + with self.assertRaises(TypeError): + limiter.total_tokens = "10" + with self.assertRaises(ValueError): + limiter.total_tokens = -1 + with self.assertRaises(ValueError): + limiter.total_tokens = 1.5 + + async def test_on_behalf_of(self): + limiter = asyncio.CapacityLimiter(2) + borrower1 = "task_a" + borrower2 = "task_b" + + limiter.acquire_on_behalf_of_nowait(borrower1) + self.assertEqual(limiter.borrowed_tokens, 1) + + limiter.acquire_on_behalf_of_nowait(borrower2) + self.assertEqual(limiter.borrowed_tokens, 2) + + # Can't acquire on behalf of same borrower again + with self.assertRaises(RuntimeError): + limiter.acquire_on_behalf_of_nowait(borrower1) + + limiter.release_on_behalf_of(borrower1) + self.assertEqual(limiter.borrowed_tokens, 1) + + limiter.release_on_behalf_of(borrower2) + self.assertEqual(limiter.borrowed_tokens, 0) + + # Can't release a non-borrower + with self.assertRaises(RuntimeError): + limiter.release_on_behalf_of(borrower1) + + async def test_statistics(self): + limiter = asyncio.CapacityLimiter(2) + + stats = limiter.statistics() + self.assertEqual(stats.borrowed_tokens, 0) + self.assertEqual(stats.total_tokens, 2) + self.assertEqual(stats.borrowers, ()) + self.assertEqual(stats.tasks_waiting, 0) + + limiter.acquire_on_behalf_of_nowait("borrower_a") + stats = limiter.statistics() + self.assertEqual(stats.borrowed_tokens, 1) + self.assertIn("borrower_a", stats.borrowers) + self.assertEqual(stats.tasks_waiting, 0) + + limiter.release_on_behalf_of("borrower_a") + + async def test_repr(self): + limiter = asyncio.CapacityLimiter(5) + r = repr(limiter) + self.assertIn('borrowed:0', r) + self.assertIn('total:5', r) + + limiter.acquire_on_behalf_of_nowait("x") + r = repr(limiter) + self.assertIn('borrowed:1', r) + + async def test_context_manager(self): + limiter = asyncio.CapacityLimiter(1) + + async with limiter: + self.assertEqual(limiter.borrowed_tokens, 1) + + self.assertEqual(limiter.borrowed_tokens, 0) + + async def test_concurrent_tasks(self): + limiter = asyncio.CapacityLimiter(2) + active = 0 + max_active = 0 + + async def worker(): + nonlocal active, max_active + await limiter.acquire() + active += 1 + max_active = max(max_active, active) + await asyncio.sleep(0.01) + active -= 1 + limiter.release() + + tasks = [asyncio.create_task(worker()) for _ in range(5)] + await asyncio.gather(*tasks) + + # At most 2 tasks should have been active at once + self.assertLessEqual(max_active, 2) + self.assertEqual(limiter.borrowed_tokens, 0) + + async def test_acquire_on_behalf_of_async(self): + limiter = asyncio.CapacityLimiter(1) + limiter.acquire_on_behalf_of_nowait("blocker") + + acquired = False + + async def waiter(): + nonlocal acquired + await limiter.acquire_on_behalf_of("async_borrower") + acquired = True + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0) + self.assertFalse(acquired) + + limiter.release_on_behalf_of("blocker") + await task + self.assertTrue(acquired) + + limiter.release_on_behalf_of("async_borrower") + self.assertEqual(limiter.borrowed_tokens, 0) + + async def test_nowait_fairness_with_waiters(self): + """nowait should raise WouldBlock even if tokens are available, + when there are waiters queued (to maintain FIFO fairness).""" + limiter = asyncio.CapacityLimiter(1) + limiter.acquire_on_behalf_of_nowait("holder") + + async def waiter(): + await limiter.acquire_on_behalf_of("queued") + + task = asyncio.create_task(waiter()) + await asyncio.sleep(0) + + # Release — the queued waiter should get the token + limiter.release_on_behalf_of("holder") + await asyncio.sleep(0) + + # Even though the token was just freed, the waiter got it + self.assertEqual(limiter.borrowed_tokens, 1) + + # nowait should fail because there might still be waiters + # (in this case the waiter already got the token, so it's full) + with self.assertRaises(asyncio.WouldBlock): + limiter.acquire_on_behalf_of_nowait("newcomer") + + limiter.release_on_behalf_of("queued") + await task + + async def test_zero_capacity(self): + limiter = asyncio.CapacityLimiter(0) + self.assertEqual(limiter.available_tokens, 0) + with self.assertRaises(asyncio.WouldBlock): + limiter.acquire_on_behalf_of_nowait("x") + + async def test_inf_capacity(self): + import math + limiter = asyncio.CapacityLimiter(math.inf) + # Should be able to acquire many tokens + for i in range(100): + limiter.acquire_on_behalf_of_nowait(f"borrower_{i}") + self.assertEqual(limiter.borrowed_tokens, 100) + self.assertEqual(limiter.available_tokens, math.inf) + for i in range(100): + limiter.release_on_behalf_of(f"borrower_{i}") + + if __name__ == '__main__': unittest.main()