Skip to Content

vLLM TPU’s Unified Backend is Revolutionizing LLM Inference

Seamless LLM Inference on TPUs

The latest vLLM TPU release is enabling developers to run open-source LLMs on TPUs with unmatched performance and flexibility. Powered by the tpu-inference backend, this innovation ensures a smooth, high-throughput experience whether you’re working in PyTorch or JAX, all while maintaining a consistent interface.

Overcoming Core Technical Challenges

The vLLM team’s journey began with building a robust TPU backend in time for Cloud Next 2025. They tackled three major challenges:

  • Advanced Attention Kernels: Implementation of the ragged paged attention (RPA v2) kernel optimized chunked prefill and prefix caching for TPUs without sacrificing vLLM’s attention paradigm.

  • Programming Model Alignment: Reconciling vLLM’s MPMD approach with TPU’s SPMD model led to more efficient data handling and communication.

  • PyTorch/XLA Optimization: Running PyTorch code natively on TPU via PyTorch/XLA required significant kernel and stack-level enhancements for peak performance.

The result? Up to 3.6x faster throughput for Llama 3.1-8B and 2.1x for Llama 3.1-70B on cutting-edge TPUs—a huge leap for open-source inference.

Unified Support: PyTorch and JAX via tpu-inference

The new tpu-inference backend brings PyTorch and JAX together through a single JAX→XLA lowering path. Using Torchax, PyTorch users run their models on TPU with zero code changes, while JAX’s mature graph generation ensures optimal performance for all frameworks.

This unified approach results in:

  • All models lowered with JAX, unlocking about 20% higher throughput even for PyTorch definitions.

  • One streamlined installation, just a single pip command for both PyTorch and JAX workflows.

  • Automatic routing to TPU-optimized code where available, or fallback to PyTorch through Torchax as needed.

For most users, this means frictionless model serving, regardless of framework. The team can still reimplement select models natively in JAX for further hardware-specific gains.

Breakthrough: Ragged Paged Attention v3

The debut of Ragged Paged Attention v3 (RPA v3) marks a significant step forward for TPU inference. This open-source kernel supports arbitrary head dimensions, various quantization types, and tensor-parallelism. Notable features include:

  • Fused KV Cache Updates to streamline pipelines and reduce latency.

  • Flexible batch type support (prefill-only, decode-only, mixed) for optimal compute and memory usage.

  • Up to 10% further throughput gains over previous versions, without sacrificing versatility.

RPA v3 establishes a new reference point for attention kernels in open-source projects and lays the groundwork for future TPU-centric innovations like Mixture-of-Experts (MoE).

Deeper TPU Integration with Native SPMD

This release signals a strategic move from GPU-like multi-worker paradigms to Single Program, Multi-Data (SPMD), fully leveraging TPU and XLA-native workflows. With SPMD, developers write code for a single device, letting the compiler partition and optimize tasks across the hardware, improving both computation and communication.

This deep integration ensures scalable, high-performance inference, cementing vLLM’s role as the TPU inference library of choice.

Performance Milestones and Future Directions

From its early 2025 prototype, vLLM TPU has doubled—or even quintupled—performance on key workloads and broadened its model and feature support. The latest release brings:

  • Dense and multimodal model support (with multimodal via tpu-inference)
  • Compatibility with Trillium (v6e) and v5e TPUs
  • Prefix caching, chunked prefill, structured decoding, and speculative sampling
  • Comprehensive quantization and cutting-edge kernel optimizations
  • Experimental multimodal inference and v5p TPU support

Looking ahead, the roadmap includes sparsecore offloading, advanced speculative decoding, distributed serving, and RL integrations—ensuring vLLM TPU remains at the forefront of open-source LLM inference.

Key Takeaway

vLLM TPU’s unified backend, powered by tpu-inference, redefines performance and flexibility for LLM inference on TPUs. By uniting PyTorch and JAX under a single, optimized path and innovating in attention kernels, it empowers the AI community to achieve more, faster. Explore the documentation and contribute to shape the future of large model inference on TPUs.

Source: vLLM Blog


vLLM TPU’s Unified Backend is Revolutionizing LLM Inference
Joshua Berkowitz October 18, 2025
Views 10692
Share this post