2This module contains the Formatter class and its related classes.
5from json
import JSONDecodeError
12from formatron.schemas.schema
import Schema
13from formatron.extractor import Extractor, LiteralExtractor, NonterminalExtractor, ChoiceExtractor, SubstringExtractor
20 An abstract Formatter that enforces a format on the string generated by a language model.
26 Accept a token from the language model.
28 token_id: The token ID.
30 The result of accepting the token.
36 Accept a bytes object from the language model.
38 _bytes: The bytes object.
44 Compute the allowed tokens based on the current state.
50 Mask the logits based on the current state.
52 logits: The logits to mask.
60 Get the allowed tokens since the last computation(in other words, the last call to `compute_allowed_tokens`).
68 Check if the generated string satisfies the format and hence the generation is completed.
74 Perform actions when the generation is completed.
79 def captures(self) -> dict[str, typing.Any|None]:
81 Get the captures from the generated string.
85 def reset(self) -> None:
87 Reset the formatter to the initial state.
91class Formatter(FormatterBase):
93 A Formatter that enforces a format on the string generated by a language model. It is designed to compose
94 multiple extractors in a sequential, unambiguous, greedy manner. Check out the Formatter.captures property docs for more details.
95 If you need more complex extraction logic, you need to implement your own Extractor.
98 def __init__(self, extractors: list[Extractor], engine: kbnf.Engine,
99 decode_callback: typing.Callable[[list[int]], str], grammar_str: str):
101 Initialize the formatter.
103 extractors: The matchers to extract data from the generated string.
104 engine: The KBNF engine to enforce the format.
105 decode_callback: The callback to decode the token IDs to a string.
106 grammar_str: The KBNF grammar string.
108 self._extractors = extractors
109 self._engine = engine
118 Get the KBNF grammar string.
123 result = self.
_engine.try_accept_new_token(token_id)
125 if result == kbnf.AcceptTokenResult.Finished:
131 self.
_engine.try_accept_new_bytes(_bytes)
134 self.
_engine.compute_allowed_token_ids()
140 return self.
_engine.get_allowed_token_ids_from_last_computation()
144 Check if the generation is completed. This means the generation is ended by the engine.
145 If the generation is ended by integration-specific stop conditions like `max_new_tokens`,
146 the generation is not considered completed by this method.
148 return self.
_engine.is_finished()
152 result = matcher.extract(generated_output)
156 generated_output, captured = matcher.extract(generated_output)
157 if matcher.capture_name:
158 if matcher.capture_name
in self.
_captures:
161 self.
_captures[matcher.capture_name].append(captured)
163 self.
_captures[matcher.capture_name] = captured
166 def captures(self) -> dict[str, typing.Any] | None:
168 Get the captures from the generated string. Note that the captures are only available for one extractor if:
169 - The extractor has a capture name.
170 - Formatter.is_completed() returns True.
171 - The extractor successfully extracts the data.
172 - This means the extractor identifies the correct string span to extract and whatever post-processing the extractor does on the extracted string is successful.
174 Captures are obtained by calling `Extractor.extract` method on the generated string in the sequence of extractors appended to the formatter.
175 Note that the previous extractors does not 'see' the semantics of the later extractors. For example,
176 consider the following formatter:
178 f = FormatterBuilder()
179 f.append_line(f"{f.regex('.*?', capture_name='a')}{f.regex('.*', capture_name='b')}")
182 The `b` extractor will always corresponding to `None` because the `a` extractor will always extract the whole string.
183 This behavior is different from what a typical regular expression engine would do!
187 def reset(self) -> None:
193 return (f
"Formatter(engine={self._engine}, "
194 f
"captures={self._captures}, "
195 f
"extractors={len(self._extractors)}, "
196 f
"completed={self.is_completed()}, "
197 f
"token_ids={len(self._token_ids)})"
198 f
"grammar={self._grammar_str})")
203 A builder for creating a Formatter.
205 _formatter_builder_counter = 0
209 Initialize the formatter builder.
214 self._capture_names = set()
215 self._nonterminal_to_extractor = {}
217 self._instance_id = self.__class__._formatter_builder_counter
218 self.__class__._formatter_builder_counter += 1
221 def _assert_capture_name_valid(self, capture_name: str):
222 assert capture_name.isidentifier(), (f
"capture_name {capture_name}"
223 f
" should only contains alphanumeric characters, "
224 f
"underscores, and does not start with digits!")
225 assert capture_name
not in self._capture_names, f
"capture_name {capture_name} is duplicated!"
227 def append_line(self, line: str) ->
None:
229 Append a line to the format. Specifically, a newline character is appended to the input.
231 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
233 self.append_str(line +
'\n')
235 def append_multiline_str(self, lines: str) ->
None:
237 Appends a multiline string to the format, preserving the first line's leading whitespaces
238 and remove any common leading whitespaces from subsequent lines.
240 Note that tabs and spaces are both treated as whitespace, but they are not equal:
241 the lines " hello" and "\\thello" are considered to have no common leading whitespace.
243 Entirely blank lines are normalized to a newline character.
245 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
247 first = lines.find(
'\n')
248 self.
append_str(lines[:first + 1] + textwrap.dedent(lines[first + 1:]))
252 Append a string to the format without any post-processing.
254 Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`.
259 def append_literal(end):
261 literal = string[last:end]
265 for i, char
in enumerate(string):
267 if state !=
"escaped":
271 elif state ==
"dollar":
273 append_literal(i - 1)
275 state =
"left_bracket"
278 elif state ==
"left_bracket":
289 append_literal(len(string))
292 nonterminal = f
"__{name}_{self._counter}_{self._instance_id}"
297 if extractor.capture_name
is None:
302 def choose(self, *extractors: Extractor | str, capture_name: str =
None) -> ChoiceExtractor:
304 Create a choice extractor.
306 Check out the ChoiceExtractor docs for more details.
308 extractors: The extractors to choose from.
309 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
311 The choice extractor.
314 for extractor
in extractors:
315 if isinstance(extractor, str):
318 new_extractors.append(extractor)
320 lambda nonterminal:
ChoiceExtractor(new_extractors, capture_name, nonterminal))
322 def _add_extractor(self, extractor_type: str, create_extractor: typing.Callable[[str], Extractor]):
324 extractor = create_extractor(nonterminal)
325 if isinstance(extractor, NonterminalExtractor):
327 nonterminal = extractor.nonterminal
329 self.
_rules.append(extractor.kbnf_definition)
332 def extractor(self, create_extractor: typing.Callable[[str], Extractor]) -> Extractor:
334 Create a custom extractor.
337 create_extractor: callable with signature (extractor_nonterminal: str)->Extractor that create the extractor. extractor_nonterminal is the auto-generated nonterminal reference for the extractor.
338 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
342 def json(self, schema: Schema, *, capture_name: str =
None) -> JsonExtractor:
344 Create a JSON extractor. Check out the JsonExtractor docs for more details.
347 schema: The schema for extraction.
348 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
352 def to_json(json: str):
354 return schema.from_json(json)
355 except JSONDecodeError:
358 lambda nonterminal:
JsonExtractor(nonterminal, capture_name,schema, to_json))
360 def regex(self, regex: str, *, capture_name: str =
None) -> RegexExtractor:
362 Create a regex extractor.
364 Check out the RegexExtractor docs for more details.
367 regex: The regular expression for extraction.
368 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
373 lambda nonterminal:
RegexExtractor(regex, capture_name, nonterminal))
375 def str(self, *, stop: typing.Union[str, list[str]] =
None,
376 capture_name: typing.Optional[str] =
None) -> Extractor:
378 Create a string extractor.
380 The extractor will extract all text until(inclusive) one of the stop strings is encountered.
383 stop: The strings for the extractors to stop at. They will be included in text generation and extraction.
384 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
386 The string extractor.
388 stop = [stop]
if isinstance(stop, str)
else stop
or []
392 nonterminal_regex =
"#'.*'"
395 capture_regex = f
".*?(?:{'|'.join([i.replace(backslash, backslash * 2) for i in map(re.escape, stop)])})"
396 nonterminal_regex = f
"#e'{capture_regex}'"
397 self.
_rules.append(f
"{nonterminal} ::= {nonterminal_regex};")
399 capture_regex, capture_name, nonterminal)
402 def substr(self, string: str, *, capture_name: str =
None, extract_empty_substring: bool =
False) -> Extractor:
404 Create a substring extractor.
406 The extractor will extract a substring of the input string.
409 string: The string to extract.
410 capture_name: The capture name of the extractor, or `None` if the extractor does not capture.
411 extract_empty_substring: Whether to extract an empty substring as a valid substring.
413 The substring extractor.
417 extract_empty_substring=extract_empty_substring))
419 def build(self, vocabulary: kbnf.Vocabulary,
420 decode: typing.Callable[[list[int]], str],
421 engine_config: kbnf.Config =
None) -> Formatter:
423 Build a formatter from the builder. The builder will not be consumed and can be used again.
426 vocabulary: The KBNF engine vocabulary for the formatter.
427 decode: The callback to decode the token IDs to a string.
428 engine_config: The KBNF engine configuration.
433 self.
_main_rule) != 0,
"An empty formatter builder cannot build!"
435 rules.append(f
"start ::= {' '.join(self._main_rule)};")
436 grammar_str =
"\n".join(rules)
437 engine = kbnf.Engine(grammar_str, vocabulary, engine_config)
439 f =
Formatter(extractors, engine, decode, grammar_str)