Skip to content

feature: support native YOLO .pt models while ensuring compatibility with Torchvision models#495

Open
kashtennyson wants to merge 3 commits intoJdeRobot:masterfrom
kashtennyson:issue-449
Open

feature: support native YOLO .pt models while ensuring compatibility with Torchvision models#495
kashtennyson wants to merge 3 commits intoJdeRobot:masterfrom
kashtennyson:issue-449

Conversation

@kashtennyson
Copy link
Copy Markdown

Description

This PR adds support for loading native Ultralytics YOLOv8 .pt models while ensuring a consistent interface for the rest of the library. This is a fix for #449

The Problem:
Native YOLO .pt models often return a tuple (inference_tensor, loss_tensor) rather than a raw tensor, which causes "too many values to unpack" errors in the inference and eval methods. Additionally, these models frequently use float16 (Half) precision, leading to DType mismatches with input images or NMS kernel errors on certain backends.

The Solution:
Following previous feedback, I have centralized the fix within the TorchImageDetectionModel class. I implemented a local Adapter class (DetectionModelWrapper) that standardizes the model's behavior at the source:

  • Tuple Unpacking: Automatically extracts the primary detection tensor.
  • Input Alignment: Automatically casts input images to match the model's native dtype (fixing "Float vs Half" errors).
  • Output Alignment: Ensures results are returned as float32 to maintain compatibility with torchvision.ops.nms.
  • Graceful Fallback: Wrapped the .pt loading logic to provide a clear error message suggesting the installation of ultralytics if it is missing.

This PR Supersedes #469. It implements a more stable version by ensuring compatibility with Torchvision models along with the Ultralytics YOLO models.


Architectural Question for Maintainers

"I have implemented the DetectionModelWrapper as a local class within the __init__ method of TorchImageDetectionModel to keep the fix strictly within the requested section and ensure that the normalization is context-specific to the model instance.

Do you prefer this local encapsulation, or would you like me to refactor the wrapper into a private, module-level class (e.g., _ModelNormalizationWrapper) at the top of the file to keep the __init__ method more concise?"

@dpascualhe dpascualhe self-requested a review March 25, 2026 19:08
@dpascualhe dpascualhe self-assigned this Mar 25, 2026
@dpascualhe
Copy link
Copy Markdown
Collaborator

Hi, thanks for your contribution! I'll review the PR thoroughly when I can since this is an important upgrade.

@kashtennyson
Copy link
Copy Markdown
Author

Alright @dpascualhe. Thanks for the update!

I am also currently working on a broader refactor to provide global .pt support across all tasks (Detection, Segmentation, and LiDAR) by centralizing the loading and normalization logic into a shared BaseTorchModel utility. So, your guidance and feedback is crucial for these architectural decisions. Looking forward to your thoughts!

@kashtennyson
Copy link
Copy Markdown
Author

Hi @dpascualhe, I am still waiting on this PR's evaluation from your side since I plan to extend the support for .pt files across all perception tasks by centralizing the logic in a similar way. Review it whenever you can. Thanks!

Copy link
Copy Markdown
Collaborator

@dpascualhe dpascualhe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good and works for me! Can you get rid of the model_dtype extraction and subsequent casting of the input tensor during inference? I'd rather keep things as they are in that regard, so that an error is raised if they don't match for whatever reason. Upon changing that we can merge. Also, resolve conflicts with current master branch.

@kashtennyson
Copy link
Copy Markdown
Author

Thanks for the review! I have performed the requested changes. However, to keep the solution intact I have added an explicit .float() during model initialization to avoid Float vs Half precision error that occurs with .pt models during inference. Let me know if you’d prefer to remove it or handle the casting differently. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants