【rk】——rk3588推理获得logits

说明:用python进行rk3588模型RKLLM_INFER_GET_LOGITS模式推理获得logits

代码:

python 复制代码
import ctypes
import sys
import os
import subprocess
import resource
import threading
import time
import argparse
import json
from flask import Flask, request, jsonify, Response
from transformers import AutoTokenizer
import numpy as np
from tqdm import tqdm

# Set the dynamic library path
rkllm_lib = ctypes.CDLL('lib/librkllmrt.so')
# Define the structures from the library
RKLLM_Handle_t = ctypes.c_void_p
userdata = ctypes.c_void_p(None)

LLMCallState = ctypes.c_int
LLMCallState.RKLLM_RUN_NORMAL  = 0
LLMCallState.RKLLM_RUN_WAITING  = 1
LLMCallState.RKLLM_RUN_FINISH  = 2
LLMCallState.RKLLM_RUN_ERROR   = 3

RKLLMInputType = ctypes.c_int
RKLLMInputType.RKLLM_INPUT_PROMPT      = 0
RKLLMInputType.RKLLM_INPUT_TOKEN       = 1
RKLLMInputType.RKLLM_INPUT_EMBED       = 2
RKLLMInputType.RKLLM_INPUT_MULTIMODAL  = 3

RKLLMInferMode = ctypes.c_int
RKLLMInferMode.RKLLM_INFER_GENERATE = 0
RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER = 1
RKLLMInferMode.RKLLM_INFER_GET_LOGITS = 2
class RKLLMExtendParam(ctypes.Structure):
    _fields_ = [
        ("base_domain_id", ctypes.c_int32),
        ("embed_flash", ctypes.c_int8),
        ("enabled_cpus_num", ctypes.c_int8),
        ("enabled_cpus_mask", ctypes.c_uint32),
        ("n_batch", ctypes.c_uint8),
        ("use_cross_attn", ctypes.c_int8),
        ("reserved", ctypes.c_uint8 * 104)
    ]

class RKLLMParam(ctypes.Structure):
    _fields_ = [
        ("model_path", ctypes.c_char_p),
        ("max_context_len", ctypes.c_int32),
        ("max_new_tokens", ctypes.c_int32),
        ("top_k", ctypes.c_int32),
        ("n_keep", ctypes.c_int32),
        ("top_p", ctypes.c_float),
        ("temperature", ctypes.c_float),
        ("repeat_penalty", ctypes.c_float),
        ("frequency_penalty", ctypes.c_float),
        ("presence_penalty", ctypes.c_float),
        ("mirostat", ctypes.c_int32),
        ("mirostat_tau", ctypes.c_float),
        ("mirostat_eta", ctypes.c_float),
        ("skip_special_token", ctypes.c_bool),
        ("is_async", ctypes.c_bool),
        ("img_start", ctypes.c_char_p),
        ("img_end", ctypes.c_char_p),
        ("img_content", ctypes.c_char_p),
        ("extend_param", RKLLMExtendParam),
    ]

class RKLLMLoraAdapter(ctypes.Structure):
    _fields_ = [
        ("lora_adapter_path", ctypes.c_char_p),
        ("lora_adapter_name", ctypes.c_char_p),
        ("scale", ctypes.c_float)
    ]

class RKLLMEmbedInput(ctypes.Structure):
    _fields_ = [
        ("embed", ctypes.POINTER(ctypes.c_float)),
        ("n_tokens", ctypes.c_size_t)
    ]

class RKLLMTokenInput(ctypes.Structure):
    _fields_ = [
        ("input_ids", ctypes.POINTER(ctypes.c_int32)),
        ("n_tokens", ctypes.c_size_t)
    ]

class RKLLMMultiModalInput(ctypes.Structure):
    _fields_ = [
        ("prompt", ctypes.c_char_p),
        ("image_embed", ctypes.POINTER(ctypes.c_float)),
        ("n_image_tokens", ctypes.c_size_t),
        ("n_image", ctypes.c_size_t),
        ("image_width", ctypes.c_size_t),
        ("image_height", ctypes.c_size_t)
    ]

class RKLLMInputUnion(ctypes.Union):
    _fields_ = [
        ("prompt_input", ctypes.c_char_p),
        ("embed_input", RKLLMEmbedInput),
        ("token_input", RKLLMTokenInput),
        ("multimodal_input", RKLLMMultiModalInput)
    ]

class RKLLMInput(ctypes.Structure):
    _fields_ = [
        ("role", ctypes.c_char_p),
        ("enable_thinking", ctypes.c_bool),
        ("input_type", RKLLMInputType),
        ("input_data", RKLLMInputUnion)
    ]

class RKLLMLoraParam(ctypes.Structure):
    _fields_ = [
        ("lora_adapter_name", ctypes.c_char_p)
    ]

class RKLLMPromptCacheParam(ctypes.Structure):
    _fields_ = [
        ("save_prompt_cache", ctypes.c_int),
        ("prompt_cache_path", ctypes.c_char_p)
    ]

class RKLLMInferParam(ctypes.Structure):
    _fields_ = [
        ("mode", RKLLMInferMode),
        ("lora_params", ctypes.POINTER(RKLLMLoraParam)),
        ("prompt_cache_params", ctypes.POINTER(RKLLMPromptCacheParam)),
        ("keep_history", ctypes.c_int)
    ]

class RKLLMResultLastHiddenLayer(ctypes.Structure):
    _fields_ = [
        ("hidden_states", ctypes.POINTER(ctypes.c_float)),
        ("embd_size", ctypes.c_int),
        ("num_tokens", ctypes.c_int)
    ]

class RKLLMResultLogits(ctypes.Structure):
    _fields_ = [
        ("logits", ctypes.POINTER(ctypes.c_float)),
        ("vocab_size", ctypes.c_int),
        ("num_tokens", ctypes.c_int)
    ]

class RKLLMPerfStat(ctypes.Structure):
    _fields_ = [
        ("prefill_time_ms", ctypes.c_float),
        ("prefill_tokens", ctypes.c_int),
        ("generate_time_ms", ctypes.c_float),
        ("generate_tokens", ctypes.c_int),
        ("memory_usage_mb", ctypes.c_float)
    ]

class RKLLMResult(ctypes.Structure):
    _fields_ = [
        ("text", ctypes.c_char_p),
        ("token_id", ctypes.c_int),
        ("last_hidden_layer", RKLLMResultLastHiddenLayer),
        ("logits", RKLLMResultLogits),
        ("perf", RKLLMPerfStat)
    ]

# Create a lock to control multi-user access to the server.
lock = threading.Lock()

# Create a global variable to indicate whether the server is currently in a blocked state.
is_blocking = False

# Define global variables to store the callback function output for displaying in the Gradio interface
system_prompt = ''
global_text = []
global_state = -1
split_byte_data = bytes(b"") # Used to store the segmented byte data
global_logits = None
global_input_ids_len = 0

recevied_messages = []

# Define the callback function
def callback_impl(result, userdata, state):
    global global_text, global_state, split_byte_data
    if state == LLMCallState.RKLLM_RUN_FINISH:
        global_state = state
        print("\n")
        sys.stdout.flush()
    elif state == LLMCallState.RKLLM_RUN_ERROR:
        global_state = state
        print("run error")
        sys.stdout.flush()
    elif state == LLMCallState.RKLLM_RUN_NORMAL:
        global_state = state
        global_text += result.contents.text.decode('utf-8')
    return 0


def ppl_callback_impl(result, userdata, state):
    global global_input_ids_len, global_logits
    if state == LLMCallState.RKLLM_RUN_NORMAL:
        if global_input_ids_len != result.contents.logits.num_tokens:
            print(f"input_ids_len:{global_input_ids_len}, num_tokens:{result.contents.logits.num_tokens}")
        num_tokens = result.contents.logits.num_tokens
        vocab_size = result.contents.logits.vocab_size
        global_logits = np.ctypeslib.as_array(result.contents.logits.logits, shape=(num_tokens, vocab_size))
    elif state == LLMCallState.RKLLM_RUN_FINISH:
        pass
    else:
        raise Exception("ppl Call Error")
    return 0
    

# Connect the callback function between the Python side and the C++ side
callback_type = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.POINTER(RKLLMResult), ctypes.c_void_p, ctypes.c_int)
callback = callback_type(ppl_callback_impl)

# Define the RKLLM class, which includes initialization, inference, and release operations for the RKLLM model in the dynamic library
class RKLLM(object):
    def __init__(self, model_path, lora_model_path = None, prompt_cache_path = None, platform = "rk3588"):

        self.rkllm_createDefaultParam = rkllm_lib.rkllm_createDefaultParam
        self.rkllm_createDefaultParam.argtypes = []
        self.rkllm_createDefaultParam.restype = RKLLMParam
        rkllm_param = self.rkllm_createDefaultParam()
        rkllm_param.model_path = bytes(model_path, 'utf-8')
        rkllm_param.top_k = 1
        rkllm_param.top_p = 0.95
        rkllm_param.temperature = 0.8
        rkllm_param.repeat_penalty = 1.1
        rkllm_param.frequency_penalty = 0.0
        rkllm_param.presence_penalty = 0.0

        rkllm_param.max_new_tokens = 1
        rkllm_param.max_context_len = 2048
        rkllm_param.skip_special_token = ctypes.c_bool(True)
        rkllm_param.extend_param.base_domain_id = 0
        rkllm_param.extend_param.embed_flash = 1

        rkllm_param.extend_param.enabled_cpus_num = 4
        if platform.lower() in ["rk3576", "rk3588"]:
            rkllm_param.extend_param.enabled_cpus_mask = (1 << 4)|(1 << 5)|(1 << 6)|(1 << 7)
        else:
            rkllm_param.extend_param.enabled_cpus_mask = (1 << 0)|(1 << 1)|(1 << 2)|(1 << 3)

        self.handle = RKLLM_Handle_t()

        self.rkllm_init = rkllm_lib.rkllm_init
        self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type]
        self.rkllm_init.restype = ctypes.c_int
        ret = self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
        if (ret != 0):
            print("\nrkllm init failed\n")
            exit(0)
        else:
            print("\nrkllm init success!\n")

        self.rkllm_run = rkllm_lib.rkllm_run
        self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p]
        self.rkllm_run.restype = ctypes.c_int

        self.rkllm_clear_kv_cache = rkllm_lib.rkllm_clear_kv_cache
        self.rkllm_clear_kv_cache.argtypes = [RKLLM_Handle_t, ctypes.c_int, ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int)]
        self.rkllm_clear_kv_cache.restype = ctypes.c_int
        
        self.rkllm_destroy = rkllm_lib.rkllm_destroy
        self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
        self.rkllm_destroy.restype = ctypes.c_int
        
        self.rkllm_abort = rkllm_lib.rkllm_abort

        self.rkllm_infer_params = RKLLMInferParam()
        ctypes.memset(ctypes.byref(self.rkllm_infer_params), 0, ctypes.sizeof(RKLLMInferParam))
        self.rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LOGITS
        self.rkllm_infer_params.keep_history = 0


    def run_with_ids(self, *param):
        role, enable_thinking, ids = param
        rkllm_input = RKLLMInput()
        ctypes.memset(ctypes.byref(rkllm_input), 0, ctypes.sizeof(RKLLMInput))

        rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_TOKEN
        rkllm_input.input_data.token_input.input_ids = ids.ctypes.data_as(ctypes.POINTER(ctypes.c_int32))
        rkllm_input.input_data.token_input.n_tokens = ctypes.c_size_t(ids.shape[0])
        self.rkllm_run(self.handle, ctypes.byref(rkllm_input), ctypes.byref(self.rkllm_infer_params), None)
        return

    def abort(self):
        return self.rkllm_abort(self.handle)
    
    def release(self):
        self.rkllm_destroy(self.handle)


def cross_entropy_loss_numpy(logits, targets):
    shift_logits = logits - np.max(logits, axis=1, keepdims=True)
    log_probs = shift_logits - np.log(np.sum(np.exp(shift_logits), axis=1, keepdims=True))
    n = logits.shape[0]
    loss = -log_probs[np.arange(n), targets]
    return np.mean(loss)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--rkllm_model_path', type=str, default="models/Qwen3-0.6B-rk3588-w8a8.rkllm", help='Absolute path of the converted RKLLM model on the Linux board;')
    parser.add_argument('--target_platform', type=str, default="rk3588", help='Target platform: e.g., rk3588/rk3576;')
    parser.add_argument('--lora_model_path', type=str, help='Absolute path of the lora_model on the Linux board;')
    parser.add_argument('--prompt_cache_path', type=str, help='Absolute path of the prompt_cache file on the Linux board;')
    args = parser.parse_args()

    ## qwen3-1.7b
    args.rkllm_model_path = "models/Qwen3-1.7B-rk3588-w8a8.rkllm"
    json_file = "assert/gsm8k_17d799_qwen3_1.7b_pred.json"

    ## qwen3-0.6b
    args.rkllm_model_path = "models/Qwen3-0.6B-rk3588-w8a8.rkllm"
    json_file = "assert/gsm8k_17d799_qwen3_0.6B_pred.json"

    if not os.path.exists(args.rkllm_model_path):
        print("Error: Please provide the correct rkllm model path, and ensure it is the absolute path on the board.")
        sys.stdout.flush()
        exit()

    if not (args.target_platform in ["rk3588", "rk3576", "rv1126b", "rk3562"]):
        print("Error: Please specify the correct target platform: rk3588/rk3576/rv1126b/rk3562.")
        sys.stdout.flush()
        exit()

    if args.lora_model_path:
        if not os.path.exists(args.lora_model_path):
            print("Error: Please provide the correct lora_model path, and advise it is the absolute path on the board.")
            sys.stdout.flush()
            exit()

    if args.prompt_cache_path:
        if not os.path.exists(args.prompt_cache_path):
            print("Error: Please provide the correct prompt_cache_file path, and advise it is the absolute path on the board.")
            sys.stdout.flush()
            exit()

    # Fix frequency
    command = "sudo bash fix_freq_{}.sh".format(args.target_platform)
    subprocess.run(command, shell=True)

    # Set resource limit
    resource.setrlimit(resource.RLIMIT_NOFILE, (102400, 102400))

    # Initialize RKLLM model
    print("=========init....===========")
    sys.stdout.flush()
    model_path = args.rkllm_model_path
    rkllm_model = RKLLM(model_path, args.lora_model_path, args.prompt_cache_path, args.target_platform)
    print("load rkllm model from : {}".format(model_path))
    print("==============================")
    sys.stdout.flush()

    
    ## json_file
    with open(json_file, "r") as f:
        json_datas = json.load(f)

    loss = 0.0
    num_batches = 0
    for key, item in tqdm(json_datas.items()):
        input_ids = item['input_ids']
        gen_index = item['gen_index']
        input_ids = np.array(input_ids, dtype=np.int32)
        print("input_ids len:", input_ids.shape[0])

        # Reset global variables.
        ret = rkllm_model.rkllm_clear_kv_cache(rkllm_model.handle, 1, None, None)
        if ret != 0:
            print("clear kv cache failed")

        global_input_ids_len = input_ids.shape[0]
        inputs = ["user", False, input_ids]
        rkllm_model.run_with_ids(*inputs)   

        shift_logits = global_logits[gen_index:-1, :]
        shift_labels = input_ids[gen_index+1:]

        ce_loss = cross_entropy_loss_numpy(shift_logits, shift_labels)
        print("ce Loss: {:.4f}, ppl loss: {:.4f}".format(ce_loss, np.exp(ce_loss)))
        loss += ce_loss
        num_batches += 1
        if num_batches >= 100:
                break

        pass

    ppl_loss = np.exp(loss / num_batches)
    print("ppl loss: {:.4f}".format(ppl_loss))
相关推荐
AI服务老曹1 分钟前
破局异构计算与海量协议:基于 Docker 容器化的国标 GB28181/RTSP 边缘计算 AI 视频管理平台架构设计与源码交付实践
人工智能·docker·边缘计算
俊哥V1 分钟前
每日 AI 研究简报 · 2026-06-09
人工智能·ai
计算机安禾3 分钟前
【数据库系统原理】第14篇:关系模式的语义约束:函数依赖的公理系统与闭包计算
人工智能·算法·机器学习
bluetata3 分钟前
Agentic AI 解读:从认知跃升到企业落地实战指南
人工智能
量化君也4 分钟前
快速入门量化交易都要学些什么?
大数据·人工智能·python·算法·金融
o561-6o623o7鹿6 分钟前
陈,生理实验系统虚实结合型 生理学实验系统 生理学实验系统软件 生物机能实验系统
人工智能
Tbisnic10 分钟前
AI大模型学习 第十天:让程序“指挥”大模型 —— 从对话到工具调用
人工智能·python·ai·大模型·react·cot·提示词工程
婷婷81610 分钟前
我的前端项目构建时间从 8 分钟降到 40 秒,这 5 个优化起了关键作用
人工智能
大任视点16 分钟前
从云经济学之父,到人工智能经济学奠基人
大数据·人工智能·业界资讯
光锥智能17 分钟前
库克“谢幕”,苹果AI“起航”?|苹果2026WWDC
人工智能