Formatron v0.4.2
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
formatter.py
Go to the documentation of this file.
1"""
2This module contains the Formatter class and its related classes.
3"""
4import abc
5from json import JSONDecodeError
6import re
7import textwrap
8import typing
9from copy import copy
10import kbnf
11from formatron.formats.json import JsonExtractor
12from formatron.schemas.schema import Schema
13from formatron.extractor import Extractor, LiteralExtractor, NonterminalExtractor, ChoiceExtractor, SubstringExtractor
14from formatron.formats.regex import RegexExtractor
15
16
17
18class FormatterBase(abc.ABC):
19 """
20 An abstract Formatter that enforces a format on the string generated by a language model.
21 """
23 @abc.abstractmethod
24 def accept_token(self, token_id: int) -> typing.Any:
25 """
26 Accept a token from the language model.
27 Args:
28 token_id: The token ID.
29 Returns:
30 The result of accepting the token.
31 """
33 @abc.abstractmethod
34 def accept_bytes(self, _bytes: bytes):
35 """
36 Accept a bytes object from the language model.
37 Args:
38 _bytes: The bytes object.
39 """
41 @abc.abstractmethod
42 def compute_allowed_tokens(self) -> None:
43 """
44 Compute the allowed tokens based on the current state.
45 """
47 @abc.abstractmethod
48 def mask_logits(self, logits) -> typing.Any:
49 """
50 Mask the logits based on the current state.
51 Args:
52 logits: The logits to mask.
53 Returns:
54 The masked logits.
55 """
57 @abc.abstractmethod
58 def get_allowed_tokens_since_last_computation(self) -> typing.Sequence[int]:
59 """
60 Get the allowed tokens since the last computation(in other words, the last call to `compute_allowed_tokens`).
61 Returns:
62 The allowed tokens.
63 """
65 @abc.abstractmethod
66 def is_completed(self) -> bool:
67 """
68 Check if the generated string satisfies the format and hence the generation is completed.
69 """
71 @abc.abstractmethod
72 def _on_completion(self, generated_output: str) -> None:
73 """
74 Perform actions when the generation is completed.
75 """
76
77 @property
78 @abc.abstractmethod
79 def captures(self) -> dict[str, typing.Any|None]:
80 """
81 Get the captures from the generated string.
82 """
83
84 @abc.abstractmethod
85 def reset(self) -> None:
86 """
87 Reset the formatter to the initial state.
88 """
89
90
91class Formatter(FormatterBase):
92 """
93 A Formatter that enforces a format on the string generated by a language model. It is designed to compose
94 multiple extractors in a sequential, unambiguous, greedy manner. Check out the Formatter.captures property docs for more details.
95 If you need more complex extraction logic, you need to implement your own Extractor.
96 """
97
98 def __init__(self, extractors: list[Extractor], engine: kbnf.Engine,
99 decode_callback: typing.Callable[[list[int]], str], grammar_str: str):
100 """
101 Initialize the formatter.
102 Args:
103 extractors: The matchers to extract data from the generated string.
104 engine: The KBNF engine to enforce the format.
105 decode_callback: The callback to decode the token IDs to a string.
106 grammar_str: The KBNF grammar string.
107 """
108 self._extractors = extractors
109 self._engine = engine
110 self._token_ids = []
111 self._decode_callback = decode_callback
112 self._grammar_str = grammar_str
113 self._captures = {}
114
115 @property
116 def grammar_str(self):
117 """
118 Get the KBNF grammar string.
119 """
120 return self._grammar_str
121
122 def accept_token(self, token_id: int):
123 result = self._engine.try_accept_new_token(token_id)
124 self._token_ids.append(token_id)
125 if result == kbnf.AcceptTokenResult.Finished:
126 output = self._decode_callback(self._token_ids)
128 return result
129
130 def accept_bytes(self, _bytes: bytes):
131 self._engine.try_accept_new_bytes(_bytes)
132
133 def compute_allowed_tokens(self) -> None:
134 self._engine.compute_allowed_token_ids()
135
136 def mask_logits(self, logits) -> typing.Any:
137 return self._engine.mask_logits(logits)
138
139 def get_allowed_tokens_since_last_computation(self) -> typing.Sequence[int]:
140 return self._engine.get_allowed_token_ids_from_last_computation()
141
142 def is_completed(self) -> bool:
143 """
144 Check if the generation is completed. This means the generation is ended by the engine.
145 If the generation is ended by integration-specific stop conditions like `max_new_tokens`,
146 the generation is not considered completed by this method.
147 """
148 return self._engine.is_finished()
149
150 def _on_completion(self, generated_output: str) -> None:
151 for matcher in self._extractors:
152 result = matcher.extract(generated_output)
153 if result is None:
154 captured = None
155 else:
156 generated_output, captured = matcher.extract(generated_output)
157 if matcher.capture_name:
158 if matcher.capture_name in self._captures:
159 self._captures[matcher.capture_name] = [
160 self._captures[matcher.capture_name]]
161 self._captures[matcher.capture_name].append(captured)
162 else:
163 self._captures[matcher.capture_name] = captured
164
165 @property
166 def captures(self) -> dict[str, typing.Any] | None:
167 """
168 Get the captures from the generated string. Note that the captures are only available for one extractor if:
169 - The extractor has a capture name.
170 - Formatter.is_completed() returns True.
171 - The extractor successfully extracts the data.
172 - This means the extractor identifies the correct string span to extract and whatever post-processing the extractor does on the extracted string is successful.
174 Captures are obtained by calling `Extractor.extract` method on the generated string in the sequence of extractors appended to the formatter.
175 Note that the previous extractors does not 'see' the semantics of the later extractors. For example,
176 consider the following formatter:
177 ```python
178 f = FormatterBuilder()
179 f.append_line(f"{f.regex('.*?', capture_name='a')}{f.regex('.*', capture_name='b')}")
180 f = f.build()
181 ```
182 The `b` extractor will always corresponding to `None` because the `a` extractor will always extract the whole string.
183 This behavior is different from what a typical regular expression engine would do!
184 """
185 return self._captures
186
187 def reset(self) -> None:
188 self._captures.clear()
189 self._engine.reset()
190 self._token_ids.clear()
191
192 def __str__(self):
193 return (f"Formatter(engine={self._engine}, "
194 f"captures={self._captures}, "
195 f"extractors={len(self._extractors)}, "
196 f"completed={self.is_completed()}, "
197 f"token_ids={len(self._token_ids)})"
198 f"grammar={self._grammar_str})")
199
200
201class FormatterBuilder:
202 """
203 A builder for creating a Formatter.
204 """
205 _formatter_builder_counter = 0
206
207 def __init__(self):
208 """
209 Initialize the formatter builder.
210 """
211 self._counter = 0
212 self._main_rule = []
213 self._rules = []
214 self._capture_names = set()
215 self._nonterminal_to_extractor = {}
216 self._extractors = []
217 self._instance_id = self.__class__._formatter_builder_counter
218 self.__class__._formatter_builder_counter += 1
219
220
221 def _assert_capture_name_valid(self, capture_name: str):
222 assert capture_name.isidentifier(), (f"capture_name {capture_name}"
223 f" should only contains alphanumeric characters, "
224 f"underscores, and does not start with digits!")
225 assert capture_name not in self._capture_names, f"capture_name {capture_name} is duplicated!"
227 def append_line(self, line: str) -> None:
228 """
229 Append a line to the format. Specifically, a newline character is appended to the input.
230
231 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
232 """
233 self.append_str(line + '\n')
234
235 def append_multiline_str(self, lines: str) -> None:
236 """
237 Appends a multiline string to the format, preserving the first line's leading whitespaces
238 and remove any common leading whitespaces from subsequent lines.
239
240 Note that tabs and spaces are both treated as whitespace, but they are not equal:
241 the lines " hello" and "\\thello" are considered to have no common leading whitespace.
243 Entirely blank lines are normalized to a newline character.
244
245 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
246 """
247 first = lines.find('\n')
248 self.append_str(lines[:first + 1] + textwrap.dedent(lines[first + 1:]))
250 def append_str(self, string: str) -> None:
251 """
252 Append a string to the format without any post-processing.
254 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
255 """
256 state = "normal"
257 last = 0
259 def append_literal(end):
260 if last < end:
261 literal = string[last:end]
262 self._main_rule.append(repr(literal))
263 self._extractors.append(LiteralExtractor(literal))
264
265 for i, char in enumerate(string):
266 if char == "$":
267 if state != "escaped":
268 state = "dollar"
269 else:
270 state = "normal"
271 elif state == "dollar":
272 if char == "{":
273 append_literal(i - 1)
274 last = i + 1
275 state = "left_bracket"
276 else:
277 state = "normal"
278 elif state == "left_bracket":
279 if char == "}":
280 state = "normal"
281 self._main_rule.append(string[last:i])
282 self._extractors.append(
283 self._nonterminal_to_extractor[string[last:i]])
284 last = i + 1
285 elif char == "\\":
286 state = "escaped"
287 else:
288 state = "normal"
289 append_literal(len(string))
290
291 def _create_nonterminal(self, name: str) -> str:
292 nonterminal = f"__{name}_{self._counter}_{self._instance_id}"
293 self._counter += 1
294 return nonterminal
295
296 def _add_capture_name(self, extractor: NonterminalExtractor) -> None:
297 if extractor.capture_name is None:
298 return None
299 self._assert_capture_name_valid(extractor.capture_name)
300 self._capture_names.add(extractor.capture_name)
301
302 def choose(self, *extractors: Extractor | str, capture_name: str = None) -> ChoiceExtractor:
303 """
304 Create a choice extractor.
305
306 Check out the ChoiceExtractor docs for more details.
307 Args:
308 extractors: The extractors to choose from.
309 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
310 Returns:
311 The choice extractor.
312 """
313 new_extractors = []
314 for extractor in extractors:
315 if isinstance(extractor, str):
316 new_extractors.append(LiteralExtractor(extractor))
317 else:
318 new_extractors.append(extractor)
319 return self._add_extractor("choice",
320 lambda nonterminal: ChoiceExtractor(new_extractors, capture_name, nonterminal))
321
322 def _add_extractor(self, extractor_type: str, create_extractor: typing.Callable[[str], Extractor]):
323 nonterminal = self._create_nonterminal(extractor_type)
324 extractor = create_extractor(nonterminal)
325 if isinstance(extractor, NonterminalExtractor):
326 self._add_capture_name(extractor)
327 nonterminal = extractor.nonterminal
328 self._nonterminal_to_extractor[nonterminal] = extractor
329 self._rules.append(extractor.kbnf_definition)
330 return extractor
331
332 def extractor(self, create_extractor: typing.Callable[[str], Extractor]) -> Extractor:
333 """
334 Create a custom extractor.
335
336 Args:
337 create_extractor: callable with signature (extractor_nonterminal: str)->Extractor that create the extractor. extractor_nonterminal is the auto-generated nonterminal reference for the extractor.
338 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
339 """
340 return self._add_extractor("extractor", create_extractor)
341
342 def json(self, schema: Schema, *, capture_name: str = None) -> JsonExtractor:
343 """
344 Create a JSON extractor. Check out the JsonExtractor docs for more details.
345
346 Args:
347 schema: The schema for extraction.
348 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
349 Returns:
350 The JSON extractor.
351 """
352 def to_json(json: str):
353 try:
354 return schema.from_json(json)
355 except JSONDecodeError: # make ChoiceExtractor work appropriately
356 return None
357 return self._add_extractor("json",
358 lambda nonterminal: JsonExtractor(nonterminal, capture_name,schema, to_json))
360 def regex(self, regex: str, *, capture_name: str = None) -> RegexExtractor:
361 """
362 Create a regex extractor.
363
364 Check out the RegexExtractor docs for more details.
365
366 Args:
367 regex: The regular expression for extraction.
368 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
369 Returns:
370 The regex extractor.
371 """
372 return self._add_extractor("regex",
373 lambda nonterminal: RegexExtractor(regex, capture_name, nonterminal))
374
375 def str(self, *, stop: typing.Union[str, list[str]] = None,
376 capture_name: typing.Optional[str] = None) -> Extractor:
377 """
378 Create a string extractor.
379
380 The extractor will extract all text until(inclusive) one of the stop strings is encountered.
381
382 Args:
383 stop: The strings for the extractors to stop at. They will be included in text generation and extraction.
384 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
385 Returns:
386 The string extractor.
387 """
388 stop = [stop] if isinstance(stop, str) else stop or []
389 nonterminal = self._create_nonterminal("str")
390 if not stop:
391 capture_regex = ".*"
392 nonterminal_regex = "#'.*'"
393 else:
394 backslash = '\\'
395 capture_regex = f".*?(?:{'|'.join([i.replace(backslash, backslash * 2) for i in map(re.escape, stop)])})"
396 nonterminal_regex = f"#e'{capture_regex}'"
397 self._rules.append(f"{nonterminal} ::= {nonterminal_regex};")
398 self._nonterminal_to_extractor[nonterminal] = RegexExtractor(
399 capture_regex, capture_name, nonterminal)
400 return self._nonterminal_to_extractor[nonterminal]
401
402 def substr(self, string: str, *, capture_name: str = None, extract_empty_substring: bool = False) -> Extractor:
403 """
404 Create a substring extractor.
405
406 The extractor will extract a substring of the input string.
407
408 Args:
409 string: The string to extract.
410 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
411 extract_empty_substring: Whether to extract an empty substring as a valid substring.
412 Returns:
413 The substring extractor.
414 """
415 return self._add_extractor("substr",
416 lambda nonterminal: SubstringExtractor(string, capture_name, nonterminal,
417 extract_empty_substring=extract_empty_substring))
418
419 def build(self, vocabulary: kbnf.Vocabulary,
420 decode: typing.Callable[[list[int]], str],
421 engine_config: kbnf.Config = None) -> Formatter:
422 """
423 Build a formatter from the builder. The builder will not be consumed and can be used again.
424
425 Args:
426 vocabulary: The KBNF engine vocabulary for the formatter.
427 decode: The callback to decode the token IDs to a string.
428 engine_config: The KBNF engine configuration.
429 Returns:
430 The formatter.
431 """
432 assert len(
433 self._main_rule) != 0, "An empty formatter builder cannot build!"
434 rules = copy(self._rules)
435 rules.append(f"start ::= {' '.join(self._main_rule)};")
436 grammar_str = "\n".join(rules)
437 engine = kbnf.Engine(grammar_str, vocabulary, engine_config)
438 extractors = copy(self._extractors)
439 f = Formatter(extractors, engine, decode, grammar_str)
440 return f
An extractor that uses multiple extractors to extract data.
Definition extractor.py:194
An extractor that extracts a literal string.
Definition extractor.py:143
An extractor that extracts a substring of a given string from the input string.
Definition extractor.py:240
An extractor that loads json data to an object from a string.
Definition json.py:253
An extractor that extracts a string using a regular expression.
Definition regex.py:13
An abstract Formatter that enforces a format on the string generated by a language model.
Definition formatter.py:22
None _on_completion(self, str generated_output)
Perform actions when the generation is completed.
Definition formatter.py:78
None reset(self)
Reset the formatter to the initial state.
Definition formatter.py:102
None compute_allowed_tokens(self)
Compute the allowed tokens based on the current state.
Definition formatter.py:46
dict[str, typing.Any|None] captures(self)
Definition formatter.py:96
typing.Sequence[int] get_allowed_tokens_since_last_computation(self)
Get the allowed tokens since the last computation(in other words, the last call to compute_allowed_to...
Definition formatter.py:64
bool is_completed(self)
Check if the generated string satisfies the format and hence the generation is completed.
Definition formatter.py:70
accept_bytes(self, bytes _bytes)
Accept a bytes object from the language model.
Definition formatter.py:40
typing.Any mask_logits(self, logits)
Mask the logits based on the current state.
Definition formatter.py:56
typing.Any accept_token(self, int token_id)
Accept a token from the language model.
Definition formatter.py:32
A builder for creating a Formatter.
Definition formatter.py:238
None append_str(self, str string)
Append a string to the format without any post-processing.
Definition formatter.py:292
None _add_capture_name(self, NonterminalExtractor extractor)
Definition formatter.py:333
ChoiceExtractor choose(self, *Extractor|str extractors, str capture_name=None)
Create a choice extractor.
Definition formatter.py:349
_add_extractor(self, str extractor_type, typing.Callable[[str], Extractor] create_extractor)
Definition formatter.py:359
JsonExtractor json(self, Schema schema, *, str capture_name=None)
Create a JSON extractor.
Definition formatter.py:388
Formatter build(self, kbnf.Vocabulary vocabulary, typing.Callable[[list[int]], str] decode, kbnf.Config engine_config=None)
Build a formatter from the builder.
Definition formatter.py:468
_assert_capture_name_valid(self, str capture_name)
Definition formatter.py:258
Extractor substr(self, str string, *, str capture_name=None, bool extract_empty_substring=False)
Create a substring extractor.
Definition formatter.py:451
Extractor str(self, *, typing.Union[str, list[str]] stop=None, typing.Optional[str] capture_name=None)
Create a string extractor.
Definition formatter.py:424
str _create_nonterminal(self, str name)
Definition formatter.py:328
RegexExtractor regex(self, str regex, *, str capture_name=None)
Create a regex extractor.
Definition formatter.py:408
typing.Any mask_logits(self, logits)
Mask the logits based on the current state.
Definition formatter.py:159
None reset(self)
Reset the formatter to the initial state.
Definition formatter.py:221
grammar_str(self)
Get the KBNF grammar string.
Definition formatter.py:140
accept_bytes(self, bytes _bytes)
Accept a bytes object from the language model.
Definition formatter.py:153
__init__(self, list[Extractor] extractors, kbnf.Engine engine, typing.Callable[[list[int]], str] decode_callback, str grammar_str)
Initialize the formatter.
Definition formatter.py:121
None compute_allowed_tokens(self)
Compute the allowed tokens based on the current state.
Definition formatter.py:156
typing.Sequence[int] get_allowed_tokens_since_last_computation(self)
Get the allowed tokens since the last computation(in other words, the last call to compute_allowed_to...
Definition formatter.py:162
bool is_completed(self)
Check if the generation is completed.
Definition formatter.py:170
None _on_completion(self, str generated_output)
Perform actions when the generation is completed.
Definition formatter.py:173
accept_token(self, int token_id)
Accept a token from the language model.
Definition formatter.py:145
dict[str, typing.Any]|None captures(self)
Get the captures from the generated string.
Definition formatter.py:216
Extractors for extracting data from generated strings.
Definition extractor.py:1
The module defines the JsonExtractor class, which is used to extract data from a string in JSON forma...
Definition json.py:1
This module contains the RegexExtractor class, which is used to extract data using a regular expressi...
Definition regex.py:1