diff --git a/train_iter.py b/train_iter.py index d750e1b..70c011e 100644 --- a/train_iter.py +++ b/train_iter.py @@ -5,6 +5,7 @@ import os import sys import time +import threading from src.arrival_scale import validate_job_arrival_scale from src.workloadgen_cli import add_workloadgen_args, build_workloadgen_cli_args @@ -153,57 +154,255 @@ def build_command( return command -def run_all_parallel(combinations, max_parallel, iter_limit_per_step, session, prices, - job_durations, jobs, hourly_jobs, job_arrival_scale, jobs_exact_replay, jobs_exact_replay_aggregate, plot_dashboard, dashboard_hours, - seeds, seed_sweep, evaluate_savings, eval_months, workloadgen_args): - active = [] # list of (proc, label) - current_env = os.environ.copy() +def make_log_dir(session): + ts = str(int(time.time())) + if session: + log_dir = os.path.join("sessions", session, "proc_logs", ts) + else: + log_dir = os.path.join("proc_logs", ts) + os.makedirs(log_dir, exist_ok=True) + return log_dir + + +def label_to_filename(label): + return label.replace(", ", "_").replace("=", "") + ".log" + + +def _elapsed_str(seconds): + m, s = divmod(int(seconds), 60) + h, m = divmod(m, 60) + return f"{h}h{m:02d}m{s:02d}s" if h else f"{m}m{s:02d}s" + + +def _run_plain(tasks, max_parallel, log_dir, launch): + pending = list(tasks) + active = [] # (proc, label, log_fh, start_time) + done_log = [] failure_count = 0 - multi_seed = len(seeds) > 1 + total = len(pending) - for combo, seed in itertools.product(combinations, seeds): - efficiency_weight, price_weight, idle_weight, job_age_weight, drop_weight = combo - label = f"efficiency={efficiency_weight}, price={price_weight}, idle={idle_weight}, job_age={job_age_weight}, drop={drop_weight}" - if multi_seed: - label += f", seed={seed}" + print(f"[run] logs -> {log_dir}/") + + try: + while pending or active: + while pending and len(active) < max_parallel: + combo, seed = pending.pop(0) + proc, label, fh, t0 = launch(combo, seed) + print(f"[run] starting ({len(done_log) + len(active) + 1}/{total}): {label}") + active.append((proc, label, fh, t0)) - # Wait until a slot is free - while len(active) >= max_parallel: still_running = [] - for proc, lbl in active: - if proc.poll() is None: - still_running.append((proc, lbl)) - else: + for proc, label, fh, t0 in active: + if proc.poll() is not None: + fh.close() rc = proc.returncode if rc != 0: failure_count += 1 - status = "done" if rc == 0 else f"error (rc={rc})" - print(f"[run] {status}: {lbl}") + elapsed = time.time() - t0 + done_log.append((label, rc, elapsed)) + status = "done" if rc == 0 else f"FAILED (rc={rc})" + print(f"[run] [{len(done_log)}/{total}] {status}: {label} ({_elapsed_str(elapsed)})") + else: + still_running.append((proc, label, fh, t0)) active = still_running - if len(active) >= max_parallel: + + if active: time.sleep(1) + finally: + for proc, label, fh, t0 in active: + try: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + proc.wait() + except OSError: + pass + try: + fh.close() + except OSError: + pass + + return failure_count + + +def _draw_tui(stdscr, active, done_log, n_pending, total, log_dir, input_buf=""): + import curses as _curses + try: + stdscr.erase() + h, w = stdscr.getmaxyx() + row = 0 + + hdr = (f"train_iter [{len(done_log)}/{total} done | {len(active)} running | " + f"{n_pending} queued] logs: {log_dir}/") + stdscr.addstr(row, 0, hdr[:w - 1], _curses.A_BOLD) + row += 1 + stdscr.addstr(row, 0, "-" * min(w - 1, 80)) + row += 1 + + # Reserve last line for the terminate prompt + body_end = h - 2 + + if active and row < body_end: + stdscr.addstr(row, 0, "Running:") + row += 1 + for i, (_, label, _, t0) in enumerate(active): + if row >= body_end: + break + line = f" [{i + 1}] {_elapsed_str(time.time() - t0)} {label}" + stdscr.addstr(row, 0, line[:w - 1]) + row += 1 + row += 1 + + max_show = body_end - row + if done_log and max_show > 1 and row < body_end: + stdscr.addstr(row, 0, "Completed:") + row += 1 + max_show -= 1 + for label, rc, elapsed in done_log[-max_show:]: + if row >= body_end: + break + if rc == 0: + status = "done" + elif rc == -1: + status = "terminated" + else: + status = f"FAILED(rc={rc})" + stdscr.addstr(row, 0, f" {status}: {label} ({_elapsed_str(elapsed)})"[:w - 1]) + row += 1 + # Terminate prompt at the last line + prompt = f"Terminate #: {input_buf}_" + stdscr.addstr(h - 1, 0, prompt[:w - 1]) + + stdscr.refresh() + except Exception: + pass + + +def _run_tui(stdscr, tasks, max_parallel, log_dir, launch): + import curses as _curses + _curses.curs_set(0) + stdscr.nodelay(True) + + pending = list(tasks) + active = [] # (proc, label, log_fh, start_time) + done_log = [] # (label, rc, elapsed) + failure_count = 0 + total = len(pending) + input_buf = "" + + while pending or active: + while pending and len(active) < max_parallel: + combo, seed = pending.pop(0) + proc, label, fh, t0 = launch(combo, seed) + active.append((proc, label, fh, t0)) + + still_running = [] + for proc, label, fh, t0 in active: + if proc.poll() is not None: + fh.close() + rc = proc.returncode + if rc != 0: + failure_count += 1 + done_log.append((label, rc, time.time() - t0)) + else: + still_running.append((proc, label, fh, t0)) + active = still_running + + _draw_tui(stdscr, active, done_log, len(pending), total, log_dir, input_buf) + + try: + key = stdscr.getkey() + if key in ("\n", "\r", "KEY_ENTER"): + try: + idx = int(input_buf) - 1 + if 0 <= idx < len(active): + proc, label, fh, t0 = active.pop(idx) + elapsed = time.time() - t0 + proc.terminate() + def _reap(proc, label, fh, elapsed): + try: + proc.wait() + except OSError: + pass + try: + fh.close() + except OSError: + pass + done_log.append((label, -1, elapsed)) + threading.Thread(target=_reap, args=(proc, label, fh, elapsed), daemon=True).start() + failure_count += 1 + except ValueError: + pass + input_buf = "" + elif key in ("KEY_BACKSPACE", "\x7f", "\b"): + input_buf = input_buf[:-1] + elif key == "\x1b": # ESC + input_buf = "" + elif key.isdigit(): + input_buf += key + except _curses.error: + pass + + time.sleep(0.25) + + _draw_tui(stdscr, [], done_log, 0, total, log_dir) + try: + h, w = stdscr.getmaxyx() + summary = f"All {total} runs done. {failure_count} failure(s). Press any key to exit." + stdscr.addstr(h - 1, 0, summary[:w - 1], _curses.A_BOLD) + stdscr.refresh() + except Exception: + pass + if sys.stdin.isatty(): + stdscr.nodelay(False) + stdscr.getch() + + return failure_count + + +def run_all_parallel(combinations, max_parallel, iter_limit_per_step, session, prices, + job_durations, jobs, hourly_jobs, job_arrival_scale, jobs_exact_replay, + jobs_exact_replay_aggregate, plot_dashboard, dashboard_hours, + seeds, seed_sweep, evaluate_savings, eval_months, workloadgen_args, + no_tui=False): + multi_seed = len(seeds) > 1 + current_env = os.environ.copy() + log_dir = make_log_dir(session) + tasks = list(itertools.product(combinations, seeds)) + + def launch(combo, seed): + efficiency_weight, price_weight, idle_weight, job_age_weight, drop_weight = combo + label = f"efficiency={efficiency_weight}, price={price_weight}, idle={idle_weight}, job_age={job_age_weight}, drop={drop_weight}" + if multi_seed: + label += f", seed={seed}" command = build_command( efficiency_weight, price_weight, idle_weight, job_age_weight, drop_weight, - iter_limit_per_step, session, prices, job_durations, jobs, hourly_jobs, job_arrival_scale, jobs_exact_replay, jobs_exact_replay_aggregate, + iter_limit_per_step, session, prices, job_durations, jobs, hourly_jobs, + job_arrival_scale, jobs_exact_replay, jobs_exact_replay_aggregate, plot_dashboard, dashboard_hours, seed, seed_sweep, - evaluate_savings, eval_months, - workloadgen_args, + evaluate_savings, eval_months, workloadgen_args, ) - print(f"[run] starting: {label}") - proc = subprocess.Popen(command, env=current_env) - active.append((proc, label)) - - # Wait for all remaining processes - for proc, label in active: - proc.wait() - rc = proc.returncode - if rc != 0: - failure_count += 1 - status = "done" if rc == 0 else f"error (rc={rc})" - print(f"[run] {status}: {label}") - - return failure_count + log_path = os.path.join(log_dir, label_to_filename(label)) + log_fh = open(log_path, "w") + try: + proc = subprocess.Popen(command, env=current_env, stdout=log_fh, stderr=subprocess.STDOUT) + except OSError: + log_fh.close() + raise + return proc, label, log_fh, time.time() + + if not no_tui and sys.stdout.isatty(): + import curses + failure_count = [0] + def _run(stdscr): + failure_count[0] = _run_tui(stdscr, tasks, max_parallel, log_dir, launch) + curses.wrapper(_run) + return failure_count[0] + else: + return _run_plain(tasks, max_parallel, log_dir, launch) def parse_fixed_weights(fix_weights_str, fix_values_str): if not fix_weights_str or not fix_values_str: @@ -244,6 +443,7 @@ def main(): parser.add_argument("--parallel", type=int, default=1, metavar="N", help="Number of training runs to execute in parallel (default: 1, sequential)") parser.add_argument("--evaluate-savings", action="store_true", help="Forward to train.py to evaluate savings compared to baseline.") parser.add_argument("--eval-months", type=int, default=6, help="Number of months to evaluate savings over (forwarded to train.py)") + parser.add_argument("--no-tui", action="store_true", help="Disable interactive TUI; print plain progress lines instead (auto-disabled when not a TTY)") add_workloadgen_args(parser) parser.add_argument("--session", help="Session ID") @@ -310,6 +510,7 @@ def main(): evaluate_savings=args.evaluate_savings, eval_months=args.eval_months, workloadgen_args=workloadgen_args, + no_tui=args.no_tui, ) if failures: print(f"{failures} run(s) failed")