Skip to content

Latest commit

 

History

History
494 lines (358 loc) · 22.7 KB

File metadata and controls

494 lines (358 loc) · 22.7 KB

Optimizing Models: A Train Of Thought

An T. Le, Hanoi, Nov 2025 (revised Jan 2026)

In practice, modern foundation models are optimized in two tightly-coupled layers:

  • (A) Model-side optimization: quantization, pruning/sparsity, distillation, low-rank / factorized parameterization, etc. (reduce compute/memory while preserving task performance).
  • (B) Deployment optimization: compilers + runtimes (e.g., TensorRT-LLM/TensorRT, OpenVINO, ONNX Runtime, LiteRT (ex-TFLite), TVM, ncnn, vendor SDKs) that turn an optimized checkpoint into a hardware-efficient engine.

NOTE: TorchAO / torchao mostly belongs to (A), but it increasingly acts as the “bridge” into (B) via export/compile flows.

This note is a quick mental map of the mainstream pathways for compressing and deploying foundation models (with links to docs + code).


1. Mainstream (A) Model-side optimization

1.1 Quantization (almost always the first step)

Goal: reduce weights/activations from FP16/FP32 -> INT8/INT4/FP8/FP4 without unacceptable accuracy loss.

Common variants (esp. for transformers) (survey: Zhu et al., 2024):

  • Post-Training Quantization (PTQ)
    No retraining; use a small calibration set.

    • Static (typical INT8): collect activation stats offline, quantize weights + activations.
    • Dynamic (common on CPUs): quantize weights offline; compute some activation quant params at runtime.
    • Transformer-specific PTQ recipes (examples): SmoothQuant, AWQ.

    Codebases:

  • Quantization-Aware Training (QAT) Simulate quantization during fine-tuning to recover accuracy when PTQ is too lossy.

  • Mixed precision & vendor-specific formats

    • Mixed FP16/BF16 is the default “cheap win.”
    • NVIDIA-specific low-bit floats: FP8 and NVFP4/FP4 are supported through TensorRT + NVIDIA Model Optimizer (ModelOpt). Practical entry points: NVFP4 overview, TensorRT quantized types.

Reality check: the format only matters if your deployment stack has the matching kernels (this is where TensorRT-LLM / TensorRT / OpenVINO / ORT / LiteRT differ).

Example (GR00T-style VLA):

  • Vision + language backbones: INT8 / FP8 / FP4 (where supported)
  • Action/diffusion head + final layers: FP16/BF16

References:


1.2 Pruning & sparsity

Goal: remove “unimportant” parameters (often combined with quantization). Speedups depend heavily on kernel support and sparsity structure.

Common flavors:

  • Unstructured pruning (weight sparsity)

    • Easy to apply, but usually needs specialized sparse kernels to see wall-clock speedups.
  • Structured pruning (more reliable speedups)

    • Prune attention heads, MLP channels, entire blocks/layers, tokens, etc.
  • N:M (semi-structured) sparsity

    • Example: 2:4 sparsity (50% zeros in a constrained pattern) which maps to NVIDIA Sparse Tensor Cores.
    • Acceleration stack often involves TensorRT and/or cuSPARSELt plus an export path that preserves sparsity metadata.

    References:


1.3 Knowledge Distillation (KD)

Goal: train a smaller/cheaper student to mimic a larger teacher.

Main flavors:

  • Logit distillation: student matches teacher soft logits.
  • Feature distillation: align hidden states / attention maps.
  • Sequence / behavior distillation: student imitates teacher-generated trajectories/actions.

Tutorials / tooling:


1.4 Low-rank & factorization tricks

Often used for parameter-efficient adaptation, and sometimes for compression if merged or baked into the final model:

  • LoRA / low-rank adapters: train low-rank deltas, optionally merge.

  • Matrix/tensor decompositions: SVD / Tucker / CP, etc.


2. Mainstream (B) Deployment pipelines / toolchains

2.1 NVIDIA-centric: TensorRT-LLM + NVIDIA Model Optimizer (ModelOpt)

NVIDIA Model Optimizer (formerly “TensorRT Model Optimizer”) is the main toolkit for PTQ/QAT + pruning + distillation + speculative decoding + sparsity in the NVIDIA stack.

Pipeline sketch

  1. Start from a Hugging Face / PyTorch checkpoint (e.g., GR00T N1.5).
  2. Apply PTQ or QAT with ModelOpt (INT8/FP8/NVFP4, etc.).
  3. If needed: pruning/sparsity + distillation to recover accuracy at lower cost.
  4. Export/build a TensorRT(-LLM) engine; deploy (Jetson / server GPUs / Blackwell-class hardware).

Pre-quantized checkpoints:


2.2 Intel / CPU-centric: OpenVINO + NNCF (plus Intel Neural Compressor where useful)

For Intel CPUs/GPUs and many industrial deployments, the “mainline” path today is:

Hugging Face-friendly workflow:

Intel Neural Compressor (INC) is still relevant as a cross-framework quant/prune/distill toolkit (especially outside OpenVINO-only workflows):


2.3 PyTorch-native: torchao + torch.export (PT2E quantization)

If you want to stay close to PyTorch while exploring low-bit + sparsity:

Key tutorials (PyTorch 2 export quantization):

Pruning in “vanilla PyTorch”:

  • torch.nn.utils.prune is useful for simple experiments, but for structured pruning (channels/blocks with dependency handling), libraries like VainF/Torch-Pruning are often more practical.

2.4 Framework-agnostic / edge-oriented runtimes

Common choices across heterogeneous edge targets:

Typical flow:

  1. Do pruning/KD in PyTorch/TF.
  2. Export (ONNX / LiteRT / IR).
  3. Run runtime-specific quantization + graph optimizations.
  4. Deploy.

2.5 “All-in-one” compression frameworks

If you want config-driven automation across methods:

Note: older “SparseML” references exist, but the upstream repo is archived; treat it as legacy unless your org already depends on it.


2.6 Example: GR00T-like robotics FM

A practical “first pass” for a GR00T-style VLA:

  1. Baseline profiling on target (Jetson / server GPU / CPU box / SoC).
  2. Quantize most transformer layers (PTQ INT8/FP8/FP4 depending on hardware + stack); keep sensitive heads higher precision.
  3. Structured pruning / 2:4 sparsity only if your deployment engine has real sparse kernels for your shapes.
  4. Distill:
    • smaller VLA, or
    • task-specific students (e.g., manipulator-only policy).
  5. Export to the deployment stack:
    • NVIDIA path: PyTorch -> ModelOpt -> TensorRT-LLM/TensorRT
    • Intel path: PyTorch/HF -> OpenVINO IR -> NNCF -> OpenVINO runtime
    • General path: PyTorch -> ONNX -> ORT / TVM / ncnn
    • Mobile path: TF/PyTorch -> LiteRT

3. Serving-time optimization (often the biggest real-world win)

A useful mental model: cost splits into prefill (prompt processing) and decode (token-by-token).

  • Prefill is usually compute-bound (big GEMMs + attention).
  • Decode is often memory / KV-cache bandwidth bound.

So:

  • Weight-only INT4/FP4 helps decode if your stack has good kernels.
  • Better attention kernels (FlashAttention/FlashInfer) help prefill and reduce memory traffic.
  • KV-cache tricks matter most for long context and high concurrency.

3.1 Batching + scheduling + KV memory management

If you do nothing else, choose a serving engine that gives you:

  • continuous / dynamic batching
  • paged KV cache (reduces fragmentation under concurrency)
  • optional chunked prefill (smooths very long prompts)

Good entry points:

Reality check: feature parity differs (quant formats, speculative decoding, MoE, multi-modal, etc.).
Always confirm against each runtime’s “supported hardware + quantization” tables.

3.2 Kernel libraries that matter in practice

For transformer-heavy workloads, attention + MLP kernels are usually the make-or-break.

3.3 KV-cache optimization for long context + high concurrency

When context length or concurrency grows, KV cache can dominate VRAM and drive latency cliffs.

Two complementary strategies:

  • Systems: paged KV cache + chunked prefill (runtime feature).
  • Model-side: KV cache quantization/compression (typically 4–8 bit; often mixed precision).

Researchy-but-usable codebases:

Kernel library for efficient pruning decisions:

  • Flash-ColReduce: Triton kernels for attention column-wise reductions (sum/mean/max) with O(N) memory; enables efficient token/KV importance estimation without materializing O(N²) attention matrix. Repo: z-lab/flash-colreduce

3.4 Decoding acceleration (reduce target-model forward passes)

If decode is the bottleneck, you can reduce the number of expensive target-model steps:

  • Speculative decoding (draft model + verification): romsto/Speculative-Decoding
  • Multi-token heads (Medusa): FasterDecoding/Medusa
  • Block diffusion parallel drafting (DFlash): lightweight diffusion-based draft model generating multiple tokens in parallel; proven on Qwen/Llama/GPT-OSS; benefits LLM serving + high concurrency. Repo: z-lab/dflash

4. Deployment lanes by hardware (quick cheat sheet)

NVIDIA GPUs / Jetson / Blackwell-class

AMD GPUs (ROCm/HIP) + non-NVIDIA datacenter

  • Serving engines like vLLM can run with HIP backends; quantization support is more kernel-dependent and can be narrower than NVIDIA.
  • PyTorch torch.compile (Inductor) is a good “graph+kernel” optimization baseline across NVIDIA/AMD/Intel GPUs (via Triton): API: torch.compile · Guide: torch.compiler docs

CPUs (x86 + ARM servers)

First levers:

  • smaller model (distill) and/or weight-only quantization (INT8/INT4).

Runtimes:

Apple silicon (laptop / mobile-class SoC)

Android / Qualcomm / embedded SoCs

“Runs everywhere” local inference engines

These are often the fastest way to get something working across laptops + edge boxes:


5. Beyond LLMs: VLMs + diffusion model optimization in robotics

VLMs / VLAs

  • Optimize each submodule separately (vision encoder, LLM, action head) and re-profile end-to-end.
  • Watch for non-model bottlenecks: image decode, resizing, tokenization, simulator/robot loop.

Diffusion / image generation

Model-side levers (reduce steps or faster steps):

Distillation frameworks (end-to-end model compression):

Serving-time optimization (inference-time caching, not training-based):

  • PyTorch-native and Flexible Inference Engine with Hybrid Cache Acceleration and Parallelism for DiTs: Cache-DiT Repo: vipshop/cache-dit
  • From Instantaneous to Average Velocity for Accelerating Flow Matching Inference: MeanCache Paper: ICLR 2026 · Repo: UnicomAI/MeanCache

6. Minimal “what should I do first?” decision tree

  1. Profile and label the bottleneck: weights vs KV cache vs kernels vs scheduling.
  2. If decode/VRAM dominates -> start with weight-only INT4/FP4, but only if your runtime supports it well.
  3. If long context/concurrency dominates -> fix paged KV + chunked prefill, then consider KV cache quantization.
  4. If prefill compute dominates -> better kernels (FlashAttention/FlashInfer) + compile (torch.compile / TensorRT).
  5. If you still can’t hit constraints -> distill (often the only way to cut both compute and memory).

7. Closing remarks

  • Model optimization and deployment optimization are inseparable.
  • Most “wins” come from matching a compression method to the runtime’s kernels.
  • Treat it as a feedback loop: profile -> compress -> compile -> measure -> iterate.