2This module integrates the RWKV library by providing convenience utilities.
11__all__ = [
"create_engine_vocabulary",
"PIPELINE",
"PIPELINE_ARGS"]
14 A wrapper for the arguments of the pipeline of RWKV.
27 engine_gen_config=EngineGenerationConfig()):
28 super().
__init__(temperature, top_p, top_k, alpha_frequency, alpha_presence, alpha_decay, token_ban, token_stop,
35 Create a vocabulary for the KBNF engine.
37 assert WORD_NAME ==
'rwkv_vocab_v20230424',
"Only world vocabulary is supported!"
38 return kbnf.Vocabulary({k: Token(v)
for k, v
in tokenizer.idx2token.items()},
39 {k: v.decode(
"UTF-8", errors=
"replace")
for k, v
in
40 tokenizer.idx2token.items()})
45 A wrapper for the pipeline of RWKV.
48 def __init__(self, model, WORD_NAME, formatter_builder: FormatterBuilder =
None):
51 formatter = formatter_builder.build(vocabulary,
lambda tokens: self.tokenizer.decode(tokens))
52 if formatter
is not None:
57 def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=
None, state=
None):
62 if args.engine_gen_config.reset_at_beginning
and self.
formatter and self.
formatter.is_completed():
64 for i
in range(token_count):
66 tokens = self.encode(ctx)
if i == 0
else [token]
68 if i == 0
and args.engine_gen_config.read_prompt:
71 while len(tokens) > 0:
72 out, state = self.model.forward(tokens[:args.chunk_len], state)
73 tokens = tokens[args.chunk_len:]
76 for n
in args.token_ban:
77 out[n] = -float(
'inf')
79 out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
82 formatter.compute_allowed_tokens()
83 out = out[:len(self.tokenizer.idx2token) + 1]
84 out = formatter.mask_logits(out)
86 token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
89 if token
in args.token_stop:
92 for xxx
in occurrence:
93 occurrence[xxx] *= args.alpha_decay
95 ttt = self.decode([token])
97 if ttt
in ' \t0123456789':
99 if token
not in occurrence:
100 occurrence[token] = www
102 occurrence[token] += www
106 tmp = self.decode(all_tokens[out_last:])
107 if '\ufffd' not in tmp: