2This module integrates the vllm library by providing convenience utilities.
11from vllm.transformers_utils.tokenizer
import AnyTokenizer
16 Logit processor that uses formatters to mask batch logits.
19 def __init__(self, formatters: typing.Sequence[FormatterBase |
None], eos_token_id: int,
20 configs: typing.Sequence[EngineGenerationConfig] |
None =
None):
26 assert len(configs) == len(formatters), \
27 f
"Number of formatters({len(formatters)}) must match number of configs({len(configs)})"
34 return [f.captures
if f
is not None else None for f
in self.
_formatters]
38 Check if the formatters are completed. Each boolean in the list corresponds to the
39 completion status of the formatter at the same index.
41 return [f.is_completed()
if f
is not None else None for f
in self.
_formatters]
43 def reset(self) -> None:
54 def __call__(self, prompt, generated_tokens, logits):
55 result = next(self.
_iter,
None)
58 raise ValueError(f
"Batch size {self._debug_counter} "
59 f
"is greater than number of formatters({len(self._formatters)})!")
60 if len(generated_tokens) == 0:
62 formatter, config = result
65 if config.reset_at_beginning
and formatter.is_completed():
67 if config.read_prompt:
69 formatter.accept_token(token)
71 assert result
is None, (f
"Batch size {self._debug_counter} "
72 f
"is less than number of formatters({len(self._formatters)})!")
74 result = next(self.
_iter)
79 while formatter.is_completed():
82 formatter, _ = next(self.
_iter)
85 if len(generated_tokens) != 0:
86 input_id = generated_tokens[-1]
87 if not formatter.is_completed():
88 formatter.accept_token(input_id)
90 if formatter.is_completed():
91 logits[:] = float(
"-inf")
94 formatter.compute_allowed_tokens()
95 logits = formatter.mask_logits(logits)
100 vocab_processors: typing.Optional[list[typing.Callable]] =
None) -> kbnf.Vocabulary:
102 Create a vocabulary for the KBNF engine.
104 tokenizer: The tokenizer.
105 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
106 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
108 vocab = tokenizer.get_vocab()
109 new_vocab = get_original_characters(vocab, vocab_processors)
110 return kbnf.Vocabulary({k: kbnf.Token(v)
for k, v
in new_vocab.items()}, {
111 v: k
for k, v
in vocab.items()})
115 formatter_builders: typing.Sequence[FormatterBuilder |
None] | FormatterBuilder,
116 configs: typing.Sequence[EngineGenerationConfig] =
None,
117 vocab_processors: typing.Optional[list[typing.Callable]] =
None) \
118 -> FormattersLogitsProcessor:
120 Create a formatter logits processor.
123 formatter_builders: The formatter builders.
124 configs: The engine generation configurations.
125 vocab_processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
126 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
128 tokenizer = llm.get_tokenizer()
130 if not isinstance(formatter_builders, collections.abc.Sequence):
131 formatter_builders = [formatter_builders]
132 formatters = [i.build(vocab,
lambda tokens: tokenizer.decode(tokens))
if i
is not None else None
133 for i
in formatter_builders]