Formatron v0.4.9
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
transformers.py
Go to the documentation of this file.
1"""
2This module integrates the transformers library by providing convenience utilities.
3"""
4import collections
5import typing
6
7import kbnf
8from transformers import LogitsProcessor, PreTrainedTokenizerBase, LogitsProcessorList
9
10from formatron.config import EngineGenerationConfig
11from formatron.formatter import FormatterBuilder, FormatterBase
12from formatron.integrations.utils import get_original_characters
13
14__all__ = ["create_engine_vocabulary", "create_formatter_logits_processor", "create_formatter_logits_processor_list", "FormattersLogitsProcessor"]
15
16def create_engine_vocabulary(tokenizer: PreTrainedTokenizerBase,
17 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary:
18 """
19 Create a vocabulary for the KBNF engine.
20 Args:
21 tokenizer: The tokenizer.
22 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
23 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
24 """
25 vocab = tokenizer.get_vocab()
26 new_vocab = get_original_characters(vocab, vocab_processors)
27 return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()},
28 {v: k for k, v in vocab.items()})
29
30
31def create_formatter_logits_processor(tokenizer: PreTrainedTokenizerBase,
32 formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
33 configs: typing.Sequence[EngineGenerationConfig] = None,
34 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> LogitsProcessor:
35 """
36 Create a formatter logits processor.
37 Args:
38 tokenizer: The tokenizer.
39 formatter_builders: The formatter builders.
40 configs: The engine generation configurations.
41 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
42 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
43 """
44 vocab = create_engine_vocabulary(tokenizer, vocab_processors)
45 if not isinstance(formatter_builders, collections.abc.Sequence):
46 formatter_builders = [formatter_builders]
47 formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None
48 for i in formatter_builders]
49 return FormattersLogitsProcessor(formatters, tokenizer.eos_token_id, configs)
50
51
52def create_formatter_logits_processor_list(tokenizer: PreTrainedTokenizerBase,
53 formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
54 configs: typing.Sequence[EngineGenerationConfig] = None,
55 vocab_processors: typing.Optional[list[typing.Callable]] = None) \
56 -> LogitsProcessorList:
57 """
58 Create a formatter logits processor list.
59 Args:
60 tokenizer: The tokenizer.
61 formatter_builders: The formatter builders.
62 configs: The engine generation configurations.
63 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
64 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
65 """
66 return LogitsProcessorList([create_formatter_logits_processor(tokenizer,
67 formatter_builders, configs, vocab_processors)])
68
69
70class FormattersLogitsProcessor(LogitsProcessor):
71 """
72 Logit processor that uses formatters to mask batch logits.
73 """
75 def __init__(self, formatters: typing.Sequence[FormatterBase | None], eos_token_id: int,
76 configs: typing.Sequence[EngineGenerationConfig] | None = None):
77 self._formatters = formatters
78 self._eos_token_id = eos_token_id
80 if configs is None:
81 configs = [EngineGenerationConfig() for _ in formatters]
82 assert len(configs) == len(formatters), \
83 f"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
84 self.configs = configs
86 def reset(self) -> None:
88 for f in self._formatters:
89 if f is not None:
90 f.reset()
91
92 @property
93 def formatters_captures(self) -> list[dict[str, typing.Any] | None]:
94 """
95 Get the captures of the formatters. Each element in the list corresponds to the
96 captures of the formatter at the same index. If the formatter is None, the element
97 is None.
98 """
99 return [f.captures if f is not None else None for f in self._formatters]
100
101 def is_completed(self) -> list[bool | None]:
102 """
103 Check if the formatters are completed. Each boolean in the list corresponds to the
104 completion status of the formatter at the same index. If the formatter is None,
105 the element is None.
106 """
107 return [f.is_completed() if f is not None else None for f in self._formatters]
108
109 def __call__(self, input_ids, scores):
110 assert input_ids.shape[0] == len(self._formatters), (f"Number of formatters({len(self._formatters)})"
111 f" must match batch size({input_ids.shape[0]})")
112 if self._last_input_id_length is None: # First iteration
113 self._last_input_id_length = input_ids.shape[1]
114 for formatter, config, prompt in zip(self._formatters, self.configs, input_ids):
115 if formatter is None:
116 continue
117 if config.reset_at_beginning:
118 formatter.reset()
119 if config.read_prompt:
120 for token in prompt:
121 formatter.accept_token(token)
122 else:
123 assert input_ids.shape[1] == self._last_input_id_length + 1, ("One iteration in generation loop"
124 " must add exactly one token.")
125 self._last_input_id_length += 1
126 for formatter, input_id in zip(self._formatters, input_ids[:, -1]):
127 if formatter is not None and not formatter.is_completed():
128 formatter.accept_token(input_id)
129 for i, formatter in enumerate(self._formatters):
130 if formatter is None:
131 continue
132 if formatter.is_completed():
133 scores[i, :] = float("-inf")
134 scores[i, self._eos_token_id] = 0.0
135 continue
136 formatter.compute_allowed_tokens()
137 scores[i, :] = formatter.mask_logits(scores[i, :])
138 return scores
Configuration for how an KBNF engine should be used in text generation.
Definition config.py:14
Logit processor that uses formatters to mask batch logits.
list[bool|None] is_completed(self)
Check if the formatters are completed.
list[dict[str, typing.Any]|None] formatters_captures(self)
Get the captures of the formatters.
__init__(self, typing.Sequence[FormatterBase|None] formatters, int eos_token_id, typing.Sequence[EngineGenerationConfig]|None configs=None)
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
LogitsProcessor create_formatter_logits_processor(PreTrainedTokenizerBase tokenizer, typing.Sequence[FormatterBuilder|None]|FormatterBuilder formatter_builders, typing.Sequence[EngineGenerationConfig] configs=None, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a formatter logits processor.
LogitsProcessorList create_formatter_logits_processor_list(PreTrainedTokenizerBase tokenizer, typing.Sequence[FormatterBuilder|None]|FormatterBuilder formatter_builders, typing.Sequence[EngineGenerationConfig] configs=None, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a formatter logits processor list.
kbnf.Vocabulary create_engine_vocabulary(PreTrainedTokenizerBase tokenizer, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a vocabulary for the KBNF engine.