2This module integrates the vllm library by providing convenience utilities.
12from vllm.transformers_utils.tokenizer
import AnyTokenizer
17 Logit processor that uses formatters to mask batch logits.
20 def __init__(self, formatters: typing.Sequence[FormatterBase |
None], eos_token_id: int,
21 configs: typing.Sequence[EngineGenerationConfig] |
None =
None):
27 assert len(configs) == len(formatters), \
28 f
"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
37 return [f.captures
if f
is not None else None for f
in self.
_formatters]
41 Check if the formatters are completed. Each boolean in the list corresponds to the
42 completion status of the formatter at the same index.
44 return [f.is_completed()
if f
is not None else None for f
in self.
_formatters]
46 def reset(self) -> None:
59 def __call__(self, prompt, generated_tokens, logits):
60 result = next(self.
_iter,
None)
63 raise ValueError(f
"Batch size {self._debug_counter} "
64 f
"is greater than number of formatters({len(self._formatters)})!")
66 if len(generated_tokens) == 0:
68 formatter, config = result
72 if config.reset_at_beginning
and formatter.is_completed():
74 if config.read_prompt:
76 formatter.accept_token(token)
79 assert result
is None, (f
"Batch size {self._debug_counter} "
80 f
"is less than number of formatters({len(self._formatters)})!")
82 result = next(self.
_iter)
89 while formatter.is_completed():
92 formatter, _ = next(self.
_iter)
95 if len(generated_tokens) != 0:
96 input_id = generated_tokens[-1]
97 if not formatter.is_completed():
98 formatter.accept_token(input_id)
99 if formatter.is_completed():
100 logits[:] = float(
"-inf")
103 formatter.compute_allowed_tokens()
109 vocab_processors: typing.Optional[list[typing.Callable]] =
None) -> kbnf.Vocabulary:
111 Create a vocabulary for the KBNF engine.
113 tokenizer: The tokenizer.
114 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
115 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
117 vocab = tokenizer.get_vocab()
118 new_vocab = get_original_characters(vocab, vocab_processors)
119 return kbnf.Vocabulary({k: kbnf.Token(v)
for k, v
in new_vocab.items()}, {
120 v: k
for k, v
in vocab.items()})
124 formatter_builders: typing.Sequence[FormatterBuilder |
None] | FormatterBuilder,
125 configs: typing.Sequence[EngineGenerationConfig] =
None,
126 vocab_processors: typing.Optional[list[typing.Callable]] =
None) \
127 -> FormattersLogitsProcessor:
129 Create a formatter logits processor.
132 formatter_builders: The formatter builders.
133 configs: The engine generation configurations.
134 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
135 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
137 tokenizer = llm.get_tokenizer()
139 if not isinstance(formatter_builders, collections.abc.Sequence):
140 formatter_builders = [formatter_builders]
141 formatters = [i.build(vocab,
lambda tokens: tokenizer.decode(tokens))
if i
is not None else None
142 for i
in formatter_builders]