Formatron v0.5.0
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
transformers.py
Go to the documentation of this file.
1"""
2This module integrates the transformers library by providing convenience utilities.
3"""
4import collections
5import typing
6import torch
7import kbnf
8from transformers import LogitsProcessor, PreTrainedTokenizerBase, LogitsProcessorList
9
10from formatron.config import EngineGenerationConfig
11from formatron.formatter import FormatterBuilder, FormatterBase
12from formatron.integrations.utils import get_original_characters
13
14__all__ = ["create_engine_vocabulary", "create_formatter_logits_processor", "create_formatter_logits_processor_list", "FormattersLogitsProcessor"]
15
16def create_engine_vocabulary(tokenizer: PreTrainedTokenizerBase,
17 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> kbnf.Vocabulary:
18 """
19 Create a vocabulary for the KBNF engine.
20 Args:
21 tokenizer: The tokenizer.
22 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
23 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
24 """
25 vocab = tokenizer.get_vocab()
26 new_vocab = get_original_characters(vocab, vocab_processors)
27 return kbnf.Vocabulary({k: kbnf.Token(v) for k, v in new_vocab.items()},
28 {v: k for k, v in vocab.items()})
29
30
31def create_formatter_logits_processor(tokenizer: PreTrainedTokenizerBase,
32 formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
33 configs: typing.Sequence[EngineGenerationConfig] = None,
34 vocab_processors: typing.Optional[list[typing.Callable]] = None) -> LogitsProcessor:
35 """
36 Create a formatter logits processor.
37 Args:
38 tokenizer: The tokenizer.
39 formatter_builders: The formatter builders.
40 configs: The engine generation configurations.
41 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
42 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
43 """
44 vocab = create_engine_vocabulary(tokenizer, vocab_processors)
45 if not isinstance(formatter_builders, collections.abc.Sequence):
46 formatter_builders = [formatter_builders]
47 formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None
48 for i in formatter_builders]
49 return FormattersLogitsProcessor(formatters, tokenizer.eos_token_id, configs)
50
51
52def create_formatter_logits_processor_list(tokenizer: PreTrainedTokenizerBase,
53 formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
54 configs: typing.Sequence[EngineGenerationConfig] = None,
55 vocab_processors: typing.Optional[list[typing.Callable]] = None) \
56 -> LogitsProcessorList:
57 """
58 Create a formatter logits processor list.
59 Args:
60 tokenizer: The tokenizer.
61 formatter_builders: The formatter builders.
62 configs: The engine generation configurations.
63 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
64 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
65 """
66 return LogitsProcessorList([create_formatter_logits_processor(tokenizer,
67 formatter_builders, configs, vocab_processors)])
68
69
70class FormattersLogitsProcessor(LogitsProcessor):
71 """
72 Logit processor that uses formatters to mask batch logits.
73 """
75 def __init__(self, formatters: typing.Sequence[FormatterBase | None], eos_token_id: int,
76 configs: typing.Sequence[EngineGenerationConfig] | None = None):
77 self._formatters = formatters
78 self._eos_token_id = eos_token_id
80 if configs is None:
81 configs = [EngineGenerationConfig() for _ in formatters]
82 assert len(configs) == len(formatters), \
83 f"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
84 self.configs = configs
85 self._mask_logits_fn = FormattersLogitsProcessor._get_fastest_compatible_logits_mask_fn()
87 @staticmethod
89 def default_mask_logits_fn(bit_masks, formatter, scores, i):
90 scores[i, :] = formatter.mask_logits(scores[i, :])
91 try:
92 from kbnf.triton_logits_mask import mask_logits_inplace
93 def fast_mask_logits_fn(bit_masks, formatter, scores, i):
94 mask_logits_inplace(scores[i, :], bit_masks[i, :], [formatter._engine])
95 return fast_mask_logits_fn
96 except ImportError:
97 return default_mask_logits_fn
98
99 def reset(self) -> None:
101 for f in self._formatters:
102 if f is not None:
103 f.reset()
104
105 @property
106 def formatters_captures(self) -> list[dict[str, typing.Any] | None]:
107 """
108 Get the captures of the formatters. Each element in the list corresponds to the
109 captures of the formatter at the same index. If the formatter is None, the element
110 is None.
111 """
112 return [f.captures if f is not None else None for f in self._formatters]
113
114 def is_completed(self) -> list[bool | None]:
115 """
116 Check if the formatters are completed. Each boolean in the list corresponds to the
117 completion status of the formatter at the same index. If the formatter is None,
118 the element is None.
119 """
120 return [f.is_completed() if f is not None else None for f in self._formatters]
121
122 def __call__(self, input_ids, scores):
123 assert input_ids.shape[0] == len(self._formatters), (f"Number of formatters({len(self._formatters)})"
124 f" must match batch size({input_ids.shape[0]})")
125 if self._last_input_id_length is None: # First iteration
126 self._last_input_id_length = input_ids.shape[1]
127 for formatter, config, prompt in zip(self._formatters, self.configs, input_ids):
128 if formatter is None:
129 continue
130 if config.reset_at_beginning:
131 formatter.reset()
132 if config.read_prompt:
133 for token in prompt:
134 formatter.accept_token(token)
135 self._bit_masks = torch.empty((scores.shape[0],
136 (scores.shape[1]+31)//32), dtype=torch.int32, device='cpu', pin_memory=True)
137 else:
138 assert input_ids.shape[1] == self._last_input_id_length + 1, ("One iteration in generation loop"
139 " must add exactly one token.")
140 self._last_input_id_length += 1
141 for formatter, input_id in zip(self._formatters, input_ids[:, -1]):
142 if formatter is not None and not formatter.is_completed():
143 formatter.accept_token(input_id)
144 for i, formatter in enumerate(self._formatters):
145 if formatter is None:
146 continue
147 if formatter.is_completed():
148 scores[i, :] = float("-inf")
149 scores[i, self._eos_token_id] = 0.0
150 continue
151 formatter.compute_allowed_tokens()
152 self._mask_logits_fn(self._bit_masks, formatter, scores, i)
153 return scores
Configuration for how an KBNF engine should be used in text generation.
Definition config.py:14
Logit processor that uses formatters to mask batch logits.
list[bool|None] is_completed(self)
Check if the formatters are completed.
list[dict[str, typing.Any]|None] formatters_captures(self)
Get the captures of the formatters.
__init__(self, typing.Sequence[FormatterBase|None] formatters, int eos_token_id, typing.Sequence[EngineGenerationConfig]|None configs=None)
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
LogitsProcessor create_formatter_logits_processor(PreTrainedTokenizerBase tokenizer, typing.Sequence[FormatterBuilder|None]|FormatterBuilder formatter_builders, typing.Sequence[EngineGenerationConfig] configs=None, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a formatter logits processor.
LogitsProcessorList create_formatter_logits_processor_list(PreTrainedTokenizerBase tokenizer, typing.Sequence[FormatterBuilder|None]|FormatterBuilder formatter_builders, typing.Sequence[EngineGenerationConfig] configs=None, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a formatter logits processor list.
kbnf.Vocabulary create_engine_vocabulary(PreTrainedTokenizerBase tokenizer, typing.Optional[list[typing.Callable]] vocab_processors=None)
Create a vocabulary for the KBNF engine.