2This module contains the Formatter class and its related classes. 
    6from json 
import JSONDecodeError
 
   14from formatron.schemas.schema 
import Schema
 
   15from formatron.extractor import Extractor, LiteralExtractor, NonterminalExtractor, ChoiceExtractor, SubstringExtractor
 
   22    An abstract Formatter that enforces a format on the string generated by a language model.  
   28        Accept a token from the language model. 
   30            token_id: The token ID. 
   32            The result of accepting the token. 
   38        Accept a bytes object from the language model. 
   40            _bytes: The bytes object. 
   46        Compute the allowed tokens based on the current state. 
   52        Mask the logits based on the current state. 
   54            logits: The logits to mask. 
   62        Get the allowed tokens since the last computation(in other words, the last call to `compute_allowed_tokens`). 
   70        Check if the generated string satisfies the format and hence the generation is completed. 
   76        Perform actions when the generation is completed. 
   81    def captures(self) -> dict[str, typing.Any|None]:
 
 
   83        Get the captures from the generated string.  
   87    def reset(self) -> None:
 
   89        Reset the formatter to the initial state. 
   93class Formatter(FormatterBase):
 
   95    A Formatter that enforces a format on the string generated by a language model. It is designed to compose 
   96    multiple extractors in a sequential, unambiguous, greedy manner. Check out the Formatter.captures property docs for more details. 
   97    If you need more complex extraction logic, you need to implement your own Extractor. 
  100    def __init__(self, extractors: list[Extractor], engine: kbnf.Engine,
 
  101                 decode_callback: typing.Callable[[list[int]], str], grammar_str: str):
 
  103        Initialize the formatter. 
  105            extractors: The matchers to extract data from the generated string. 
  106            engine: The KBNF engine to enforce the format. 
 
  107            decode_callback: The callback to decode the token IDs to a string. 
  108            grammar_str: The KBNF grammar string. 
  110        self._extractors = extractors
 
  111        self._engine = engine
 
 
  120        Get the KBNF grammar string. 
  125        result = self.
_engine.try_accept_new_token(token_id)
 
  127        if result == kbnf.AcceptTokenResult.Finished:
 
 
  136        def decode_buffer(buffer_type: type, buffer_content: list):
 
  137            if buffer_type 
not in (int, bytes):
 
  139                    buffer_content = [int(item) 
for item 
in buffer_content]
 
  142                    assert False, f
"Invalid type: {buffer_type}. Unable to convert to int." 
  143            if buffer_type 
is int:
 
 
  145            elif buffer_type 
is bytes:
 
  146                return b
"".join(buffer_content).decode()
 
  149            if last_type 
is None:
 
  150                last_type = type(element)
 
  151            elif last_type != type(element):
 
  152                output += decode_buffer(last_type, buffer)
 
  154                last_type = type(element)
 
 
  155            buffer.append(element)
 
  158            output += decode_buffer(last_type, buffer)
 
  161    def accept_bytes(self, _bytes: bytes)->kbnf.AcceptTokenResult:
 
  162        result = self.
_engine.try_accept_new_bytes(_bytes)
 
  164        if result == kbnf.AcceptTokenResult.Finished:
 
  170        self.
_engine.compute_allowed_token_ids()
 
  176        return self.
_engine.get_allowed_token_ids_from_last_computation()
 
  180        Check if the generation is completed. This means the generation is ended by the engine. 
  181        If the generation is ended by integration-specific stop conditions like `max_new_tokens`, 
  182        the generation is not considered completed by this method. 
 
  188            result = matcher.extract(generated_output)
 
 
  192                generated_output, captured = matcher.extract(generated_output)
 
  193            if matcher.capture_name:
 
  194                if matcher.capture_name 
in self.
_captures:
 
 
  197                    self.
_captures[matcher.capture_name].append(captured)
 
 
  199                    self.
_captures[matcher.capture_name] = captured
 
 
  202    def captures(self) -> dict[str, typing.Any] | None:
 
  204        Get the captures from the generated string. Note that the captures are only available for one extractor if: 
  205        - The extractor has a capture name. 
  206        - Formatter.is_completed() returns True. 
  207        - The extractor successfully extracts the data. 
  208          - This means the extractor identifies the correct string span to extract and whatever post-processing the extractor does on the extracted string is successful. 
 
  210        Captures are obtained by calling `Extractor.extract` method on the generated string in the sequence of extractors appended to the formatter. 
  211        Note that the previous extractors does not 'see' the semantics of the later extractors. For example, 
  212        consider the following formatter: 
  214        f = FormatterBuilder() 
  215        f.append_line(f"{f.regex('.*?', capture_name='a')}{f.regex('.*', capture_name='b')}") 
  218        The `b` extractor will always corresponding to `None` because the `a` extractor will always extract the whole string. 
  219        This behavior is different from what a typical regular expression engine would do!  
  223    def reset(self) -> None:
 
 
  229        return (f
"Formatter(engine={self._engine}, " 
  230                f
"captures={self._captures}, " 
  231                f
"extractors={len(self._extractors)}, " 
  232                f
"completed={self.is_completed()}, " 
  233                f
"token_ids={len(self._token_id_or_bytes)})" 
  234                f
"grammar={self._grammar_str})")
 
  239    A builder for creating a Formatter. 
  241    _formatter_builder_counter = 0
 
  245        Initialize the formatter builder. 
  250        self._capture_names = set()
 
  251        self._nonterminal_to_extractor = {}
 
  253        self._instance_id = self.__class__._formatter_builder_counter
 
  254        self.__class__._formatter_builder_counter += 1
 
 
  257    def _assert_capture_name_valid(self, capture_name: str):
 
  258        assert capture_name.isidentifier(), (f
"capture_name {capture_name}" 
  259                                             f
" should only contains alphanumeric characters, " 
  260                                             f
"underscores, and does not start with digits!")
 
  261        assert capture_name 
not in self._capture_names, f
"capture_name {capture_name} is duplicated!" 
 
  263    def append_line(self, line: str) -> 
None:
 
  265        Append a line to the format. Specifically, a newline character is appended to the input. 
  267        Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`. 
  269        self.append_str(line + 
'\n')
 
 
 
  271    def append_multiline_str(self, lines: str) -> 
None:
 
  273        Appends a multiline string to the format, preserving the first line's leading whitespaces 
  274        and remove any common leading whitespaces from subsequent lines. 
  276        Note that tabs and spaces are both treated as whitespace, but they are not equal: 
  277        the lines " hello" and "\\thello" are considered to have no common leading whitespace. 
  279        Entirely blank lines are normalized to a newline character. 
  281        Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`. 
  283        first = lines.find(
'\n')
 
  284        self.
append_str(lines[:first + 1] + textwrap.dedent(lines[first + 1:]))
 
  288        Append a string to the format without any post-processing. 
  290        Note that if you need a literal `$`, you need to escape it by adding a backslash: `\\$`. 
 
  295        def append_literal(end):
 
  297                literal = string[last:end]
 
 
  301        for i, char 
in enumerate(string):
 
  303                if state != 
"escaped":
 
  307            elif state == 
"dollar":
 
 
  309                    append_literal(i - 1)
 
  311                    state = 
"left_bracket" 
  314            elif state == 
"left_bracket":
 
  325        append_literal(len(string))
 
  328        nonterminal = f
"__{name}_{self._counter}_{self._instance_id}" 
  333        if extractor.capture_name 
is None:
 
  338    def choose(self, *extractors: Extractor | str, capture_name: str = 
None) -> ChoiceExtractor:
 
  340        Create a choice extractor. 
  342        Check out the ChoiceExtractor docs for more details. 
  344            extractors: The extractors to choose from. 
  345            capture_name: The capture name of the extractor, or `None` if the extractor does not capture. 
  347            The choice extractor. 
  350        for extractor 
in extractors:
 
  351            if isinstance(extractor, str):
 
  354                new_extractors.append(extractor)
 
  356                                   lambda nonterminal: 
ChoiceExtractor(new_extractors, capture_name, nonterminal))
 
  358    def _add_extractor(self, extractor_type: str, create_extractor: typing.Callable[[str], Extractor]):
 
  360        extractor = create_extractor(nonterminal)
 
  361        if isinstance(extractor, NonterminalExtractor):
 
  363            nonterminal = extractor.nonterminal
 
 
  365        self.
_rules.append(extractor.kbnf_definition)
 
  368    def extractor(self, create_extractor: typing.Callable[[str], Extractor]) -> Extractor:
 
 
  370        Create a custom extractor. 
  373            create_extractor: callable with signature (extractor_nonterminal: str)->Extractor that create the extractor. extractor_nonterminal is the auto-generated nonterminal reference for the extractor. 
 
  377    def json(self, schema: typing.Type[Schema]|collections.abc.Sequence, *, capture_name: str = 
None) -> JsonExtractor:
 
  379        Create a JSON extractor. Check out the JsonExtractor docs for more details. 
  382            schema: The schema for extraction. 
  383            capture_name: The capture name of the extractor, or `None` if the extractor does not capture. 
  387        def to_json(_json: str):
 
  388            local_schema = schema
 
  389            origin = typing.get_origin(local_schema)
 
  390            if origin 
is not None:
 
  391                local_schema = origin
 
  392            if isinstance(local_schema, type) 
and issubclass(local_schema, Schema):
 
  394                    return local_schema.from_json(_json)
 
 
  395                except JSONDecodeError:  
 
  399                    return json.loads(_json)
 
  400                except JSONDecodeError:
 
  403                                   lambda nonterminal: 
JsonExtractor(nonterminal, capture_name,schema, to_json))
 
 
  405    def regex(self, regex: str, *, capture_name: str = 
None) -> RegexExtractor:
 
  407        Create a regex extractor. 
  409        Check out the RegexExtractor docs for more details. 
  412            regex: The regular expression for extraction. 
  413            capture_name: The capture name of the extractor, or `None` if the extractor does not capture. 
 
  418                                   lambda nonterminal: 
RegexExtractor(regex, capture_name, nonterminal))
 
  420    def regex_complement(self, regex: str, *, capture_name: str = 
None) -> RegexComplementExtractor:
 
  422        Create a regex complement extractor. This is roughly equivalent to 'extract a string that does not match the given regex anywhere'. 
  424        Check out the RegexComplementExtractor docs for more details. 
  427            regex: The regular expression for extraction. 
  428            capture_name: The capture name of the extractor, or `None` if the extractor does not capture. 
  430            The regex complement extractor. 
  435    def str(self, *, stop: typing.Union[str, list[str]] = 
None,
 
  436            capture_name: typing.Optional[str] = 
None) -> Extractor:
 
  438        Create a string extractor. 
  440        The extractor will extract all text until(inclusive) one of the stop strings is encountered.  
 
  443            stop: The strings for the extractors to stop at. They will be included in text generation and extraction. 
  444            capture_name: The capture name of the extractor, or `None` if the extractor does not capture. 
  446            The string extractor. 
  448        stop = [stop] 
if isinstance(stop, str) 
else stop 
or []
 
  453            capture_regex = f
".*?(?:{'|'.join([re.escape(i.replace(backslash, backslash * 2)) for i in stop])})" 
  455                                   lambda nonterminal: 
RegexExtractor(capture_regex, capture_name, nonterminal))
 
 
  457    def substr(self, string: str, *, capture_name: str = 
None, extract_empty_substring: bool = 
False) -> Extractor:
 
  459        Create a substring extractor. 
  461        The extractor will extract a substring of the input string. 
  464            string: The string to extract. 
  465            capture_name: The capture name of the extractor, or `None` if the extractor does not capture. 
  466            extract_empty_substring: Whether to extract an empty substring as a valid substring. 
  468            The substring extractor. 
 
  472                                                                           extract_empty_substring=extract_empty_substring))
 
  476    def build(self, vocabulary: kbnf.Vocabulary,
 
  477              decode: typing.Callable[[list[int]], str],
 
  478              engine_config: kbnf.Config = 
None) -> Formatter:
 
  480        Build a formatter from the builder. The builder will not be consumed and can be used again. 
  483            vocabulary: The KBNF engine vocabulary for the formatter. 
  484            decode: The callback to decode the token IDs to a string. 
  485            engine_config: The KBNF engine configuration. 
  490            self.
_main_rule) != 0, 
"An empty formatter builder cannot build!" 
  492        rules.append(f
"start ::= {' '.join(self._main_rule)};")
 
  493        grammar_str = 
"\n".join(rules)
 
 
  494        engine = kbnf.Engine(grammar_str, vocabulary, engine_config)
 
  496        f = 
Formatter(extractors, engine, decode, grammar_str)