Stable LLM Inference
A mathematical function gives the same output for the same input. LLMs do not. Not even at temperature = 0. Determinism matters for reproducibility, debugging, and for keeping pipelines stable. Without it, identical prompts can yield divergent completions, making experiments irreproducible and systems brittle. We can intervene at three broad layers to improve stability: generation, training, and infrastructure.
Generation Constraints
Beyond temperature or beam-width tuning, we can shape generation behavior directly.
- Constrain Generation. Borrow from schema- or grammar-constrained decoding: restrict the space of valid continuations so the model can only produce syntactically or structurally valid outputs. Fewer admissible continuations mean less variance.
- Context engineering. Or prompt optimization. Define what counts as success—bitwise-identical tokens, consistent field extraction, etc.—and optimize prompts for that goal on unseen data.
- Prefer retrieval over regeneration. When possible, bypass generation entirely. For repeated or semantically equivalent inputs, retrieve and reuse the cached, validated answer instead of regenerating it. This turns an unstable generative step into a deterministic lookup.
To stabilize outputs for similar inputs, options include
- Reduce input entropy. Predictive typing, menus, or FAQs keep user inputs within a narrow, predictable range.
- Canonicalize inputs. Normalize casing, spelling, field order, or formatting before they reach the model.
- Leverage semantic retrieval. When paraphrases are detected, return cached or canonical responses instead of regenerating them.
Learning Constraints
Let the model’s own distribution become more decisive.
Assume a conditional distribution $P_\theta(y \mid x)$. Even with greedy decoding, if the top few continuations are close, tiny numeric or scheduling jitters can flip the argmax. The fix is to
- Pick a canonical target $y^*$ per input $x$
- Increase the margin $\log P_\theta(y^* \mid x) - \log P_\theta(y \neq y^* \mid x)$ to be comfortably positive across the whole sequence.
Plain SFT—cross-entropy with hard targets–already concentrates probability on $y^*$. To make it more 'stable', we can augment CE with a margin term so the chosen token is not merely highest, but higher by at least $\gamma$:
$$L_{\text{margin}} = \sum_t \max\left(0, \gamma - \left(\ell_{y_t^*} - \max_{k \neq y_t^*} \ell_k\right)\right)$$
where $\ell_k$ are logits at step $t$. This trains the model to keep a cushion between the canonical token and the runner-up—exactly what fights flips. For each $x$, we can also include non-canonical but semantically-valid outputs $\tilde{y}$ as hard negatives and add a ranking loss:
$$L_{\text{rank}} = \log\left(1 + \exp\left(s(x, \tilde{y}) - s(x, y^*)\right)\right)$$
where $s$ is a sequence score (e.g., sum of token logits or a learned reranker).
Infrastructure Constraints (fixes temperature = 0 stochasticity)
Systems produce different outputs because we haven't defined the system (libraries and infrastructure) precisely enough, because we use non-deterministic algorithms, and because computations depend on batch composition. The solutions are:
- Pin the stack: exact GPU model, driver, CUDA/cuBLAS/cuDNN versions, inference engine build, tokenizer build, quantization config.
- Pin precision. Mixed‑precision and different GPU archs can produce subtly different numerics even with the same seed.
- Set deterministic flags: Enable deterministic algorithms and set cuBLAS’ reproducibility env var (it controls internal workspaces/reduction strategies), e.g.,
torch.use_deterministic_algorithms(True)andCUBLAS_WORKSPACE_CONFIG=:4096:8(or:16:8). - Batch invariance. Floating-point addition isn’t associative: ((a + b) + c) can differ slightly from (a + (b + c)) because of rounding. On GPUs, sums are computed in reduction trees whose shape depends on how work is divided among threads, blocks, or tiles. Change that shape, and you change the rounding order—and therefore the final value. In large-language-model inference, this becomes visible because your request is rarely executed alone. Servers dynamically batch multiple user requests and chunk them into segments (for example, prefill “tiles,” paged attention, or prefix-cache merges). Each batching or chunking decision alters how those reductions are scheduled inside kernels, which changes the order of additions and therefore the computed logits. Different reduction orders → slightly different sums → slightly different logits. When the top two logits are close, even a micro-difference can flip the argmax—producing a different next token, even with temperature = 0.