3from functools
import lru_cache
5__all__ = [
"get_original_characters",
"update_vocab_0xHH",
"update_vocab_sentencepiece",
"update_vocab_dot_G"]
7def _multiple_replace(replacements: typing.Dict[bytes, bytes], regex: re.Pattern[bytes], text: bytes) -> bytes:
9 return regex.sub(
lambda mo: replacements[mo.group()], text)
13 processors: typing.Optional[list[typing.Callable]] =
None) -> typing.Dict[int, bytes]:
15 Get a vocabulary of original characters unmangled to raw UTF-8 bytes by the provided processors.
18 vocab: The mangled vocabulary.
19 processors: List of callables with signature (token_to_char: typing.Dict[bytes, bytes])->None.
20 Callables can be used to "unmangle" encoded characters to original characters. If None, processors will be auto-detected.
22 old_char_to_new_char = {}
23 assert len(set(vocab.values())) == len(vocab),
"Vocabulary contains duplicate token IDs!"
24 if processors
is None:
26 for update_vocab
in processors:
27 update_vocab(old_char_to_new_char)
29 regex = re.compile(b
"(%s)" % b
"|".join(sorted(list(map(re.escape, old_char_to_new_char.keys())), key=
lambda x: len(x), reverse=
True)))
34 new_vocab[token_id] = new_k
40 Autodetect vocabulary processors.
43 llama_present = any(i.find(
'<0xF0>') != -1
for i
in vocab.keys())
44 underscore_present = (len([1
for i
in vocab.keys()
if i.find(
'\u2581') != -1]) / len(vocab)) > 0.2
45 g_present = (len([1
for i
in vocab.keys()
if i.find(
'\u0120') != -1]) / len(vocab)) > 0.2
47 result.append(update_vocab_0xHH)
48 if underscore_present:
49 result.append(update_vocab_sentencepiece)
51 result.append(update_vocab_dot_G)
57 Vocabulary processor for <0xHH> tokens (used in llama tokenizers)
60 token_to_char[(
"<0x" + f
"{j:02x}".upper() +
">").encode(
"UTF-8")] = bytes([j])
65 Vocabulary processor for ▁ token (used in sentencepiece tokenizers)
67 token_to_char[
"\u2581".encode(
"UTF-8")] = b
" "
72 Vocabulary processor for GPT2 style token mangling, like from \\n to Ġ(used in huggingface bytelevel preprocessors)
82 bs = list(range(ord(
"!"), ord(
"~")+1))+list(range(ord(
"¡"), ord(
"¬")+1))+list(range(ord(
"®"), ord(
"ÿ")+1))
90 cs = [chr(n).encode(
"UTF-8")
for n
in cs]
91 for i
in range(len(bs)):
92 bs[i] = bytes([bs[i]])
93 return dict(zip(cs, bs))