diff --git a/README.md b/README.md index b1c1d8c..10f70ff 100644 --- a/README.md +++ b/README.md @@ -69,8 +69,8 @@ puma rm inftyai/tiny-random-gpt2 ### API Server ```bash -# Start the inference server -puma serve +# Start the inference server with a model +puma serve inftyai/tiny-random-gpt2 # Server will start on http://0.0.0.0:8000 # API endpoints: @@ -109,7 +109,7 @@ curl http://localhost:8000/v1/chat/completions \ | `rm ` | ✅ | Remove model and cache | | `info` | ✅ | Display system information | | `version` | ✅ | Show PUMA version | -| `serve` | ✅ | Start OpenAI-compatible API server | +| `serve ` | ✅ | Start OpenAI-compatible API server with a model | | `ps` | 🚧 | List running models | | `run` | 🚧 | Start model inference | | `stop` | 🚧 | Stop running model | @@ -151,11 +151,14 @@ PUMA provides an OpenAI-compatible API server for model inference. ### Starting the Server ```bash -# Default: 0.0.0.0:8000 -puma serve +# Start server with a model (default: 0.0.0.0:8000) +puma serve inftyai/tiny-random-gpt2 # Custom host and port -puma serve --host 127.0.0.1 --port 3000 +puma serve inftyai/tiny-random-gpt2 --host 127.0.0.1 --port 3000 + +# Model must be pulled first +puma pull inftyai/tiny-random-gpt2 ``` ### API Endpoints @@ -188,13 +191,14 @@ curl http://localhost:8000/v1/chat/completions \ #### List Models ```bash +# Returns the currently loaded model curl http://localhost:8000/v1/models ``` #### Health Check ```bash curl http://localhost:8000/health -# Returns: {"status":"ok","version":"0.0.2"} +# Returns: {"status":"ok"} ``` ### OpenAI Python Client diff --git a/src/api/routes.rs b/src/api/routes.rs index 403dd5b..dae6c0e 100644 --- a/src/api/routes.rs +++ b/src/api/routes.rs @@ -60,13 +60,11 @@ pub fn create_router( #[derive(Serialize)] struct HealthResponse { status: String, - version: String, } /// Health check endpoint async fn health_check() -> Json { Json(HealthResponse { status: "ok".to_string(), - version: env!("CARGO_PKG_VERSION").to_string(), }) } diff --git a/src/api/tests.rs b/src/api/tests.rs index 512e55a..fba05ca 100644 --- a/src/api/tests.rs +++ b/src/api/tests.rs @@ -100,7 +100,6 @@ async fn test_health_check() { assert_eq!(status, StatusCode::OK); assert_eq!(json["status"], "ok"); - assert!(json["version"].is_string()); } #[tokio::test] diff --git a/src/cli/commands.rs b/src/cli/commands.rs index ea0a87c..651d88f 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -43,6 +43,9 @@ enum Commands { #[derive(Parser)] struct ServeArgs { + /// Model name to serve (e.g., inftyai/tiny-random-gpt2) + model: String, + /// Host address to bind to #[arg(long, default_value = "0.0.0.0")] host: String, @@ -221,7 +224,24 @@ pub async fn run(cli: Cli) { } Commands::SERVE(args) => { - if let Err(e) = crate::cli::serve::execute(&args.host, args.port).await { + // Verify model exists + let registry = ModelRegistry::new(None); + match registry.get_model(&args.model) { + Ok(Some(_)) => { + // Model exists, proceed + } + Ok(None) => { + eprintln!("❌ Error: Model '{}' not found in registry", args.model); + eprintln!("Run 'puma pull {}' to download it first", args.model); + std::process::exit(1); + } + Err(e) => { + eprintln!("❌ Error checking model: {}", e); + std::process::exit(1); + } + } + + if let Err(e) = crate::cli::serve::execute(&args.host, args.port, &args.model).await { eprintln!("Error starting server: {}", e); std::process::exit(1); } @@ -392,4 +412,58 @@ mod tests { assert_eq!(result.metadata.cache.revision, "v2"); assert_eq!(result.metadata.cache.size, 2000); } + + #[test] + fn test_serve_with_existing_model() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = create_test_model("test/serve-model", "abc123"); + registry.register_model(model).unwrap(); + + // Verify model exists (this is what serve command checks) + let result = registry.get_model("test/serve-model"); + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_serve_with_nonexistent_model() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + // Verify model doesn't exist + let result = registry.get_model("nonexistent/model"); + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); + } + + #[test] + fn test_serve_args_parsing() { + // Test that ServeArgs requires model argument + use clap::CommandFactory; + let app = Cli::command(); + + // This should fail without model argument + let result = app.clone().try_get_matches_from(vec!["puma", "serve"]); + assert!(result.is_err()); + + // This should succeed with model argument + let result = app + .clone() + .try_get_matches_from(vec!["puma", "serve", "test/model"]); + assert!(result.is_ok()); + + // This should succeed with model and optional args + let result = app.try_get_matches_from(vec![ + "puma", + "serve", + "test/model", + "--host", + "127.0.0.1", + "--port", + "9000", + ]); + assert!(result.is_ok()); + } } diff --git a/src/cli/serve.rs b/src/cli/serve.rs index d0bb814..c9f7a31 100644 --- a/src/cli/serve.rs +++ b/src/cli/serve.rs @@ -7,7 +7,11 @@ use crate::backend::mock::MockEngine; use crate::registry::model_registry::ModelRegistry; /// Execute the serve command -pub async fn execute(host: &str, port: u16) -> Result<(), Box> { +pub async fn execute( + host: &str, + port: u16, + model_name: &str, +) -> Result<(), Box> { println!( "{}", " @@ -23,7 +27,7 @@ pub async fn execute(host: &str, port: u16) -> Result<(), Box