2This module integrates the transformers library by providing convenience utilities.
8from transformers
import LogitsProcessor, PreTrainedTokenizerBase, LogitsProcessorList
17 Create a vocabulary for the KBNF engine.
19 vocab = tokenizer.get_vocab()
20 new_vocab = get_original_characters(vocab)
21 return kbnf.Vocabulary({v: kbnf.Token(k)
for k, v
in new_vocab.items()},
22 {v: k
for k, v
in vocab.items()})
26 formatter_builders: typing.Sequence[FormatterBuilder |
None] | FormatterBuilder,
27 configs: typing.Sequence[EngineGenerationConfig] =
None) -> LogitsProcessor:
29 Create a formatter logits processor.
32 if not isinstance(formatter_builders, collections.abc.Sequence):
33 formatter_builders = [formatter_builders]
34 formatters = [i.build(vocab,
lambda tokens: tokenizer.decode(tokens))
if i
is not None else None
35 for i
in formatter_builders]
40 formatter_builders: typing.Sequence[FormatterBuilder |
None] | FormatterBuilder,
41 configs: typing.Sequence[EngineGenerationConfig] =
None) \
42 -> LogitsProcessorList:
44 Create a formatter logits processor list.
47 formatter_builders, configs)])
52 Logit processor that uses formatters to mask batch logits.
55 def __init__(self, formatters: typing.Sequence[FormatterBase |
None], eos_token_id: int,
56 configs: typing.Sequence[EngineGenerationConfig] =
None):
62 assert len(configs) == len(formatters), \
63 f
"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
66 def reset(self) -> None:
75 Get the captures of the formatters. Each element in the list corresponds to the
76 captures of the formatter at the same index. If the formatter is None, the element
79 return [f.captures
if f
is not None else None for f
in self.
_formatters]
83 Check if the formatters are completed. Each boolean in the list corresponds to the
84 completion status of the formatter at the same index. If the formatter is None,
87 return [f.is_completed()
if f
is not None else None for f
in self.
_formatters]
89 def __call__(self, input_ids, scores):
90 assert input_ids.shape[0] == len(self.
_formatters), (f
"Number of formatters({len(self._formatters)})"
91 f
" must match batch size({input_ids.shape[0]})")
97 if config.reset_at_beginning:
99 if config.read_prompt:
101 formatter.accept_token(token)
104 " must add exactly one token.")
106 for formatter, input_id
in zip(self.
_formatters, input_ids[:, -1]):
108 formatter.accept_token(input_id)
110 if formatter
is None:
112 if formatter.is_completed():
113 scores[i, :] = float(
"-inf")
116 formatter.compute_allowed_tokens()
117 scores[i, :] = formatter.mask_logits(scores[i, :])