Formatron v0.4.2
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
vllm.py
Go to the documentation of this file.
1"""
2This module integrates the vllm library by providing convenience utilities.
3"""
4import collections.abc
5import time
6import typing
7import kbnf
8import torch
9from vllm import LLM
10from formatron.config import EngineGenerationConfig
11from formatron.formatter import FormatterBase, FormatterBuilder
12from formatron.integrations._utils import get_original_characters
13
14
16 """
17 Logit processor that uses formatters to mask batch logits.
18 """
20 def __init__(self, formatters: typing.Sequence[FormatterBase | None], eos_token_id: int,
21 configs: typing.Sequence[EngineGenerationConfig] | None = None):
22 self._formatters = formatters
23 self._eos_token_id = eos_token_id
25 if configs is None:
26 configs = [EngineGenerationConfig() for _ in formatters]
27 assert len(configs) == len(formatters), \
28 f"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
29 self._configs = configs
30 self._iter = zip(self._formatters, self._configs)
33 @property
34 def formatters_captures(self) -> list[dict[str, typing.Any] | None]:
35 return [f.captures if f is not None else None for f in self._formatters]
36
37 def is_completed(self) -> list[bool | None]:
38 """
39 Check if the formatters are completed. Each boolean in the list corresponds to the
40 completion status of the formatter at the same index.
41 """
42 return [f.is_completed() if f is not None else None for f in self._formatters]
43
44 def reset(self) -> None:
45 for f in self._formatters:
46 if f is not None:
47 f.reset()
50
52 self._iter = zip(self._formatters, self._configs)
53 self._debug_counter = 0
55 def __call__(self, prompt, generated_tokens, logits):
56 result = next(self._iter, None)
57 if result is None and len(generated_tokens) == self._last_input_id_length:
58 # We exhausted all formatters but still have sequences to process in this batch
59 raise ValueError(f"Batch size {self._debug_counter} "
60 f"is greater than number of formatters({len(self._formatters)})!")
61 if len(generated_tokens) == 0: # First iteration
62 self._debug_counter += 1
63 formatter, config = result
64 if formatter is None:
65 return logits
66 if config.reset_at_beginning and formatter.is_completed():
67 formatter.reset()
68 if config.read_prompt:
69 for token in prompt:
70 formatter.accept_token(token)
71 elif len(generated_tokens) == self._last_input_id_length + 1: # to next batch step
72 assert result is None, (f"Batch size {self._debug_counter} "
73 f"is less than number of formatters({len(self._formatters)})!")
75 result = next(self._iter)
76 self._last_input_id_length += 1
77 formatter, _ = result
78 if formatter is None:
79 return logits
80 while formatter.is_completed():
81 if generated_tokens[-1] == self._eos_token_id:
82 return logits
83 formatter, _ = next(self._iter)
84 if formatter is None:
85 return logits
86 if len(generated_tokens) != 0: # accept new token
87 input_id = generated_tokens[-1]
88 if input_id != self._eos_token_id:
89 formatter.accept_token(input_id)
90
91 if formatter.is_completed():
92 logits[:] = float("-inf")
93 logits[self._eos_token_id] = 1000
94 return logits
95 formatter.compute_allowed_tokens()
96 logits = formatter.mask_logits(logits)
97 return logits
98
99
100def create_engine_vocabulary(llm: LLM) -> kbnf.Vocabulary:
101 """
102 Create a vocabulary for the KBNF engine.
103 """
104 tokenizer = llm.get_tokenizer()
105 vocab = tokenizer.get_vocab()
106 new_vocab = get_original_characters(vocab)
107 return kbnf.Vocabulary({v: kbnf.Token(k) for k, v in new_vocab.items()}, {
108 v: k for k, v in vocab.items()})
109
110
112 formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
113 configs: typing.Sequence[EngineGenerationConfig] = None) \
114 -> FormattersLogitsProcessor:
115 """
116 Create a formatter logits processor.
117 """
118 tokenizer = llm.get_tokenizer()
119 vocab = create_engine_vocabulary(llm)
120 if not isinstance(formatter_builders, collections.abc.Sequence):
121 formatter_builders = [formatter_builders]
122 formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None
123 for i in formatter_builders]
124 return FormattersLogitsProcessor(formatters, tokenizer.eos_token_id, configs)
Configuration for how an KBNF engine should be used in text generation.
Definition config.py:14
Logit processor that uses formatters to mask batch logits.
Definition vllm.py:19
__init__(self, typing.Sequence[FormatterBase|None] formatters, int eos_token_id, typing.Sequence[EngineGenerationConfig]|None configs=None)
Definition vllm.py:22
__call__(self, prompt, generated_tokens, logits)
Definition vllm.py:65
list[bool|None] is_completed(self)
Check if the formatters are completed.
Definition vllm.py:51
list[dict[str, typing.Any]|None] formatters_captures(self)
Definition vllm.py:42
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
FormattersLogitsProcessor create_formatters_logits_processor(LLM llm, typing.Sequence[FormatterBuilder|None]|FormatterBuilder formatter_builders, typing.Sequence[EngineGenerationConfig] configs=None)
Create a formatter logits processor.
Definition vllm.py:127
kbnf.Vocabulary create_engine_vocabulary(LLM llm)
Create a vocabulary for the KBNF engine.
Definition vllm.py:113