diff --git a/src/simulation/scripts/lanch_one_simu.py b/src/simulation/scripts/lanch_one_simu.py deleted file mode 100644 index 8779f0b..0000000 --- a/src/simulation/scripts/lanch_one_simu.py +++ /dev/null @@ -1,79 +0,0 @@ -raise NotImplementedError("This file is currently begin worked on") - -import os -import sys - -import onnxruntime as ort - -from simulation import ( - VehicleEnv, -) -from simulation import config as c -from utils import onnx_utils - -# ------------------------------------------------------------------------- - - -# --- Chemin vers le fichier ONNX --- - -ONNX_MODEL_PATH = "model.onnx" - - -# --- Initialisation du moteur d'inférence ONNX Runtime (ORT) --- -def init_onnx_runtime_session(onnx_path: str) -> ort.InferenceSession: - if not os.path.exists(onnx_path): - raise FileNotFoundError( - f"Le fichier ONNX est introuvable à : {onnx_path}. Veuillez l'exporter d'abord." - ) - - # Crée la session d'inférence - return ort.InferenceSession( - onnx_path - ) # On peut modifier le providers afin de mettre une CUDA - - -if __name__ == "__main__": - if not os.path.exists("/tmp/autotech/"): - os.mkdir("/tmp/autotech/") - - os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi') - - # 2. Initialisation de la session ONNX Runtime - try: - ort_session = init_onnx_runtime_session(ONNX_MODEL_PATH) - input_name = ort_session.get_inputs()[0].name - output_name = ort_session.get_outputs()[0].name - print(f"Modèle ONNX chargé depuis {ONNX_MODEL_PATH}") - print(f"Input Name: {input_name}, Output Name: {output_name}") - except FileNotFoundError as e: - print(f"ERREUR : {e}") - print( - "Veuillez vous assurer que vous avez exécuté une fois le script d'entraînement pour exporter 'model.onnx'." - ) - sys.exit(1) - - # 3. Boucle d'inférence (Test) - env = VehicleEnv(0, 0) - obs = env.reset() - print("Début de la simulation en mode inférence...") - - max_steps = 5000 - step_count = 0 - - while True: - action = onnx_utils.run_onnx_model(ort_session, obs) - - # 4. Exécuter l'action dans l'environnement - obs, reward, done, info = env.step(action) - - # Note: L'environnement Webots gère généralement son propre affichage - # env.render() # Décommenter si votre env supporte le rendu externe - - # Gestion des fins d'épisodes - if done: - print(f"Épisode(s) terminé(s) après {step_count} étapes.") - obs = env.reset() - - # Fermeture propre (très important pour les processus parallèles SubprocVecEnv) - envs.close() - print("Simulation terminée. Environnements fermés.") diff --git a/src/simulation/scripts/launch_one_simu.py b/src/simulation/scripts/launch_one_simu.py new file mode 100644 index 0000000..aad8dde --- /dev/null +++ b/src/simulation/scripts/launch_one_simu.py @@ -0,0 +1,76 @@ +import os +import sys +from pathlib import Path + +import numpy as np +import onnxruntime as ort + +import simulation.config as c +from extractors import ( # noqa: F401 + CNN1DResNetExtractor, + TemporalResNetExtractor, +) +from simulation import VehicleEnv +from utils import run_onnx_model + +ONNX_MODEL_PATH = c.save_dir / f"model_{c.ExtractorClass.__name__}.onnx" + + +def init_onnx_runtime_session(onnx_path: Path) -> ort.InferenceSession: + if not os.path.exists(onnx_path): + raise FileNotFoundError( + f"The ONNX file could not be found at: {onnx_path}. Please export it first." + ) + return ort.InferenceSession(onnx_path) + + +if __name__ == "__main__": + if not os.path.exists("/tmp/autotech/"): + os.mkdir("/tmp/autotech/") + + os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi') + + # Starting the ONNX session + try: + ort_session = init_onnx_runtime_session(ONNX_MODEL_PATH) + input_name = ort_session.get_inputs()[0].name + output_name = ort_session.get_outputs()[0].name + print(f"ONNX model loaded from {ONNX_MODEL_PATH}") + print(f"Input Name: {input_name}, Output Name: {output_name}") + except FileNotFoundError as e: + print(f"ERROR: {e}") + sys.exit(1) + + env = VehicleEnv(0, 0) + obs, _ = env.reset() + + print("Starting simulation in inference mode...") + + step_count = 0 + + while True: + raw_action = run_onnx_model(ort_session, obs[None]) + logits = np.array(raw_action).flatten() + + steer_logits = logits[: c.n_actions_steering] + speed_logits = logits[c.n_actions_steering :] + + action_steer = np.argmax(steer_logits) + action_speed = np.argmax(speed_logits) + + action = np.array([action_steer, action_speed], dtype=np.int64) + + next_obs, reward, done, truncated, info = env.step(action) + + step_count += 1 + + if done: + print(f"Episode(s) finished after {step_count} steps.") + step_count = 0 + + fresh_frame = next_obs[:, -1:] + obs, _ = env.reset() + env.context[:, -1:] = fresh_frame + obs = env.context + else: + obs = next_obs diff --git a/src/simulation/src/simulation/config.py b/src/simulation/src/simulation/config.py index 5b1dd8c..5fb064f 100644 --- a/src/simulation/src/simulation/config.py +++ b/src/simulation/src/simulation/config.py @@ -15,7 +15,7 @@ # Webots environments config n_map = 2 n_simulations = 1 -n_vehicles = 2 +n_vehicles = 1 n_stupid_vehicles = 0 n_actions_steering = 16 n_actions_speed = 16 diff --git a/src/simulation/src/utils/__init__.py b/src/simulation/src/utils/__init__.py index 4b7ee72..7e56893 100644 --- a/src/simulation/src/utils/__init__.py +++ b/src/simulation/src/utils/__init__.py @@ -1,3 +1,8 @@ from .plot_model_io import PlotModelIO +import onnxruntime as ort +import numpy as np __all__ = ["PlotModelIO"] + +def run_onnx_model(session: ort.InferenceSession, x: np.ndarray): + return session.run(None, {"input": x})[0] \ No newline at end of file diff --git a/uv.lock b/uv.lock index 7effdcb..7ac586b 100644 --- a/uv.lock +++ b/uv.lock @@ -852,6 +852,11 @@ dependencies = [ { name = "zmq" }, ] +[package.optional-dependencies] +controller = [ + { name = "pygame" }, +] + [package.metadata] requires-dist = [ { name = "adafruit-blinka", specifier = ">=8.0.0" }, @@ -869,6 +874,7 @@ requires-dist = [ { name = "onnxruntime", specifier = ">=1.8.0" }, { name = "opencv-python", specifier = ">=4.12.0.88" }, { name = "picamera2", specifier = ">=0.3.0" }, + { name = "pygame", marker = "extra == 'controller'", specifier = ">=2.6.1" }, { name = "pyps4controller", specifier = ">=1.2.0" }, { name = "rpi-gpio", specifier = ">=0.7.1" }, { name = "rpi-hardware-pwm", specifier = ">=0.1.0" }, @@ -879,6 +885,7 @@ requires-dist = [ { name = "websockets", specifier = ">=16.0" }, { name = "zmq", specifier = ">=0.0.0" }, ] +provides-extras = ["controller"] [[package]] name = "humanfriendly" @@ -2392,6 +2399,9 @@ dependencies = [ wheels = [ { url = "https://files.pythonhosted.org/packages/d3/54/a2ba279afcca44bbd320d4e73675b282fcee3d81400ea1b53934efca6462/torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574", size = 79498202, upload-time = "2026-02-10T21:44:52.603Z" }, { url = "https://files.pythonhosted.org/packages/ec/23/2c9fe0c9c27f7f6cb865abcea8a4568f29f00acaeadfc6a37f6801f84cb4/torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e", size = 79498254, upload-time = "2026-02-10T21:44:44.095Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7a/abada41517ce0011775f0f4eacc79659bc9bc6c361e6bfe6f7052a6b9363/torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6", size = 915622781, upload-time = "2026-03-11T14:17:11.354Z" }, + { url = "https://files.pythonhosted.org/packages/ab/c6/4dfe238342ffdcec5aef1c96c457548762d33c40b45a1ab7033bb26d2ff2/torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b", size = 915627275, upload-time = "2026-03-11T14:16:11.325Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f0/72bf18847f58f877a6a8acf60614b14935e2f156d942483af1ffc081aea0/torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49", size = 915523474, upload-time = "2026-03-11T14:17:44.422Z" }, { url = "https://files.pythonhosted.org/packages/cc/af/758e242e9102e9988969b5e621d41f36b8f258bb4a099109b7a4b4b50ea4/torch-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5fd4117d89ffd47e3dcc71e71a22efac24828ad781c7e46aaaf56bf7f2796acf", size = 145996088, upload-time = "2026-01-21T16:24:44.171Z" }, { url = "https://files.pythonhosted.org/packages/23/8e/3c74db5e53bff7ed9e34c8123e6a8bfef718b2450c35eefab85bb4a7e270/torch-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:787124e7db3b379d4f1ed54dd12ae7c741c16a4d29b49c0226a89bea50923ffb", size = 915711952, upload-time = "2026-01-21T16:23:53.503Z" }, { url = "https://files.pythonhosted.org/packages/6e/01/624c4324ca01f66ae4c7cd1b74eb16fb52596dce66dbe51eff95ef9e7a4c/torch-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:2c66c61f44c5f903046cc696d088e21062644cbe541c7f1c4eaae88b2ad23547", size = 113757972, upload-time = "2026-01-21T16:24:39.516Z" },