Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 0 additions & 79 deletions src/simulation/scripts/lanch_one_simu.py

This file was deleted.

76 changes: 76 additions & 0 deletions src/simulation/scripts/launch_one_simu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
import sys
from pathlib import Path

import numpy as np
import onnxruntime as ort

import simulation.config as c
from extractors import ( # noqa: F401
CNN1DResNetExtractor,
TemporalResNetExtractor,
)
from simulation import VehicleEnv
from utils import run_onnx_model

ONNX_MODEL_PATH = c.save_dir / f"model_{c.ExtractorClass.__name__}.onnx"


def init_onnx_runtime_session(onnx_path: Path) -> ort.InferenceSession:
if not os.path.exists(onnx_path):
raise FileNotFoundError(
f"The ONNX file could not be found at: {onnx_path}. Please export it first."
)
return ort.InferenceSession(onnx_path)


if __name__ == "__main__":
if not os.path.exists("/tmp/autotech/"):
os.mkdir("/tmp/autotech/")

os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')

# Starting the ONNX session
try:
ort_session = init_onnx_runtime_session(ONNX_MODEL_PATH)
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
print(f"ONNX model loaded from {ONNX_MODEL_PATH}")
print(f"Input Name: {input_name}, Output Name: {output_name}")
except FileNotFoundError as e:
print(f"ERROR: {e}")
sys.exit(1)

env = VehicleEnv(0, 0)
obs, _ = env.reset()

print("Starting simulation in inference mode...")

step_count = 0

while True:
raw_action = run_onnx_model(ort_session, obs[None])
logits = np.array(raw_action).flatten()

steer_logits = logits[: c.n_actions_steering]
speed_logits = logits[c.n_actions_steering :]

action_steer = np.argmax(steer_logits)
action_speed = np.argmax(speed_logits)

action = np.array([action_steer, action_speed], dtype=np.int64)

next_obs, reward, done, truncated, info = env.step(action)

step_count += 1

if done:
print(f"Episode(s) finished after {step_count} steps.")
step_count = 0

fresh_frame = next_obs[:, -1:]
obs, _ = env.reset()
env.context[:, -1:] = fresh_frame
obs = env.context
else:
obs = next_obs
36 changes: 6 additions & 30 deletions src/simulation/scripts/launch_train_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,7 @@
]
)

policy_kwargs: Dict[str, Any] = dict(
features_extractor_class=c.ExtractorClass,
# features_extractor_kwargs=dict(
# device=c.device,
# ),
activation_fn=nn.ReLU,
net_arch=[512, 512, 512],
)

ppo_args: Dict[str, Any] = dict(
n_steps=4096,
n_epochs=10,
batch_size=256,
learning_rate=3e-4,
gamma=0.99,
verbose=1,
normalize_advantage=True,
device=c.device,
)

save_path = (
Path("~/.cache/autotech/checkpoints").expanduser() / c.ExtractorClass.__name__
)
save_path = c.save_dir / "checkpoints" / c.ExtractorClass.__name__

save_path.mkdir(parents=True, exist_ok=True)

Expand All @@ -61,12 +39,12 @@
if valid_files:
model_path = max(valid_files, key=lambda x: int(x.name.rstrip(".zip")))
print(f"Loading model {model_path.name}")
model = PPO.load(model_path, envs, **ppo_args, policy_kwargs=policy_kwargs)
model = PPO.load(model_path, envs, **c.ppo_args, policy_kwargs=c.policy_kwargs)
i = int(model_path.name.rstrip(".zip")) + 1
print(f"Model found, loading {model_path}")

else:
model = PPO("MlpPolicy", envs, **ppo_args, policy_kwargs=policy_kwargs)
model = PPO("MlpPolicy", envs, **c.ppo_args, policy_kwargs=c.policy_kwargs)

i = 0
print("Model not found, creating a new one")
Expand All @@ -83,22 +61,20 @@
while True:
onnx_utils.export_onnx(
model,
os.path.expanduser(
f"~/.cache/autotech/model_{c.ExtractorClass.__name__}.onnx"
),
str(c.save_dir / f"model_{c.ExtractorClass.__name__}.onnx"),
)
onnx_utils.test_onnx(model)

if c.LOG_LEVEL <= DEBUG:
from utils import PlotModelIO

model.learn(
total_timesteps=500_000,
total_timesteps=c.total_timesteps,
progress_bar=False,
callback=PlotModelIO(),
)
else:
model.learn(total_timesteps=500_000, progress_bar=True)
model.learn(total_timesteps=c.total_timesteps, progress_bar=True)

print("iteration over")
# TODO: we could just use a callback to save checkpoints or export the model to onnx
Expand Down
36 changes: 35 additions & 1 deletion src/simulation/src/simulation/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# just a file that lets us define some constants that are used in multiple files the simulation
import logging
from pathlib import Path
from typing import Any, Dict

import torch.nn as nn
from torch.cuda import is_available

from extractors import ( # noqa: F401
Expand All @@ -9,20 +12,51 @@
TemporalResNetExtractor,
)

# Webots environments config
n_map = 2
n_simulations = 1
n_vehicles = 2
n_vehicles = 1
n_stupid_vehicles = 0
n_actions_steering = 16
n_actions_speed = 16
lidar_max_range = 12.0
respawn_on_crash = True # whether to go backwards or to respawn when crashing


# Training config
device = "cuda" if is_available() else "cpu"
save_dir = Path("~/.cache/autotech").expanduser()
total_timesteps = 500_000
ppo_args: Dict[str, Any] = dict(
n_steps=4096,
n_epochs=10,
batch_size=256,
learning_rate=3e-4,
gamma=0.99,
verbose=1,
normalize_advantage=True,
device=device,
)


# Common extractor shared between the policy and value networks
# (cf: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html)
ExtractorClass = TemporalResNetExtractor
context_size = ExtractorClass.context_size
lidar_horizontal_resolution = ExtractorClass.lidar_horizontal_resolution
camera_horizontal_resolution = ExtractorClass.camera_horizontal_resolution
n_sensors = ExtractorClass.n_sensors


# Architecture of the model
policy_kwargs: Dict[str, Any] = dict(
features_extractor_class=ExtractorClass,
activation_fn=nn.ReLU,
# Architecture of the MLP heads for the Value and Policy networks
net_arch=[512, 512, 512],
)


# Logging config
LOG_LEVEL = logging.INFO
FORMATTER = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
5 changes: 5 additions & 0 deletions src/simulation/src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .plot_model_io import PlotModelIO
import onnxruntime as ort
import numpy as np

__all__ = ["PlotModelIO"]

def run_onnx_model(session: ort.InferenceSession, x: np.ndarray):
return session.run(None, {"input": x})[0]
3 changes: 1 addition & 2 deletions src/simulation/src/utils/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def test_onnx(model: OnPolicyAlgorithm):

try:
class_name = model.policy.features_extractor.__class__.__name__
model_path = os.path.expanduser(f"~/.cache/autotech/model_{class_name}.onnx")

model_path = c.save_dir / f"model_{class_name}.onnx"
session = ort.InferenceSession(model_path)
except Exception as e:
print(f"Error loading ONNX model: {e}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def step(self):
done = np.True_
elif b_collided:
reward = np.float32(-0.5)
done = np.False_
done = np.bool(c.respawn_on_crash)
elif b_past_checkpoint:
reward = np.float32(1.0)
done = np.False_
Expand Down
10 changes: 10 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.