Skip to content

feat(causal_llm): add LogitsProcessor hook for guidance / constrained decoding#2381

Open
JulienBalianSonos wants to merge 1 commit into
mainfrom
feat/causal-llm-logits-processor
Open

feat(causal_llm): add LogitsProcessor hook for guidance / constrained decoding#2381
JulienBalianSonos wants to merge 1 commit into
mainfrom
feat/causal-llm-logits-processor

Conversation

@JulienBalianSonos

Copy link
Copy Markdown
Collaborator

What

Adds a small, generic hook to CausalLlmState so callers can adjust the
next-token logits before selection — the building block for constrained
decoding / guidance (grammar or JSON-Schema masking, logit biasing, custom
sampling), without baking any such logic into causal_llm.

/// Adjust next-token logits before selection (constrained decoding / guidance).
pub trait LogitsProcessor {
    fn process(&mut self, logits: &mut [f32], tokens: &[u32]);
}

impl CausalLlmState {
    pub fn generate_next_token(&mut self) -> Result<()>;                  // unchanged
    pub fn generate_next_token_with(&mut self, p: &mut dyn LogitsProcessor) -> Result<()>; // new
}
  • LogitsProcessor::process is called once per token, after the repeat
    penalty and just before argmax, with the full token sequence so far.
  • A blanket impl makes any FnMut(&mut [f32], &[u32]) a LogitsProcessor, and
    NoLogitsProcessor is a no-op default.
  • generate_next_token() now simply delegates to generate_next_token_with(&mut NoLogitsProcessor), so existing behavior and the public API are unchanged.

Why

The sampler does argmax internally with no way to influence it, so downstream
projects can't do grammar/JSON-constrained decoding. This is the minimal seam
to enable it.

We've used it downstream to drive llguidance
for guaranteed JSON / JSON-Schema / Lark-grammar / regex output. None of that
lands here
causal_llm gains no dependency; the hook is engine-agnostic.

Design notes

  • Minimal & non-invasive: no struct fields added, so Debug/freeze/Send of
    CausalLlmState are untouched. The processor is passed per call.
  • Masking pattern: set disallowed tokens' logits to f32::NEG_INFINITY, then
    the existing argmax picks the best allowed token.

Tests

Added unit tests (no model needed): a closure masks a token; NoLogitsProcessor
leaves logits unchanged. cargo check -p causal_llm passes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant