Formatron v0.5.0
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
vllm.py
Go to the documentation of this file.
1"""
2This module integrates the vllm library by providing convenience utilities.
3"""
4import collections.abc
5import typing
6import kbnf
7import torch
8from vllm import LLM
9from formatron.config import EngineGenerationConfig
10from formatron.formatter import FormatterBase, FormatterBuilder
11from formatron.integrations.utils import get_original_characters, get_fastest_compatible_logits_mask_fn,get_bit_mask
12from vllm.transformers_utils.tokenizer import AnyTokenizer
13
14
16 """
17 Logit processor that uses formatters to mask batch logits.
18 """
20 def __init__(self, formatters: typing.Sequence[FormatterBase | None], eos_token_id: int,
21 configs: typing.Sequence[EngineGenerationConfig] | None = None):
22 self._formatters = formatters
23 self._eos_token_id = eos_token_id
25 if configs is None:
26 configs = [EngineGenerationConfig() for _ in formatters]
27 assert len(configs) == len(formatters), \
28 f"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
29 self._configs = configs
30 self._iter = zip(self._formatters, self._configs)
32 self._mask_logits_fn = get_fastest_compatible_logits_mask_fn()
33 self._bit_masks = []
35 @property
36 def formatters_captures(self) -> list[dict[str, typing.Any] | None]:
37 return [f.captures if f is not None else None for f in self._formatters]
38
39 def is_completed(self) -> list[bool | None]:
40 """
41 Check if the formatters are completed. Each boolean in the list corresponds to the
42 completion status of the formatter at the same index.
43 """
44 return [f.is_completed() if f is not None else None for f in self._formatters]
45
46 def reset(self) -> None:
47 for f in self._formatters:
48 if f is not None:
49 f.reset()
52 self._bit_masks.clear()
54 def _to_next_batch_step(self):
55 self._iter = zip(self._formatters, self._configs)
56 self._bit_mask_iter = iter(self._bit_masks)
57 self._debug_counter = 0
58
59 def __call__(self, prompt, generated_tokens, logits):
60 result = next(self._iter, None)
61 if result is None and len(generated_tokens) == self._last_input_id_length:
62 # We exhausted all formatters but still have sequences to process in this batch
63 raise ValueError(f"Batch size {self._debug_counter} "
64 f"is greater than number of formatters({len(self._formatters)})!")
65 bit_mask = False
66 if len(generated_tokens) == 0: # First iteration
67 self._debug_counter += 1
68 formatter, config = result
69 self._bit_masks.append(get_bit_mask(logits))
70 if formatter is None:
71 return logits
72 if config.reset_at_beginning and formatter.is_completed():
73 formatter.reset()
74 if config.read_prompt:
75 for token in prompt:
76 formatter.accept_token(token)
77 bit_mask = self._bit_masks[-1]
78 elif len(generated_tokens) == self._last_input_id_length + 1: # to next batch step
79 assert result is None, (f"Batch size {self._debug_counter} "
80 f"is less than number of formatters({len(self._formatters)})!")
82 result = next(self._iter)
83 self._last_input_id_length += 1
84 if bit_mask is False:
85 bit_mask = next(self._bit_mask_iter)
86 formatter, _ = result
87 if formatter is None:
88 return logits
89 while formatter.is_completed():
90 if generated_tokens[-1] == self._eos_token_id:
91 return logits
92 formatter, _ = next(self._iter)
93 if formatter is None:
94 return logits
95 if len(generated_tokens) != 0: # accept new token
96 input_id = generated_tokens[-1]
97 if not formatter.is_completed():
98 formatter.accept_token(input_id)
99 if formatter.is_completed():
100 logits[:] = float("-inf")
101 logits[self._eos_token_id] = 1000
102 return logits
103 formatter.compute_allowed_tokens()
104 logits = self._mask_logits_fn(bit_mask, formatter, logits)
105 return logits
106
107
108def create_engine_vocabulary(tokenizer: AnyTokenizer,
109 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary:
110 """
111 Create a vocabulary for the KBNF engine.
112 Args:
113 tokenizer: The tokenizer.
114 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
115 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
116 """
117 vocab = tokenizer.get_vocab()
118 new_vocab = get_original_characters(vocab, vocab_processors)
119 return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()}, {
120 v: k for k, v in vocab.items()})
121
122
124 formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
125 configs: typing.Sequence[EngineGenerationConfig] = None,
126 vocab_processors: typing.Optional[list[typing.Callable]] = None) \
127 -> FormattersLogitsProcessor:
128 """
129 Create a formatter logits processor.
130 Args:
131 llm: The LLM.
132 formatter_builders: The formatter builders.
133 configs: The engine generation configurations.
134 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
135 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
136 """
137 tokenizer = llm.get_tokenizer()
138 vocab = create_engine_vocabulary(tokenizer, vocab_processors)
139 if not isinstance(formatter_builders, collections.abc.Sequence):
140 formatter_builders = [formatter_builders]
141 formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None
142 for i in formatter_builders]
143 return FormattersLogitsProcessor(formatters, tokenizer.eos_token_id, configs)
Configuration for how an KBNF engine should be used in text generation.
Definition config.py:14
Logit processor that uses formatters to mask batch logits.
Definition vllm.py:19
__init__(self, typing.Sequence[FormatterBase|None] formatters, int eos_token_id, typing.Sequence[EngineGenerationConfig]|None configs=None)
Definition vllm.py:22
__call__(self, prompt, generated_tokens, logits)
Definition vllm.py:69
list[bool|None] is_completed(self)
Check if the formatters are completed.
Definition vllm.py:53
list[dict[str, typing.Any]|None] formatters_captures(self)
Definition vllm.py:44
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
kbnf.Vocabulary create_engine_vocabulary(AnyTokenizer tokenizer, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a vocabulary for the KBNF engine.
Definition vllm.py:126
FormattersLogitsProcessor create_formatters_logits_processor(LLM llm, typing.Sequence[FormatterBuilder|None]|FormatterBuilder formatter_builders, typing.Sequence[EngineGenerationConfig] configs=None, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a formatter logits processor.
Definition vllm.py:146