· 15 min read developer

LLM Inference in Rust: Building a Modular Runtime with unillm

A deep dive into unillm's three-layer architecture — TensorCore, ModelCore, and WeightLoaderCore — and how Rust's type system enables a modular inference runtime supporting 47 model architectures.

rustllminferencesystems-programmingunillm

Most LLM inference runtimes are monoliths. They couple tensor operations to specific hardware, fuse model definitions with weight loading, and make adding a new architecture a multi-thousand-line ordeal. unillm takes a different approach: a modular Rust runtime where each concern lives behind a clean trait boundary. It supports 47 model architectures today, and adding the 48th requires implementing a single trait.

This post walks through unillm’s architecture, the Rust patterns that make it work, and how to extend it with your own models.

Why Rust for LLM Inference

The case for Rust in inference is not about performance alone — it is about correctness under pressure. Inference runtimes juggle GPU memory, KV caches, concurrent requests, and multiple weight formats. The failure modes are subtle: use-after-free on a CUDA buffer, a data race in the batch scheduler, a silent shape mismatch during attention computation.

Rust addresses these at compile time:

  • No garbage collector. Inference is latency-sensitive. GC pauses during token generation are unacceptable. Rust’s ownership model gives you deterministic deallocation without stop-the-world pauses.
  • Zero-cost abstractions. The Model trait, TensorOps trait, and WeightLoader trait all compile down to static dispatch. There is no vtable indirection on the hot path unless you explicitly opt in with dyn.
  • Memory safety without runtime cost. The borrow checker prevents double-free and use-after-free bugs in GPU buffer management — the kind of bugs that manifest as silent corruption in C++ inference code.
  • Type-safe configuration. The model_config! macro generates strongly-typed config structs from model metadata. A missing num_attention_heads field is a compile error, not a runtime KeyError.

Project Structure

unillm is organized into four crates:

crates/
  runtime/       TensorCore, ModelCore, WeightLoaderCore — the foundation
  inference/     High-level inference engine with generation loops
  kv/            Hybrid KV cache (RadixAttention + PagedAttention)
  scheduler/     Request scheduling with continuous batching

Everything composes through traits. The runtime crate defines the abstractions; inference, kv, and scheduler build on top of them.

The Three-Layer Architecture

unillm’s core insight is that inference can be decomposed into three orthogonal layers. Each layer is a Rust trait with a single responsibility.

Layer 1: TensorCore

TensorCore provides device-agnostic tensor operations. Whether you are running on CPU, CUDA, or Metal, the calling code is identical:

use unillm_runtime::tensor::{Tensor, ops_fn};

// These ops dispatch to the correct backend at compile time
let q = ops_fn::linear(&hidden_states, &self.q_proj_weight, self.q_proj_bias.as_ref())?;
let k = ops_fn::linear(&hidden_states, &self.k_proj_weight, self.k_proj_bias.as_ref())?;
let v = ops_fn::linear(&hidden_states, &self.v_proj_weight, self.v_proj_bias.as_ref())?;

// Attention computation
let attn_weights = ops_fn::matmul(&q, &k.transpose(-2, -1)?)?;
let attn_weights = ops_fn::softmax(&(attn_weights / head_dim_sqrt)?, -1)?;
let attn_output = ops_fn::matmul(&attn_weights, &v)?;

The ops_fn module is the key abstraction. Every tensor operation — matmul, softmax, linear, layer_norm, rope, silu — goes through this module. Backend selection happens at build time via Cargo features:

[features]
default = ["cpu"]
cpu = []
cuda = ["dep:cudarc"]
metal = ["dep:metal-rs"]

This means a CPU-only build carries zero GPU dependencies. A CUDA build compiles the CUDA kernels and links against cuBLAS. The model code never changes.

Layer 2: ModelCore

ModelCore defines the Model trait — the universal interface every architecture implements:

pub trait Model: Send + Sync {
    type Config: ModelConfig;

    /// Run a forward pass, returning logits
    fn forward(&self, inputs: &ModelInputs) -> Result<ModelOutputs>;

    /// Autoregressive generation with sampling
    fn generate(&self, inputs: &ModelInputs, params: &GenerateParams) -> Result<Vec<Token>> {
        // Default implementation: repeated forward() with sampling
        let mut tokens = inputs.input_ids.clone();
        for _ in 0..params.max_new_tokens {
            let outputs = self.forward(&ModelInputs::from_tokens(&tokens))?;
            let next_token = params.sampler.sample(&outputs.logits)?;
            if next_token == self.config().eos_token_id() {
                break;
            }
            tokens.push(next_token);
        }
        Ok(tokens)
    }

    fn config(&self) -> &Self::Config;
    fn load(loader: &dyn WeightLoader, config: Self::Config) -> Result<Self> where Self: Sized;
}

The generate() method has a default implementation that works for any autoregressive model. Models can override it for architecture-specific optimizations (e.g., speculative decoding in Medusa-style models).

Configuration is handled by the model_config! macro:

model_config!(LlamaConfig {
    vocab_size: usize = 32000,
    hidden_size: usize = 4096,
    intermediate_size: usize = 11008,
    num_hidden_layers: usize = 32,
    num_attention_heads: usize = 32,
    num_key_value_heads: usize = 32,     // GQA support
    max_position_embeddings: usize = 4096,
    rms_norm_eps: f64 = 1e-5,
    rope_theta: f64 = 10000.0,
    rope_scaling: Option<RopeScaling> = None,
});

This macro generates a struct that implements ModelConfig, with from_json() deserialization, default values, and validation. It also generates builder methods, so you can override individual fields in tests or experiments.

Layer 3: WeightLoaderCore

Weight files come in at least three formats in the wild. WeightLoaderCore abstracts over all of them:

pub trait WeightLoader: Send + Sync {
    /// Load a named tensor from the weight file
    fn load_tensor(&self, name: &str) -> Result<Tensor>;

    /// Check if a tensor exists
    fn has_tensor(&self, name: &str) -> bool;

    /// List all tensor names
    fn tensor_names(&self) -> Vec<String>;

    /// Get metadata (architecture, tokenizer config, etc.)
    fn metadata(&self) -> &WeightMetadata;
}

Three implementations ship with unillm:

FormatLoaderNotes
SafeTensorsSafeTensorsLoaderZero-copy mmap, the preferred format
GGUFGgufLoaderQuantized models, includes embedded metadata
PyTorchPytorchLoaderLegacy .bin files, pickle deserialization

The SafeTensorsLoader deserves a note: it uses memory-mapped I/O to avoid loading the entire weight file into memory. Tensors are read on demand and can be shared across processes via the OS page cache. For a 70B model with 140 GB of weights, this is the difference between needing 140 GB of RAM and needing only what the model actively uses.

Adding a New Model Architecture

Here is what it takes to add a new model. Suppose you are implementing a hypothetical “Falcon2” architecture:

use unillm_runtime::prelude::*;

// 1. Define configuration
model_config!(Falcon2Config {
    vocab_size: usize = 65024,
    hidden_size: usize = 4544,
    num_hidden_layers: usize = 32,
    num_attention_heads: usize = 71,
    num_kv_heads: usize = 1,           // MQA
    layer_norm_epsilon: f64 = 1e-5,
    alibi: bool = true,
});

// 2. Define model struct
pub struct Falcon2Model {
    config: Falcon2Config,
    embed_tokens: Tensor,
    layers: Vec<Falcon2DecoderLayer>,
    ln_f: Tensor,
    lm_head: Tensor,
}

// 3. Implement Model trait
impl Model for Falcon2Model {
    type Config = Falcon2Config;

    fn forward(&self, inputs: &ModelInputs) -> Result<ModelOutputs> {
        let mut hidden = ops_fn::embedding(&inputs.input_ids, &self.embed_tokens)?;

        for layer in &self.layers {
            hidden = layer.forward(&hidden, inputs.attention_mask.as_ref())?;
        }

        hidden = ops_fn::layer_norm(&hidden, &self.ln_f, None, self.config.layer_norm_epsilon)?;
        let logits = ops_fn::linear(&hidden, &self.lm_head, None)?;

        Ok(ModelOutputs { logits })
    }

    fn config(&self) -> &Falcon2Config { &self.config }

    fn load(loader: &dyn WeightLoader, config: Falcon2Config) -> Result<Self> {
        Ok(Self {
            config,
            embed_tokens: loader.load_tensor("transformer.word_embeddings.weight")?,
            layers: (0..config.num_hidden_layers)
                .map(|i| Falcon2DecoderLayer::load(loader, &config, i))
                .collect::<Result<Vec<_>>>()?,
            ln_f: loader.load_tensor("transformer.ln_f.weight")?,
            lm_head: loader.load_tensor("lm_head.weight")?,
        })
    }
}

That is it. The model automatically gets generate(), works with all three weight formats, runs on any supported device, and plugs into the KV cache and scheduler.

KV Cache: RadixAttention + PagedAttention Hybrid

The kv crate implements a hybrid caching strategy that combines two approaches:

PagedAttention (from vLLM) allocates KV cache memory in fixed-size blocks rather than contiguous buffers. This eliminates fragmentation — you do not need to pre-allocate worst-case sequence length for every request. When a request finishes, its blocks return to the free pool immediately.

RadixAttention adds prefix sharing on top of paged memory. If two requests share a system prompt (common in chat applications), their KV cache entries for that prefix are shared via a radix tree. The tree structure enables O(prefix_length) lookup for cache hits.

// The hybrid cache checks the radix tree first, falls back to fresh allocation
let cache_entry = self.radix_tree
    .find_prefix(&token_sequence)
    .unwrap_or_else(|| self.page_allocator.allocate(num_blocks));

In practice, this hybrid approach reduces KV cache memory by 30-60% for multi-turn chat workloads where conversations share system prompts or common prefixes.

Continuous Batching Scheduler

The scheduler crate implements iteration-level scheduling. Unlike static batching (where you wait for a full batch before processing), continuous batching adds and removes requests at every forward pass:

pub struct ContinuousBatchScheduler {
    waiting: VecDeque<Request>,
    running: Vec<ActiveRequest>,
    max_batch_tokens: usize,
    max_batch_size: usize,
}

impl Scheduler for ContinuousBatchScheduler {
    fn schedule(&mut self) -> ScheduleBatch {
        // Remove finished requests
        self.running.retain(|r| !r.is_finished());

        // Fill available capacity with waiting requests
        while let Some(req) = self.waiting.front() {
            let new_tokens = self.running.iter().map(|r| r.num_tokens()).sum::<usize>()
                + req.input_len();
            if new_tokens > self.max_batch_tokens
                || self.running.len() >= self.max_batch_size {
                break;
            }
            let req = self.waiting.pop_front().unwrap();
            self.running.push(ActiveRequest::new(req));
        }

        ScheduleBatch::from_active(&self.running)
    }
}

The scheduler also handles preemption: if a long-running request monopolizes GPU memory, the scheduler can swap its KV cache to CPU and prioritize shorter requests. This keeps p99 latency bounded even under load.

Supported Model Architectures

unillm supports 47 architectures across 10 categories:

CategoryArchitectures
Core LLMsLLaMA, LLaMA-2, LLaMA-3, Qwen, Qwen2, Gemma, Gemma-2, Phi, Phi-3, DeepSeek, DeepSeek-V2, Mistral, Mixtral
GPT FamilyGPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT
CodeStarCoder, StarCoder2, CodeLlama
MoEDeepSeek-MoE, DBRX, Grok-1, Arctic, Jamba
RWKV / RecurrentRWKV-4, RWKV-6, RecurrentGemma
Vision-LanguageQwen2-VL, Phi-3-Vision, InternVL, CogVLM, LLaVA, CLIP
Audio / SpeechWav2Vec2, HuBERT, MusicGen, Encodec, Whisper
EncoderBERT, RoBERTa, T5, FLAN-T5
SpecializedMamba, Mamba-2, MiniCPM, OLMo, Granite
EmbeddingE5, BGE, GTE

Every architecture implements the same Model trait. Swapping LLaMA for Qwen or Whisper for Wav2Vec2 is a configuration change, not a code change.

Getting Started

Clone and run:

git clone https://github.com/cognisoc/unillm.git
cd unillm

# Generate text with TinyLlama (downloads on first run)
cargo run --bin unillm -p unillm-runtime -- generate --prompt "Explain gravity in one sentence"

# Use a specific model
cargo run --bin unillm -p unillm-runtime -- generate \
    --model mistral:7b-instruct \
    --prompt "Write a Rust function that reverses a string"

# Run with CUDA
cargo run --bin unillm -p unillm-runtime --features cuda -- generate \
    --model llama3:8b \
    --prompt "What is the capital of France?"

# List downloaded models
cargo run --bin unillm -p unillm-runtime -- models

# Serve over HTTP
cargo run --bin unillm -p unillm-runtime -- serve --model llama3:8b --port 8080

For library usage in your own Rust project:

[dependencies]
unillm-runtime = { git = "https://github.com/cognisoc/unillm.git" }
unillm-inference = { git = "https://github.com/cognisoc/unillm.git" }
use unillm_runtime::prelude::*;
use unillm_inference::Engine;

fn main() -> Result<()> {
    let engine = Engine::builder()
        .model("llama3:8b")
        .device(Device::Cuda(0))
        .build()?;

    let response = engine.generate("Explain ownership in Rust", GenerateParams::default())?;
    println!("{}", response.text);
    Ok(())
}

Closing Thoughts

The value of unillm is not in any single model implementation — it is in the boundaries between layers. TensorCore does not know about models. ModelCore does not know about weight formats. The KV cache does not know about scheduling. Each piece is testable, replaceable, and comprehensible in isolation.

Rust makes these boundaries enforceable. A Model implementation cannot accidentally bypass the tensor abstraction because the type system will not let it. A WeightLoader cannot return a tensor on the wrong device because Device is part of the type. These are not conventions documented in a README — they are compile-time guarantees.

If you want to add a model, start with the Model trait. If you want to add a backend, start with ops_fn. If you want to experiment with caching strategies, the kv crate is self-contained. The architecture is designed so you never have to understand the whole system to change one part of it.

Check out the repository and open an issue if you have questions.