2This module integrates the transformers library by providing convenience utilities.
8from transformers
import LogitsProcessor, PreTrainedTokenizerBase, LogitsProcessorList
14__all__ = [
"create_engine_vocabulary",
"create_formatter_logits_processor",
"create_formatter_logits_processor_list",
"FormattersLogitsProcessor"]
17 vocab_processors: typing.Optional[list[typing.Callable]] =
None) -> kbnf.Vocabulary:
19 Create a vocabulary for the KBNF engine.
21 tokenizer: The tokenizer.
22 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
23 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
25 vocab = tokenizer.get_vocab()
26 new_vocab = get_original_characters(vocab, vocab_processors)
27 return kbnf.Vocabulary({k: kbnf.Token(v)
for k, v
in new_vocab.items()},
28 {v: k
for k, v
in vocab.items()})
32 formatter_builders: typing.Sequence[FormatterBuilder |
None] | FormatterBuilder,
33 configs: typing.Sequence[EngineGenerationConfig] =
None,
34 vocab_processors: typing.Optional[list[typing.Callable]] =
None) -> LogitsProcessor:
36 Create a formatter logits processor.
38 tokenizer: The tokenizer.
39 formatter_builders: The formatter builders.
40 configs: The engine generation configurations.
41 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
42 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
45 if not isinstance(formatter_builders, collections.abc.Sequence):
46 formatter_builders = [formatter_builders]
47 formatters = [i.build(vocab,
lambda tokens: tokenizer.decode(tokens))
if i
is not None else None
48 for i
in formatter_builders]
53 formatter_builders: typing.Sequence[FormatterBuilder |
None] | FormatterBuilder,
54 configs: typing.Sequence[EngineGenerationConfig] =
None,
55 vocab_processors: typing.Optional[list[typing.Callable]] =
None) \
56 -> LogitsProcessorList:
58 Create a formatter logits processor list.
60 tokenizer: The tokenizer.
61 formatter_builders: The formatter builders.
62 configs: The engine generation configurations.
63 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
64 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
67 formatter_builders, configs, vocab_processors)])
72 Logit processor that uses formatters to mask batch logits.
75 def __init__(self, formatters: typing.Sequence[FormatterBase |
None], eos_token_id: int,
76 configs: typing.Sequence[EngineGenerationConfig] |
None =
None):
82 assert len(configs) == len(formatters), \
83 f
"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
86 def reset(self) -> None:
95 Get the captures of the formatters. Each element in the list corresponds to the
96 captures of the formatter at the same index. If the formatter is None, the element
99 return [f.captures
if f
is not None else None for f
in self.
_formatters]
103 Check if the formatters are completed. Each boolean in the list corresponds to the
104 completion status of the formatter at the same index. If the formatter is None,
107 return [f.is_completed()
if f
is not None else None for f
in self.
_formatters]
109 def __call__(self, input_ids, scores):
110 assert input_ids.shape[0] == len(self.
_formatters), (f
"Number of formatters({len(self._formatters)})"
111 f
" must match batch size({input_ids.shape[0]})")
115 if formatter
is None:
117 if config.reset_at_beginning:
119 if config.read_prompt:
121 formatter.accept_token(token)
124 " must add exactly one token.")
126 for formatter, input_id
in zip(self.
_formatters, input_ids[:, -1]):
127 if formatter
is not None and not formatter.is_completed():
128 formatter.accept_token(input_id)
130 if formatter
is None:
132 if formatter.is_completed():
133 scores[i, :] = float(
"-inf")
136 formatter.compute_allowed_tokens()
137 scores[i, :] = formatter.mask_logits(scores[i, :])