2This module integrates the transformers library by providing convenience utilities.
8from transformers
import LogitsProcessor, PreTrainedTokenizerBase, LogitsProcessorList
14__all__ = [
"create_engine_vocabulary",
"create_formatter_logits_processor",
"create_formatter_logits_processor_list",
"FormattersLogitsProcessor"]
17 vocab_processors: typing.Optional[list[typing.Callable]] =
None) -> kbnf.Vocabulary:
19 Create a vocabulary for the KBNF engine.
21 tokenizer: The tokenizer.
22 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
23 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
25 vocab = tokenizer.get_vocab()
26 new_vocab = get_original_characters(vocab, vocab_processors)
27 return kbnf.Vocabulary({k: kbnf.Token(v)
for k, v
in new_vocab.items()},
28 {v: k
for k, v
in vocab.items()})
32 formatter_builders: typing.Sequence[FormatterBuilder |
None] | FormatterBuilder,
33 configs: typing.Sequence[EngineGenerationConfig] =
None,
34 vocab_processors: typing.Optional[list[typing.Callable]] =
None) -> LogitsProcessor:
36 Create a formatter logits processor.
38 tokenizer: The tokenizer.
39 formatter_builders: The formatter builders.
40 configs: The engine generation configurations.
41 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
42 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
45 if not isinstance(formatter_builders, collections.abc.Sequence):
46 formatter_builders = [formatter_builders]
47 formatters = [i.build(vocab,
lambda tokens: tokenizer.decode(tokens))
if i
is not None else None
48 for i
in formatter_builders]
53 formatter_builders: typing.Sequence[FormatterBuilder |
None] | FormatterBuilder,
54 configs: typing.Sequence[EngineGenerationConfig] =
None,
55 vocab_processors: typing.Optional[list[typing.Callable]] =
None) \
56 -> LogitsProcessorList:
58 Create a formatter logits processor list.
60 tokenizer: The tokenizer.
61 formatter_builders: The formatter builders.
62 configs: The engine generation configurations.
63 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
64 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
67 formatter_builders, configs, vocab_processors)])
72 Logit processor that uses formatters to mask batch logits.
75 def __init__(self, formatters: typing.Sequence[FormatterBase |
None], eos_token_id: int,
76 configs: typing.Sequence[EngineGenerationConfig] |
None =
None):
82 assert len(configs) == len(formatters), \
83 f
"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
85 self.
_mask_logits_fn = FormattersLogitsProcessor._get_fastest_compatible_logits_mask_fn()
89 def default_mask_logits_fn(bit_masks, formatter, scores, i):
90 scores[i, :] = formatter.mask_logits(scores[i, :])
92 from kbnf.triton_logits_mask
import mask_logits_inplace
93 def fast_mask_logits_fn(bit_masks, formatter, scores, i):
94 mask_logits_inplace(scores[i, :], bit_masks[i, :], [formatter._engine])
95 return fast_mask_logits_fn
97 return default_mask_logits_fn
99 def reset(self) -> None:
108 Get the captures of the formatters. Each element in the list corresponds to the
109 captures of the formatter at the same index. If the formatter is None, the element
112 return [f.captures
if f
is not None else None for f
in self.
_formatters]
116 Check if the formatters are completed. Each boolean in the list corresponds to the
117 completion status of the formatter at the same index. If the formatter is None,
120 return [f.is_completed()
if f
is not None else None for f
in self.
_formatters]
122 def __call__(self, input_ids, scores):
123 assert input_ids.shape[0] == len(self.
_formatters), (f
"Number of formatters({len(self._formatters)})"
124 f
" must match batch size({input_ids.shape[0]})")
128 if formatter
is None:
130 if config.reset_at_beginning:
132 if config.read_prompt:
134 formatter.accept_token(token)
135 self.
_bit_masks = torch.empty((scores.shape[0],
136 (scores.shape[1]+31)//32), dtype=torch.int32, device=
'cpu', pin_memory=
True)
139 " must add exactly one token.")
141 for formatter, input_id
in zip(self.
_formatters, input_ids[:, -1]):
142 if formatter
is not None and not formatter.is_completed():
143 formatter.accept_token(input_id)
145 if formatter
is None:
147 if formatter.is_completed():
148 scores[i, :] = float(
"-inf")
151 formatter.compute_allowed_tokens()