Formatron v0.4.9
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
formatter.py
Go to the documentation of this file.
1"""
2This module contains the Formatter class and its related classes.
3"""
4import abc
5import collections
6from json import JSONDecodeError
7import json
8import re
9import textwrap
10import typing
11from copy import copy
12import kbnf
13from formatron.formats.json import JsonExtractor
14from formatron.schemas.schema import Schema
15from formatron.extractor import Extractor, LiteralExtractor, NonterminalExtractor, ChoiceExtractor, SubstringExtractor
16from formatron.formats.regex import RegexComplementExtractor, RegexExtractor
17
18
19
20class FormatterBase(abc.ABC):
21 """
22 An abstract Formatter that enforces a format on the string generated by a language model.
23 """
25 @abc.abstractmethod
26 def accept_token(self, token_id: int):
27 """
28 Accept a token from the language model.
29 Args:
30 token_id: The token ID.
31 Returns:
32 The result of accepting the token.
33 """
35 @abc.abstractmethod
36 def accept_bytes(self, _bytes: bytes)->None:
37 """
38 Accept a bytes object from the language model.
39 Args:
40 _bytes: The bytes object.
41 """
43 @abc.abstractmethod
44 def compute_allowed_tokens(self) -> None:
45 """
46 Compute the allowed tokens based on the current state.
47 """
49 @abc.abstractmethod
50 def mask_logits(self, logits) -> typing.Any:
51 """
52 Mask the logits based on the current state.
53 Args:
54 logits: The logits to mask.
55 Returns:
56 The masked logits.
57 """
59 @abc.abstractmethod
60 def get_allowed_tokens_since_last_computation(self) -> typing.Sequence[int]:
61 """
62 Get the allowed tokens since the last computation(in other words, the last call to `compute_allowed_tokens`).
63 Returns:
64 The allowed tokens.
65 """
67 @abc.abstractmethod
68 def is_completed(self) -> bool:
69 """
70 Check if the generated string satisfies the format and hence the generation is completed.
71 """
73 @abc.abstractmethod
74 def _on_completion(self, generated_output: str) -> None:
75 """
76 Perform actions when the generation is completed.
77 """
78
79 @property
80 @abc.abstractmethod
81 def captures(self) -> dict[str, typing.Any|None]:
82 """
83 Get the captures from the generated string.
84 """
85
86 @abc.abstractmethod
87 def reset(self) -> None:
88 """
89 Reset the formatter to the initial state.
90 """
91
92
93class Formatter(FormatterBase):
94 """
95 A Formatter that enforces a format on the string generated by a language model. It is designed to compose
96 multiple extractors in a sequential, unambiguous, greedy manner. Check out the Formatter.captures property docs for more details.
97 If you need more complex extraction logic, you need to implement your own Extractor.
98 """
99
100 def __init__(self, extractors: list[Extractor], engine: kbnf.Engine,
101 decode_callback: typing.Callable[[list[int]], str], grammar_str: str):
102 """
103 Initialize the formatter.
104 Args:
105 extractors: The matchers to extract data from the generated string.
106 engine: The KBNF engine to enforce the format.
107 decode_callback: The callback to decode the token IDs to a string.
108 grammar_str: The KBNF grammar string.
109 """
110 self._extractors = extractors
111 self._engine = engine
113 self._decode_callback = decode_callback
114 self._grammar_str = grammar_str
115 self._captures = {}
116
117 @property
118 def grammar_str(self):
119 """
120 Get the KBNF grammar string.
121 """
122 return self._grammar_str
123
124 def accept_token(self, token_id: int)->kbnf.AcceptTokenResult:
125 result = self._engine.try_accept_new_token(token_id)
126 self._token_id_or_bytes.append(token_id)
127 if result == kbnf.AcceptTokenResult.Finished:
130 return result
131
132 def _obtain_accepted_output(self)->str:
133 buffer = []
134 output = ""
135 last_type = None
136 def decode_buffer(buffer_type: type, buffer_content: list):
137 if buffer_type not in (int, bytes):
138 try:
139 buffer_content = [int(item) for item in buffer_content]
140 buffer_type = int
141 except ValueError:
142 assert False, f"Invalid type: {buffer_type}. Unable to convert to int."
143 if buffer_type is int:
144 return self._decode_callback(buffer_content)
145 elif buffer_type is bytes:
146 return b"".join(buffer_content).decode()
148 for element in self._token_id_or_bytes:
149 if last_type is None:
150 last_type = type(element)
151 elif last_type != type(element):
152 output += decode_buffer(last_type, buffer)
153 buffer.clear()
154 last_type = type(element)
155 buffer.append(element)
156
157 if buffer:
158 output += decode_buffer(last_type, buffer)
159 return output
160
161 def accept_bytes(self, _bytes: bytes)->kbnf.AcceptTokenResult:
162 result = self._engine.try_accept_new_bytes(_bytes)
163 self._token_id_or_bytes.append(_bytes)
164 if result == kbnf.AcceptTokenResult.Finished:
165 output = self._obtain_accepted_output()
167 return result
168
169 def compute_allowed_tokens(self) -> None:
170 self._engine.compute_allowed_token_ids()
171
172 def mask_logits(self, logits) -> typing.Any:
173 return self._engine.mask_logits(logits)
174
175 def get_allowed_tokens_since_last_computation(self) -> typing.Sequence[int]:
176 return self._engine.get_allowed_token_ids_from_last_computation()
177
178 def is_completed(self) -> bool:
179 """
180 Check if the generation is completed. This means the generation is ended by the engine.
181 If the generation is ended by integration-specific stop conditions like `max_new_tokens`,
182 the generation is not considered completed by this method.
183 """
184 return self._engine.is_finished()
185
186 def _on_completion(self, generated_output: str) -> None:
187 for matcher in self._extractors:
188 result = matcher.extract(generated_output)
189 if result is None:
190 captured = None
191 else:
192 generated_output, captured = matcher.extract(generated_output)
193 if matcher.capture_name:
194 if matcher.capture_name in self._captures:
195 self._captures[matcher.capture_name] = [
196 self._captures[matcher.capture_name]]
197 self._captures[matcher.capture_name].append(captured)
198 else:
199 self._captures[matcher.capture_name] = captured
200
201 @property
202 def captures(self) -> dict[str, typing.Any] | None:
203 """
204 Get the captures from the generated string. Note that the captures are only available for one extractor if:
205 - The extractor has a capture name.
206 - Formatter.is_completed() returns True.
207 - The extractor successfully extracts the data.
208 - This means the extractor identifies the correct string span to extract and whatever post-processing the extractor does on the extracted string is successful.
210 Captures are obtained by calling `Extractor.extract` method on the generated string in the sequence of extractors appended to the formatter.
211 Note that the previous extractors does not 'see' the semantics of the later extractors. For example,
212 consider the following formatter:
213 ```python
214 f = FormatterBuilder()
215 f.append_line(f"{f.regex('.*?', capture_name='a')}{f.regex('.*', capture_name='b')}")
216 f = f.build()
217 ```
218 The `b` extractor will always corresponding to `None` because the `a` extractor will always extract the whole string.
219 This behavior is different from what a typical regular expression engine would do!
220 """
221 return self._captures
222
223 def reset(self) -> None:
224 self._captures.clear()
225 self._engine.reset()
226 self._token_id_or_bytes.clear()
227
228 def __str__(self):
229 return (f"Formatter(engine={self._engine}, "
230 f"captures={self._captures}, "
231 f"extractors={len(self._extractors)}, "
232 f"completed={self.is_completed()}, "
233 f"token_ids={len(self._token_id_or_bytes)})"
234 f"grammar={self._grammar_str})")
235
236
237class FormatterBuilder:
238 """
239 A builder for creating a Formatter.
240 """
241 _formatter_builder_counter = 0
242
243 def __init__(self):
244 """
245 Initialize the formatter builder.
246 """
247 self._counter = 0
248 self._main_rule = []
249 self._rules = []
250 self._capture_names = set()
251 self._nonterminal_to_extractor = {}
252 self._extractors = []
253 self._instance_id = self.__class__._formatter_builder_counter
254 self.__class__._formatter_builder_counter += 1
255
256
257 def _assert_capture_name_valid(self, capture_name: str):
258 assert capture_name.isidentifier(), (f"capture_name {capture_name}"
259 f" should only contains alphanumeric characters, "
260 f"underscores, and does not start with digits!")
261 assert capture_name not in self._capture_names, f"capture_name {capture_name} is duplicated!"
263 def append_line(self, line: str) -> None:
264 """
265 Append a line to the format. Specifically, a newline character is appended to the input.
266
267 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
268 """
269 self.append_str(line + '\n')
270
271 def append_multiline_str(self, lines: str) -> None:
272 """
273 Appends a multiline string to the format, preserving the first line's leading whitespaces
274 and remove any common leading whitespaces from subsequent lines.
275
276 Note that tabs and spaces are both treated as whitespace, but they are not equal:
277 the lines " hello" and "\\thello" are considered to have no common leading whitespace.
279 Entirely blank lines are normalized to a newline character.
280
281 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
282 """
283 first = lines.find('\n')
284 self.append_str(lines[:first + 1] + textwrap.dedent(lines[first + 1:]))
286 def append_str(self, string: str) -> None:
287 """
288 Append a string to the format without any post-processing.
290 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
291 """
292 state = "normal"
293 last = 0
295 def append_literal(end):
296 if last < end:
297 literal = string[last:end]
298 self._main_rule.append(repr(literal))
299 self._extractors.append(LiteralExtractor(literal))
300
301 for i, char in enumerate(string):
302 if char == "$":
303 if state != "escaped":
304 state = "dollar"
305 else:
306 state = "normal"
307 elif state == "dollar":
308 if char == "{":
309 append_literal(i - 1)
310 last = i + 1
311 state = "left_bracket"
312 else:
313 state = "normal"
314 elif state == "left_bracket":
315 if char == "}":
316 state = "normal"
317 self._main_rule.append(string[last:i])
318 self._extractors.append(
319 self._nonterminal_to_extractor[string[last:i]])
320 last = i + 1
321 elif char == "\\":
322 state = "escaped"
323 else:
324 state = "normal"
325 append_literal(len(string))
326
327 def _create_nonterminal(self, name: str) -> str:
328 nonterminal = f"__{name}_{self._counter}_{self._instance_id}"
329 self._counter += 1
330 return nonterminal
331
332 def _add_capture_name(self, extractor: NonterminalExtractor) -> None:
333 if extractor.capture_name is None:
334 return None
335 self._assert_capture_name_valid(extractor.capture_name)
336 self._capture_names.add(extractor.capture_name)
337
338 def choose(self, *extractors: Extractor | str, capture_name: str = None) -> ChoiceExtractor:
339 """
340 Create a choice extractor.
341
342 Check out the ChoiceExtractor docs for more details.
343 Args:
344 extractors: The extractors to choose from.
345 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
346 Returns:
347 The choice extractor.
348 """
349 new_extractors = []
350 for extractor in extractors:
351 if isinstance(extractor, str):
352 new_extractors.append(LiteralExtractor(extractor))
353 else:
354 new_extractors.append(extractor)
355 return self._add_extractor("choice",
356 lambda nonterminal: ChoiceExtractor(new_extractors, capture_name, nonterminal))
357
358 def _add_extractor(self, extractor_type: str, create_extractor: typing.Callable[[str], Extractor]):
359 nonterminal = self._create_nonterminal(extractor_type)
360 extractor = create_extractor(nonterminal)
361 if isinstance(extractor, NonterminalExtractor):
362 self._add_capture_name(extractor)
363 nonterminal = extractor.nonterminal
364 self._nonterminal_to_extractor[nonterminal] = extractor
365 self._rules.append(extractor.kbnf_definition)
366 return extractor
367
368 def extractor(self, create_extractor: typing.Callable[[str], Extractor]) -> Extractor:
369 """
370 Create a custom extractor.
371
372 Args:
373 create_extractor: callable with signature (extractor_nonterminal: str)->Extractor that create the extractor. extractor_nonterminal is the auto-generated nonterminal reference for the extractor.
374 """
375 return self._add_extractor("extractor", create_extractor)
376
377 def json(self, schema: typing.Type[Schema]|collections.abc.Sequence, *, capture_name: str = None) -> JsonExtractor:
378 """
379 Create a JSON extractor. Check out the JsonExtractor docs for more details.
380
381 Args:
382 schema: The schema for extraction.
383 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
384 Returns:
385 The JSON extractor.
386 """
387 def to_json(_json: str):
388 if isinstance(schema, type) and issubclass(schema, Schema):
389 try:
390 return schema.from_json(_json)
391 except JSONDecodeError: # make ChoiceExtractor work appropriately
392 return None
393 else:
394 try:
395 return json.loads(_json)
396 except JSONDecodeError:
397 return None
398 return self._add_extractor("json",
399 lambda nonterminal: JsonExtractor(nonterminal, capture_name,schema, to_json))
400
401 def regex(self, regex: str, *, capture_name: str = None) -> RegexExtractor:
402 """
403 Create a regex extractor.
404
405 Check out the RegexExtractor docs for more details.
406
407 Args:
408 regex: The regular expression for extraction.
409 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
410 Returns:
411 The regex extractor.
412 """
413 return self._add_extractor("regex",
414 lambda nonterminal: RegexExtractor(regex, capture_name, nonterminal))
415
416 def regex_complement(self, regex: str, *, capture_name: str = None) -> RegexComplementExtractor:
417 """
418 Create a regex complement extractor. This is roughly equivalent to 'extract a string that does not match the given regex anywhere'.
419
420 Check out the RegexComplementExtractor docs for more details.
421
422 Args:
423 regex: The regular expression for extraction.
424 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
425 Returns:
426 The regex complement extractor.
427 """
428 return self._add_extractor("regex_complement",
429 lambda nonterminal: RegexComplementExtractor(regex, capture_name, nonterminal))
430
431 def str(self, *, stop: typing.Union[str, list[str]] = None,
432 capture_name: typing.Optional[str] = None) -> Extractor:
433 """
434 Create a string extractor.
435
436 The extractor will extract all text until(inclusive) one of the stop strings is encountered.
437
438 Args:
439 stop: The strings for the extractors to stop at. They will be included in text generation and extraction.
440 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
441 Returns:
442 The string extractor.
443 """
444 stop = [stop] if isinstance(stop, str) else stop or []
445 if not stop:
446 capture_regex = ".*"
447 else:
448 backslash = '\\'
449 capture_regex = f".*?(?:{'|'.join([re.escape(i.replace(backslash, backslash * 2)) for i in stop])})"
450 return self._add_extractor("str",
451 lambda nonterminal: RegexExtractor(capture_regex, capture_name, nonterminal))
452
453 def substr(self, string: str, *, capture_name: str = None, extract_empty_substring: bool = False) -> Extractor:
454 """
455 Create a substring extractor.
456
457 The extractor will extract a substring of the input string.
458
459 Args:
460 string: The string to extract.
461 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
462 extract_empty_substring: Whether to extract an empty substring as a valid substring.
463 Returns:
464 The substring extractor.
465 """
466 return self._add_extractor("substr",
467 lambda nonterminal: SubstringExtractor(string, capture_name, nonterminal,
468 extract_empty_substring=extract_empty_substring))
469
470
471
472 def build(self, vocabulary: kbnf.Vocabulary,
473 decode: typing.Callable[[list[int]], str],
474 engine_config: kbnf.Config = None) -> Formatter:
475 """
476 Build a formatter from the builder. The builder will not be consumed and can be used again.
477
478 Args:
479 vocabulary: The KBNF engine vocabulary for the formatter.
480 decode: The callback to decode the token IDs to a string.
481 engine_config: The KBNF engine configuration.
482 Returns:
483 The formatter.
484 """
485 assert len(
486 self._main_rule) != 0, "An empty formatter builder cannot build!"
487 rules = copy(self._rules)
488 rules.append(f"start ::= {' '.join(self._main_rule)};")
489 grammar_str = "\n".join(rules)
490 engine = kbnf.Engine(grammar_str, vocabulary, engine_config)
491 extractors = copy(self._extractors)
492 f = Formatter(extractors, engine, decode, grammar_str)
493 return f
An extractor that uses multiple extractors to extract data.
Definition extractor.py:194
An extractor that extracts a literal string.
Definition extractor.py:143
An extractor that extracts a substring of a given string from the input string.
Definition extractor.py:240
An extractor that loads json data to an object from a string.
Definition json.py:426
An extractor that extracts data by matching a regex complement.
Definition regex.py:59
An extractor that extracts a string using a regular expression.
Definition regex.py:13
An abstract Formatter that enforces a format on the string generated by a language model.
Definition formatter.py:24
None _on_completion(self, str generated_output)
Perform actions when the generation is completed.
Definition formatter.py:80
None reset(self)
Reset the formatter to the initial state.
Definition formatter.py:104
accept_token(self, int token_id)
Accept a token from the language model.
Definition formatter.py:34
None compute_allowed_tokens(self)
Compute the allowed tokens based on the current state.
Definition formatter.py:48
dict[str, typing.Any|None] captures(self)
Definition formatter.py:98
typing.Sequence[int] get_allowed_tokens_since_last_computation(self)
Get the allowed tokens since the last computation(in other words, the last call to compute_allowed_to...
Definition formatter.py:66
None accept_bytes(self, bytes _bytes)
Accept a bytes object from the language model.
Definition formatter.py:42
bool is_completed(self)
Check if the generated string satisfies the format and hence the generation is completed.
Definition formatter.py:72
typing.Any mask_logits(self, logits)
Mask the logits based on the current state.
Definition formatter.py:58
A builder for creating a Formatter.
Definition formatter.py:274
None append_str(self, str string)
Append a string to the format without any post-processing.
Definition formatter.py:328
None _add_capture_name(self, NonterminalExtractor extractor)
Definition formatter.py:369
ChoiceExtractor choose(self, *Extractor|str extractors, str capture_name=None)
Create a choice extractor.
Definition formatter.py:385
RegexComplementExtractor regex_complement(self, str regex, *, str capture_name=None)
Create a regex complement extractor.
Definition formatter.py:464
_add_extractor(self, str extractor_type, typing.Callable[[str], Extractor] create_extractor)
Definition formatter.py:395
Formatter build(self, kbnf.Vocabulary vocabulary, typing.Callable[[list[int]], str] decode, kbnf.Config engine_config=None)
Build a formatter from the builder.
Definition formatter.py:521
_assert_capture_name_valid(self, str capture_name)
Definition formatter.py:294
Extractor substr(self, str string, *, str capture_name=None, bool extract_empty_substring=False)
Create a substring extractor.
Definition formatter.py:502
JsonExtractor json(self, typing.Type[Schema]|collections.abc.Sequence schema, *, str capture_name=None)
Create a JSON extractor.
Definition formatter.py:423
Extractor str(self, *, typing.Union[str, list[str]] stop=None, typing.Optional[str] capture_name=None)
Create a string extractor.
Definition formatter.py:480
str _create_nonterminal(self, str name)
Definition formatter.py:364
RegexExtractor regex(self, str regex, *, str capture_name=None)
Create a regex extractor.
Definition formatter.py:449
typing.Any mask_logits(self, logits)
Mask the logits based on the current state.
Definition formatter.py:195
None reset(self)
Reset the formatter to the initial state.
Definition formatter.py:257
kbnf.AcceptTokenResult accept_bytes(self, bytes _bytes)
Accept a bytes object from the language model.
Definition formatter.py:184
grammar_str(self)
Get the KBNF grammar string.
Definition formatter.py:142
__init__(self, list[Extractor] extractors, kbnf.Engine engine, typing.Callable[[list[int]], str] decode_callback, str grammar_str)
Initialize the formatter.
Definition formatter.py:123
None compute_allowed_tokens(self)
Compute the allowed tokens based on the current state.
Definition formatter.py:192
typing.Sequence[int] get_allowed_tokens_since_last_computation(self)
Get the allowed tokens since the last computation(in other words, the last call to compute_allowed_to...
Definition formatter.py:198
kbnf.AcceptTokenResult accept_token(self, int token_id)
Accept a token from the language model.
Definition formatter.py:147
bool is_completed(self)
Check if the generation is completed.
Definition formatter.py:206
None _on_completion(self, str generated_output)
Perform actions when the generation is completed.
Definition formatter.py:209
dict[str, typing.Any]|None captures(self)
Get the captures from the generated string.
Definition formatter.py:252
Extractors for extracting data from generated strings.
Definition extractor.py:1
The module defines the JsonExtractor class, which is used to extract data from a string in JSON forma...
Definition json.py:1
This module contains the RegexExtractor class, which is used to extract data using a regular expressi...
Definition regex.py:1