diff --git a/src/executorlib/task_scheduler/interactive/blockallocation.py b/src/executorlib/task_scheduler/interactive/blockallocation.py index dc98c279..1f7474df 100644 --- a/src/executorlib/task_scheduler/interactive/blockallocation.py +++ b/src/executorlib/task_scheduler/interactive/blockallocation.py @@ -1,7 +1,7 @@ import queue import random from concurrent.futures import Future -from threading import Event, Thread +from threading import Event, Lock, Thread from typing import Callable, Optional from executorlib.standalone.command import get_interactive_execute_command @@ -83,6 +83,8 @@ def __init__( self_id = random.getrandbits(128) self._self_id = self_id _interrupt_bootup_dict[self._self_id] = False + alive_workers = [max_workers] + alive_workers_lock = Lock() bootup_events = [Event() for _ in range(self._max_workers)] bootup_events[0].set() self._set_process( @@ -99,6 +101,8 @@ def __init__( if worker_id + 1 < self._max_workers else None ), + "alive_workers": alive_workers, + "alive_workers_lock": alive_workers_lock, }, ) for worker_id in range(self._max_workers) @@ -227,6 +231,8 @@ def _execute_multiple_tasks( restart_limit: int = 0, bootup_event: Optional[Event] = None, next_bootup_event: Optional[Event] = None, + alive_workers: Optional[list] = None, + alive_workers_lock: Optional[Lock] = None, **kwargs, ) -> None: """ @@ -258,6 +264,9 @@ def _execute_multiple_tasks( bootup_event (Event): Event to wait on before submitting the job to the scheduler, ensuring workers are submitted in worker_id order. next_bootup_event (Event): Event to signal after job submission, unblocking the next worker. + alive_workers (list): Single-element list [N] tracking how many worker threads are still alive. Shared across + all worker threads; decremented when a worker is permanently dead. + alive_workers_lock (Lock): Lock protecting alive_workers from concurrent modification. """ if bootup_event is not None: bootup_event.wait() @@ -279,11 +288,13 @@ def _execute_multiple_tasks( ) restart_counter = 0 while True: - if not interface.status and restart_counter > restart_limit: - interface.status = True # no more restarts - interface_initialization_exception = ExecutorlibSocketError( - "SocketInterface crashed during execution." + if not interface.status and restart_counter >= restart_limit: + _drain_dead_worker( + future_queue=future_queue, + alive_workers=alive_workers, + alive_workers_lock=alive_workers_lock, ) + break elif not interface.status: interface.bootup() interface_initialization_exception = _set_init_function( @@ -321,6 +332,47 @@ def _execute_multiple_tasks( task_done(future_queue=future_queue) +def _drain_dead_worker( + future_queue: queue.Queue, + alive_workers: Optional[list] = None, + alive_workers_lock: Optional[Lock] = None, +) -> None: + """Handle a permanently dead worker by recycling or failing its tasks. + + If healthy workers remain, tasks are recycled back into the shared queue + so they can be picked up. If all workers are dead, tasks are failed + immediately with ExecutorlibSocketError. In both cases, the worker's + shutdown message is consumed to prevent hangs in shutdown(). + """ + if alive_workers is not None and alive_workers_lock is not None: + with alive_workers_lock: + if alive_workers[0] > 0: + alive_workers[0] -= 1 + while True: + try: + task_dict = future_queue.get(timeout=1) + except queue.Empty: + continue + if "shutdown" in task_dict and task_dict["shutdown"]: + task_done(future_queue=future_queue) + break + elif "fn" in task_dict and "future" in task_dict: + if alive_workers is not None and alive_workers_lock is not None: + with alive_workers_lock: + has_healthy_workers = alive_workers[0] > 0 + else: + has_healthy_workers = False + if has_healthy_workers: + future_queue.put(task_dict) + task_done(future_queue=future_queue) + else: + f = task_dict.pop("future") + f.set_exception( + ExecutorlibSocketError("SocketInterface crashed during execution.") + ) + task_done(future_queue=future_queue) + + def _set_init_function( interface: SocketInterface, init_function: Optional[Callable] = None, diff --git a/tests/unit/task_scheduler/interactive/test_blockallocation.py b/tests/unit/task_scheduler/interactive/test_blockallocation.py new file mode 100644 index 00000000..dbd78448 --- /dev/null +++ b/tests/unit/task_scheduler/interactive/test_blockallocation.py @@ -0,0 +1,33 @@ +import queue +import unittest +from threading import Lock +from concurrent.futures import Future + +from executorlib.task_scheduler.interactive.blockallocation import _drain_dead_worker +from executorlib.task_scheduler.interactive.shared import task_done +from executorlib.standalone.interactive.communication import ExecutorlibSocketError + + +class TestDrainDeadWorker(unittest.TestCase): + def test_fail_tasks_when_no_workers_remain(self): + future_queue = queue.Queue() + alive_workers = [1] + alive_workers_lock = Lock() + future = Future() + + # Add a task and then the shutdown sentinel + future_queue.put({"fn": lambda: 42, "future": future}) + future_queue.put({"shutdown": True}) + + _drain_dead_worker( + future_queue=future_queue, + alive_workers=alive_workers, + alive_workers_lock=alive_workers_lock, + ) + + # Worker count should be decremented + self.assertEqual(alive_workers[0], 0) + + # Task should fail with ExecutorlibSocketError + with self.assertRaises(ExecutorlibSocketError): + future.result()