2This module contains the Formatter class and its related classes.
6from json
import JSONDecodeError
14from formatron.schemas.schema
import Schema
15from formatron.extractor import Extractor, LiteralExtractor, NonterminalExtractor, ChoiceExtractor, SubstringExtractor
22 An abstract Formatter that enforces a format on the string generated by a language model.
28 Accept a token from the language model.
30 token_id: The token ID.
32 The result of accepting the token.
38 Accept a bytes object from the language model.
40 _bytes: The bytes object.
46 Compute the allowed tokens based on the current state.
52 Mask the logits based on the current state.
54 logits: The logits to mask.
62 Get the allowed tokens since the last computation(in other words, the last call to `compute_allowed_tokens`).
70 Check if the generated string satisfies the format and hence the generation is completed.
76 Perform actions when the generation is completed.
81 def captures(self) -> dict[str, typing.Any|None]:
83 Get the captures from the generated string.
87 def reset(self) -> None:
89 Reset the formatter to the initial state.
93class Formatter(FormatterBase):
95 A Formatter that enforces a format on the string generated by a language model. It is designed to compose
96 multiple extractors in a sequential, unambiguous, greedy manner. Check out the Formatter.captures property docs for more details.
97 If you need more complex extraction logic, you need to implement your own Extractor.
100 def __init__(self, extractors: list[Extractor], engine: kbnf.Engine,
101 decode_callback: typing.Callable[[list[int]], str], grammar_str: str):
103 Initialize the formatter.
105 extractors: The matchers to extract data from the generated string.
106 engine: The KBNF engine to enforce the format.
107 decode_callback: The callback to decode the token IDs to a string.
108 grammar_str: The KBNF grammar string.
110 self._extractors = extractors
111 self._engine = engine
120 Get the KBNF grammar string.
125 result = self.
_engine.try_accept_new_token(token_id)
127 if result == kbnf.AcceptTokenResult.Finished:
136 def decode_buffer(buffer_type: type, buffer_content: list):
137 if buffer_type
not in (int, bytes):
139 buffer_content = [int(item)
for item
in buffer_content]
142 assert False, f
"Invalid type: {buffer_type}. Unable to convert to int."
143 if buffer_type
is int:
145 elif buffer_type
is bytes:
146 return b
"".join(buffer_content).decode()
149 if last_type
is None:
150 last_type = type(element)
151 elif last_type != type(element):
152 output += decode_buffer(last_type, buffer)
154 last_type = type(element)
155 buffer.append(element)
158 output += decode_buffer(last_type, buffer)
161 def accept_bytes(self, _bytes: bytes)->kbnf.AcceptTokenResult:
162 result = self.
_engine.try_accept_new_bytes(_bytes)
164 if result == kbnf.AcceptTokenResult.Finished:
170 self.
_engine.compute_allowed_token_ids()
176 return self.
_engine.get_allowed_token_ids_from_last_computation()
180 Check if the generation is completed. This means the generation is ended by the engine.
181 If the generation is ended by integration-specific stop conditions like `max_new_tokens`,
182 the generation is not considered completed by this method.
188 result = matcher.extract(generated_output)
192 generated_output, captured = matcher.extract(generated_output)
193 if matcher.capture_name:
194 if matcher.capture_name
in self.
_captures:
197 self.
_captures[matcher.capture_name].append(captured)
199 self.
_captures[matcher.capture_name] = captured
202 def captures(self) -> dict[str, typing.Any] | None:
204 Get the captures from the generated string. Note that the captures are only available for one extractor if:
205 - The extractor has a capture name.
206 - Formatter.is_completed() returns True.
207 - The extractor successfully extracts the data.
208 - This means the extractor identifies the correct string span to extract and whatever post-processing the extractor does on the extracted string is successful.
210 Captures are obtained by calling `Extractor.extract` method on the generated string in the sequence of extractors appended to the formatter.
211 Note that the previous extractors does not 'see' the semantics of the later extractors. For example,
212 consider the following formatter:
214 f = FormatterBuilder()
215 f.append_line(f"{f.regex('.*?', capture_name='a')}{f.regex('.*', capture_name='b')}")
218 The `b` extractor will always corresponding to `None` because the `a` extractor will always extract the whole string.
219 This behavior is different from what a typical regular expression engine would do!
223 def reset(self) -> None:
229 return (f
"Formatter(engine={self._engine}, "
230 f
"captures={self._captures}, "
231 f
"extractors={len(self._extractors)}, "
232 f
"completed={self.is_completed()}, "
233 f
"token_ids={len(self._token_id_or_bytes)})"
234 f
"grammar={self._grammar_str})")
239 A builder for creating a Formatter.
241 _formatter_builder_counter = 0
245 Initialize the formatter builder.
250 self._capture_names = set()
251 self._nonterminal_to_extractor = {}
253 self._instance_id = self.__class__._formatter_builder_counter
254 self.__class__._formatter_builder_counter += 1
257 def _assert_capture_name_valid(self, capture_name: str):
258 assert capture_name.isidentifier(), (f
"capture_name {capture_name}"
259 f
" should only contains alphanumeric characters, "
260 f
"underscores, and does not start with digits!")
261 assert capture_name
not in self._capture_names, f
"capture_name {capture_name} is duplicated!"
263 def append_line(self, line: str) ->
None:
265 Append a line to the format. Specifically, a newline character is appended to the input.
267 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
269 self.append_str(line +
'\n')
271 def append_multiline_str(self, lines: str) ->
None:
273 Appends a multiline string to the format, preserving the first line's leading whitespaces
274 and remove any common leading whitespaces from subsequent lines.
276 Note that tabs and spaces are both treated as whitespace, but they are not equal:
277 the lines " hello" and "\\thello" are considered to have no common leading whitespace.
279 Entirely blank lines are normalized to a newline character.
281 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
283 first = lines.find(
'\n')
284 self.
append_str(lines[:first + 1] + textwrap.dedent(lines[first + 1:]))
288 Append a string to the format without any post-processing.
290 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
295 def append_literal(end):
297 literal = string[last:end]
301 for i, char
in enumerate(string):
303 if state !=
"escaped":
307 elif state ==
"dollar":
309 append_literal(i - 1)
311 state =
"left_bracket"
314 elif state ==
"left_bracket":
325 append_literal(len(string))
328 nonterminal = f
"__{name}_{self._counter}_{self._instance_id}"
333 if extractor.capture_name
is None:
338 def choose(self, *extractors: Extractor | str, capture_name: str =
None) -> ChoiceExtractor:
340 Create a choice extractor.
342 Check out the ChoiceExtractor docs for more details.
344 extractors: The extractors to choose from.
345 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
347 The choice extractor.
350 for extractor
in extractors:
351 if isinstance(extractor, str):
354 new_extractors.append(extractor)
356 lambda nonterminal:
ChoiceExtractor(new_extractors, capture_name, nonterminal))
358 def _add_extractor(self, extractor_type: str, create_extractor: typing.Callable[[str], Extractor]):
360 extractor = create_extractor(nonterminal)
361 if isinstance(extractor, NonterminalExtractor):
363 nonterminal = extractor.nonterminal
365 self.
_rules.append(extractor.kbnf_definition)
368 def extractor(self, create_extractor: typing.Callable[[str], Extractor]) -> Extractor:
370 Create a custom extractor.
373 create_extractor: callable with signature (extractor_nonterminal: str)->Extractor that create the extractor. extractor_nonterminal is the auto-generated nonterminal reference for the extractor.
377 def json(self, schema: typing.Type[Schema]|collections.abc.Sequence, *, capture_name: str =
None) -> JsonExtractor:
379 Create a JSON extractor. Check out the JsonExtractor docs for more details.
382 schema: The schema for extraction.
383 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
387 def to_json(_json: str):
388 if isinstance(schema, type)
and issubclass(schema, Schema):
390 return schema.from_json(_json)
391 except JSONDecodeError:
395 return json.loads(_json)
396 except JSONDecodeError:
399 lambda nonterminal:
JsonExtractor(nonterminal, capture_name,schema, to_json))
401 def regex(self, regex: str, *, capture_name: str =
None) -> RegexExtractor:
403 Create a regex extractor.
405 Check out the RegexExtractor docs for more details.
408 regex: The regular expression for extraction.
409 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
414 lambda nonterminal:
RegexExtractor(regex, capture_name, nonterminal))
416 def regex_complement(self, regex: str, *, capture_name: str =
None) -> RegexComplementExtractor:
418 Create a regex complement extractor. This is roughly equivalent to 'extract a string that does not match the given regex anywhere'.
420 Check out the RegexComplementExtractor docs for more details.
423 regex: The regular expression for extraction.
424 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
426 The regex complement extractor.
431 def str(self, *, stop: typing.Union[str, list[str]] =
None,
432 capture_name: typing.Optional[str] =
None) -> Extractor:
434 Create a string extractor.
436 The extractor will extract all text until(inclusive) one of the stop strings is encountered.
439 stop: The strings for the extractors to stop at. They will be included in text generation and extraction.
440 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
442 The string extractor.
444 stop = [stop]
if isinstance(stop, str)
else stop
or []
449 capture_regex = f
".*?(?:{'|'.join([re.escape(i.replace(backslash, backslash * 2)) for i in stop])})"
451 lambda nonterminal:
RegexExtractor(capture_regex, capture_name, nonterminal))
453 def substr(self, string: str, *, capture_name: str =
None, extract_empty_substring: bool =
False) -> Extractor:
455 Create a substring extractor.
457 The extractor will extract a substring of the input string.
460 string: The string to extract.
461 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
462 extract_empty_substring: Whether to extract an empty substring as a valid substring.
464 The substring extractor.
468 extract_empty_substring=extract_empty_substring))
472 def build(self, vocabulary: kbnf.Vocabulary,
473 decode: typing.Callable[[list[int]], str],
474 engine_config: kbnf.Config =
None) -> Formatter:
476 Build a formatter from the builder. The builder will not be consumed and can be used again.
479 vocabulary: The KBNF engine vocabulary for the formatter.
480 decode: The callback to decode the token IDs to a string.
481 engine_config: The KBNF engine configuration.
486 self.
_main_rule) != 0,
"An empty formatter builder cannot build!"
488 rules.append(f
"start ::= {' '.join(self._main_rule)};")
489 grammar_str =
"\n".join(rules)
490 engine = kbnf.Engine(grammar_str, vocabulary, engine_config)
492 f =
Formatter(extractors, engine, decode, grammar_str)