Formatron v0.4.2
Formatron empowers everyone to control the output format of language models with minimal overhead.
Loading...
Searching...
No Matches
transformers.py
Go to the documentation of this file.
1"""
2This module integrates the transformers library by providing convenience utilities.
3"""
4import collections
5import typing
6
7import kbnf
8from transformers import LogitsProcessor, PreTrainedTokenizerBase, LogitsProcessorList
9
10from formatron.config import EngineGenerationConfig
11from formatron.formatter import FormatterBuilder, FormatterBase
12from formatron.integrations._utils import get_original_characters
13
14
15def create_engine_vocabulary(tokenizer: PreTrainedTokenizerBase) -> kbnf.Vocabulary:
16 """
17 Create a vocabulary for the KBNF engine.
18 """
19 vocab = tokenizer.get_vocab()
20 new_vocab = get_original_characters(vocab)
21 return kbnf.Vocabulary({v: kbnf.Token(k) for k, v in new_vocab.items()},
22 {v: k for k, v in vocab.items()})
23
24
25def create_formatter_logits_processor(tokenizer: PreTrainedTokenizerBase,
26 formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
27 configs: typing.Sequence[EngineGenerationConfig] = None) -> LogitsProcessor:
28 """
29 Create a formatter logits processor.
30 """
31 vocab = create_engine_vocabulary(tokenizer)
32 if not isinstance(formatter_builders, collections.abc.Sequence):
33 formatter_builders = [formatter_builders]
34 formatters = [i.build(vocab, lambda tokens: tokenizer.decode(tokens)) if i is not None else None
35 for i in formatter_builders]
36 return FormattersLogitsProcessor(formatters, tokenizer.eos_token_id, configs)
37
38
39def create_formatter_logits_processor_list(tokenizer: PreTrainedTokenizerBase,
40 formatter_builders: typing.Sequence[FormatterBuilder | None] | FormatterBuilder,
41 configs: typing.Sequence[EngineGenerationConfig] = None) \
42 -> LogitsProcessorList:
43 """
44 Create a formatter logits processor list.
45 """
46 return LogitsProcessorList([create_formatter_logits_processor(tokenizer,
47 formatter_builders, configs)])
48
49
50class FormattersLogitsProcessor(LogitsProcessor):
51 """
52 Logit processor that uses formatters to mask batch logits.
53 """
55 def __init__(self, formatters: typing.Sequence[FormatterBase | None], eos_token_id: int,
56 configs: typing.Sequence[EngineGenerationConfig] = None):
57 self._formatters = formatters
58 self._eos_token_id = eos_token_id
60 if configs is None:
61 configs = [EngineGenerationConfig() for _ in formatters]
62 assert len(configs) == len(formatters), \
63 f"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
64 self.configs = configs
66 def reset(self) -> None:
68 for f in self._formatters:
69 if f is not None:
70 f.reset()
71
72 @property
73 def formatters_captures(self) -> list[dict[str, typing.Any] | None]:
74 """
75 Get the captures of the formatters. Each element in the list corresponds to the
76 captures of the formatter at the same index. If the formatter is None, the element
77 is None.
78 """
79 return [f.captures if f is not None else None for f in self._formatters]
80
81 def is_completed(self) -> list[bool | None]:
82 """
83 Check if the formatters are completed. Each boolean in the list corresponds to the
84 completion status of the formatter at the same index. If the formatter is None,
85 the element is None.
86 """
87 return [f.is_completed() if f is not None else None for f in self._formatters]
88
89 def __call__(self, input_ids, scores):
90 assert input_ids.shape[0] == len(self._formatters), (f"Number of formatters({len(self._formatters)})"
91 f" must match batch size({input_ids.shape[0]})")
92 if self._last_input_id_length is None: # First iteration
93 self._last_input_id_length = input_ids.shape[1]
94 for formatter, config, prompt in zip(self._formatters, self.configs, input_ids):
95 if formatter is None:
96 continue
97 if config.reset_at_beginning:
98 formatter.reset()
99 if config.read_prompt:
100 for token in prompt:
101 formatter.accept_token(token)
102 else:
103 assert input_ids.shape[1] == self._last_input_id_length + 1, ("One iteration in generation loop"
104 " must add exactly one token.")
105 self._last_input_id_length += 1
106 for formatter, input_id in zip(self._formatters, input_ids[:, -1]):
107 if input_id != self._eos_token_id and formatter is not None:
108 formatter.accept_token(input_id)
109 for i, formatter in enumerate(self._formatters):
110 if formatter is None:
111 continue
112 if formatter.is_completed():
113 scores[i, :] = float("-inf")
114 scores[i, self._eos_token_id] = 0.0
115 continue
116 formatter.compute_allowed_tokens()
117 scores[i, :] = formatter.mask_logits(scores[i, :])
118 return scores
Configuration for how an KBNF engine should be used in text generation.
Definition config.py:14
Logit processor that uses formatters to mask batch logits.
list[bool|None] is_completed(self)
Check if the formatters are completed.
list[dict[str, typing.Any]|None] formatters_captures(self)
Get the captures of the formatters.
__init__(self, typing.Sequence[FormatterBase|None] formatters, int eos_token_id, typing.Sequence[EngineGenerationConfig] configs=None)
Configuration classes for Formatron.
Definition config.py:1
This module contains the Formatter class and its related classes.
Definition formatter.py:1
kbnf.Vocabulary create_engine_vocabulary(PreTrainedTokenizerBase tokenizer)
Create a vocabulary for the KBNF engine.
LogitsProcessorList create_formatter_logits_processor_list(PreTrainedTokenizerBase tokenizer, typing.Sequence[FormatterBuilder|None]|FormatterBuilder formatter_builders, typing.Sequence[EngineGenerationConfig] configs=None)
Create a formatter logits processor list.
LogitsProcessor create_formatter_logits_processor(PreTrainedTokenizerBase tokenizer, typing.Sequence[FormatterBuilder|None]|FormatterBuilder formatter_builders, typing.Sequence[EngineGenerationConfig] configs=None)
Create a formatter logits processor.