2This module integrates the RWKV library by providing convenience utilities. 
   11__all__ = [
"create_engine_vocabulary", 
"PIPELINE", 
"PIPELINE_ARGS"]
 
   14    A wrapper for the arguments of the pipeline of RWKV. 
   27                 engine_gen_config=EngineGenerationConfig()):
 
   28        super().
__init__(temperature, top_p, top_k, alpha_frequency, alpha_presence, alpha_decay, token_ban, token_stop,
 
 
 
   35    Create a vocabulary for the KBNF engine. 
   37    assert WORD_NAME == 
'rwkv_vocab_v20230424', 
"Only world vocabulary is supported!" 
   38    return kbnf.Vocabulary({k: Token(v) 
for k, v 
in tokenizer.idx2token.items()},
 
   39                           {k: v.decode(
"UTF-8", errors=
"replace") 
for k, v 
in 
   40                            tokenizer.idx2token.items()})
 
 
   45    A wrapper for the pipeline of RWKV. 
   48    def __init__(self, model, WORD_NAME, formatter_builder: FormatterBuilder = 
None):  
 
   51        formatter = formatter_builder.build(vocabulary, 
lambda tokens: self.tokenizer.decode(tokens))
 
   52        if formatter 
is not None:
 
   57    def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=
None, state=
None):
 
 
   62        if args.engine_gen_config.reset_at_beginning 
and self.
formatter and self.
formatter.is_completed():
 
   64        for i 
in range(token_count):
 
   66            tokens = self.encode(ctx) 
if i == 0 
else [token]
 
   68                if i == 0 
and args.engine_gen_config.read_prompt:
 
   71            while len(tokens) > 0:
 
   72                out, state = self.model.forward(tokens[:args.chunk_len], state)
 
   73                tokens = tokens[args.chunk_len:]
 
   76            for n 
in args.token_ban:
 
   77                out[n] = -float(
'inf')
 
   79                out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
 
   82                formatter.compute_allowed_tokens()
 
   83                out = out[:len(self.tokenizer.idx2token) + 1]  
 
   84                out = formatter.mask_logits(out)
 
   86            token = self.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
 
   89            if token 
in args.token_stop:
 
   92            for xxx 
in occurrence:
 
   93                occurrence[xxx] *= args.alpha_decay
 
   95            ttt = self.decode([token])
 
   97            if ttt 
in ' \t0123456789':
 
   99            if token 
not in occurrence:
 
  100                occurrence[token] = www
 
  102                occurrence[token] += www
 
  106            tmp = self.decode(all_tokens[out_last:])
 
  107            if '\ufffd' not in tmp: