Formatron v0.4.9
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
vllm.py
Go to the documentation of this file.
1"""
2This module integrates the vllm library by providing convenience utilities.
3"""
4import collections.abc
5import typing
6import kbnf
7from vllm import LLM
8from formatron.config import EngineGenerationConfig
9from formatron.formatter import FormatterBase, FormatterBuilder
10from formatron.integrations.utils import get_original_characters
11from vllm.transformers_utils.tokenizer import AnyTokenizer
12
13
15 """
16 Logit processor that uses formatters to mask batch logits.
17 """
19 def __init__(self, formatters: typing.Sequence[FormatterBase | None], eos_token_id: int,
20 configs: typing.Sequence[EngineGenerationConfig] | None = None):
21 self._formatters = formatters
22 self._eos_token_id = eos_token_id
24 if configs is None:
25 configs = [EngineGenerationConfig() for _ in formatters]
26 assert len(configs) == len(formatters), \
27 f"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
28 self._configs = configs
29 self._iter = zip(self._formatters, self._configs)
32 @property
33 def formatters_captures(self) -> list[dict[str, typing.Any] | None]:
34 return [f.captures if f is not None else None for f in self._formatters]
35
36 def is_completed(self) -> list[bool | None]:
37 """
38 Check if the formatters are completed. Each boolean in the list corresponds to the
39 completion status of the formatter at the same index.
40 """
41 return [f.is_completed() if f is not None else None for f in self._formatters]
42
43 def reset(self) -> None:
44 for f in self._formatters:
45 if f is not None:
46 f.reset()
49
51 self._iter = zip(self._formatters, self._configs)
52 self._debug_counter = 0
54 def __call__(self, prompt, generated_tokens, logits):
55 result = next(self._iter, None)
56 if result is None and len(generated_tokens) == self._last_input_id_length:
57 # We exhausted all formatters but still have sequences to process in this batch
58 raise ValueError(f"Batch size {self._debug_counter} "
59 f"is greater than number of formatters({len(self._formatters)})!")
60 if len(generated_tokens) == 0: # First iteration
61 self._debug_counter += 1
62 formatter, config = result
63 if formatter is None:
64 return logits
65 if config.reset_at_beginning and formatter.is_completed():
66 formatter.reset()
67 if config.read_prompt:
68 for token in prompt:
69 formatter.accept_token(token)
70 elif len(generated_tokens) == self._last_input_id_length + 1: # to next batch step
71 assert result is None, (f"Batch size {self._debug_counter} "
72 f"is less than number of formatters({len(self._formatters)})!")
74 result = next(self._iter)
75 self._last_input_id_length += 1
76 formatter, _ = result
77 if formatter is None:
78 return logits
79 while formatter.is_completed():
80 if generated_tokens[-1] == self._eos_token_id:
81 return logits
82 formatter, _ = next(self._iter)
83 if formatter is None:
84 return logits
85 if len(generated_tokens) != 0: # accept new token
86 input_id = generated_tokens[-1]
87 if not formatter.is_completed():
88 formatter.accept_token(input_id)
89
90 if formatter.is_completed():
91 logits[:] = float("-inf")
92 logits[self._eos_token_id] = 1000
93 return logits
94 formatter.compute_allowed_tokens()
95 logits = formatter.mask_logits(logits)
96 return logits
97
98
99def create_engine_vocabulary(tokenizer: AnyTokenizer,
100 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary:
101 """
102 Create a vocabulary for the KBNF engine.
103 Args:
104 tokenizer: The tokenizer.
105 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
106 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
107 """
108 vocab = tokenizer.get_vocab()
109 new_vocab = get_original_characters(vocab, vocab_processors)
110 return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()}, {
111 v: k for k, v in vocab.items()})
112
113
115 formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
116 configs: typing.Sequence[EngineGenerationConfig] = None,
117 vocab_processors: typing.Optional[list[typing.Callable]] = None) \
118 -> FormattersLogitsProcessor:
119 """
120 Create a formatter logits processor.
121 Args:
122 llm: The LLM.
123 formatter_builders: The formatter builders.
124 configs: The engine generation configurations.
125 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
126 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
127 """
128 tokenizer = llm.get_tokenizer()
129 vocab = create_engine_vocabulary(tokenizer, vocab_processors)
130 if not isinstance(formatter_builders, collections.abc.Sequence):
131 formatter_builders = [formatter_builders]
132 formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None
133 for i in formatter_builders]
134 return FormattersLogitsProcessor(formatters, tokenizer.eos_token_id, configs)
Configuration for how an KBNF engine should be used in text generation.
Definition config.py:14
Logit processor that uses formatters to mask batch logits.
Definition vllm.py:18
__init__(self, typing.Sequence[FormatterBase|None] formatters, int eos_token_id, typing.Sequence[EngineGenerationConfig]|None configs=None)
Definition vllm.py:21
__call__(self, prompt, generated_tokens, logits)
Definition vllm.py:64
list[bool|None] is_completed(self)
Check if the formatters are completed.
Definition vllm.py:50
list[dict[str, typing.Any]|None] formatters_captures(self)
Definition vllm.py:41
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
kbnf.Vocabulary create_engine_vocabulary(AnyTokenizer tokenizer, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a vocabulary for the KBNF engine.
Definition vllm.py:117
FormattersLogitsProcessor create_formatters_logits_processor(LLM llm, typing.Sequence[FormatterBuilder|None]|FormatterBuilder formatter_builders, typing.Sequence[EngineGenerationConfig] configs=None, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a formatter logits processor.
Definition vllm.py:137