【从零开始】9. RAG 应用调优-再续(番外篇)

书接上回,上节我们使用了 torch.multiprocessing 的 Pool(池化)技术实现了多进程并发处理。本来我以为可以到此为止了。但实际上我还忽略了一件事...

那就是并发线程数超过可提供进程时会存在大量等待。在资源争抢的情况下会出现显存不足最后爆出 CUDA out of memory 的错误。但我明明已经写了显存检测代码了,为什么不生效呢?代码如下:

python 复制代码
...

class cuda_tools:

    _instance = None
    _initialized = False 

    def __init__(self):
        if not cuda_tools._initialized:
            # 设置显存使用率 80% 为阈值
            self.threshold_percentage = 80
            cuda_tools._initialized = True

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def check_and_clean_gpu_memory(self):
        """
        检查 GPU 显存使用情况,并在使用率高于 self.threshold_percentage 时清理 GPU 显存。

        返回:
            bool: 如果使用率高于 self.threshold_percentage,则返回True,否则返回False。
        """

        if torch.cuda.is_available():
            device = torch.cuda.current_device()
            memory_allocated = torch.cuda.memory_allocated(device)
            memory_total = torch.cuda.get_device_properties(device).total_memory
            # 计算显存使用百分比
            memory_used_percent = (memory_allocated / memory_total) * 100
            return memory_used_percent > self.threshold_percentage
        return False

看上去这段代码好像没有什么问题,但在大神的指点下发现其实会引发两个问题:

  • torch.cuda.memory_allocated() 只能获取到当前被 PyTorch 分配的显存,不包括 CUDA 缓存的显存;
  • 多进程情况下,每个进程都会独立加载模型,无法共享显存;

好吧,既然要改了那就稍微改得更精细一点吧,最起码要做到以下 5 点:

  1. 实时监控:使用单独的线程实时监控每个进程的显存使用情况
  2. 精确控制:使用 pynvml 获取更准确的显存使用信息
  3. 自动清理:当显存使用超过阈值时自动进行清理
  4. 进程隔离:每个进程独立控制自己的显存使用
  5. 安全退出:确保资源正确释放

以下是修改后的代码:

python 复制代码
...
class gpu_memory_manager:
    def __init__(self, check_interval: float = 60):
        """
        构造函数。

        参数:
            check_interval (float): 监控间隔,单位为秒,默认为60秒。

        属性:
            threshold_percentage (float): 内存阈值,单位为百分比,超过该值将清理内存。
            check_interval (float): 监控间隔,单位为秒。
            process_id (int): 当前进程ID。
            _stop_monitoring (threading.Event): 停止监控的事件。
            _monitor_thread (Optional[threading.Thread]): 监控线程。
            handle (nvmlDevice_t): NVML 设备句柄。
        """
        # 阈值
        self.threshold_percentage = 15
        # 检查间隔
        self.check_interval = check_interval
        # 进程ID
        self.process_id = os.getpid()
        # 停止监控
        self._stop_monitoring = threading.Event()
        # 监控线程
        self._monitor_thread: Optional[threading.Thread] = None
        
        # 初始化 NVML
        nvmlInit()
        # 获取设备句柄
        self.handle = nvmlDeviceGetHandleByIndex(0) 
        
    def _get_process_memory_info(self):
        """
        获取当前进程使用的 GPU 内存大小。
        该方法使用 NVML 库获取当前进程使用的 GPU 内存大小。

        返回:
            int: 当前进程使用的 GPU 内存大小,单位为字节。
        """
        try:
            # 获取当前进程的 GPU 内存使用情况
            processes = nvmlDeviceGetComputeRunningProcesses(self.handle)
            for process in processes:
                if process.pid == self.process_id:
                    return process.usedGpuMemory
            return 0
        except:
            return 0
            
    def _get_total_memory(self):
        """
        获取当前 GPU 的总内存大小。
        该方法使用 NVML 库获取当前 GPU 的总内存大小。

        返回:
            int: 当前 GPU 的总内存大小,单位为字节。
        """
        info = nvmlDeviceGetMemoryInfo(self.handle)
        return info.total
        
    def _monitor_memory(self):
        """
        一个循环线程,用于监控当前进程的 GPU 内存使用情况。
        如果当前进程的 GPU 内存使用率超过了 `threshold_percentage`,则执行 `clean_memory` 方法来回收 GPU 内存。
        """
        while not self._stop_monitoring.is_set():
            process_memory = self._get_process_memory_info()
            total_memory = self._get_total_memory()
            
            # 只在进程实际使用显存时检查
            if process_memory > 0:  
                usage_percentage = (process_memory / total_memory) * 100
                # 如果当前进程的 GPU 内存使用率超过了阈值,则执行 clean_memory
                if usage_percentage > self.threshold_percentage:
                    print(f"Process {self.process_id} memory usage ({usage_percentage:.2f}%) exceeded threshold ({self.threshold_percentage}%)")
                    self.clean_memory()
            time.sleep(self.check_interval)
            
    def start_monitoring(self):
        """
        启动监控线程,用于监控当前进程的 GPU 内存使用情况。
        如果当前进程的 GPU 内存使用率超过了 `threshold_percentage`,则执行 `clean_memory` 方法来回收 GPU 内存。
        """
        if self._monitor_thread is None:
            # 清空停止事件
            self._stop_monitoring.clear()
            # 启动监控线程
            self._monitor_thread = threading.Thread(target=self._monitor_memory, daemon=True)
            # 启动线程
            self._monitor_thread.start()
            
    def stop_monitoring(self):
        """
        停止监控线程,确保当前进程的 GPU 内存监控安全退出。
        如果监控线程正在运行,则设置停止事件,并等待线程结束。
        """
        if self._monitor_thread is not None:
            self._stop_monitoring.set()
            self._monitor_thread.join()
            self._monitor_thread = None
            
    def clean_memory(self):
        """
        清理当前进程的 GPU 内存缓存。

        该方法通过调用 PyTorch 的 `torch.cuda.empty_cache()` 来释放未使用的 GPU 内存缓存,
        并通过 `gc.collect()` 强制进行垃圾回收,以便更好地管理内存。

        注意:此方法不会影响已经分配给张量的内存,只会清除未使用缓存。
        """
        torch.cuda.empty_cache()
        gc.collect()
        
    def __del__(self):
        """
        析构函数,用于释放资源。
        该方法在对象销毁时被调用,用于停止监控线程,并释放 NVML 库的资源。
        """
        self.stop_monitoring()
        try:
            # 释放 NVML 资源
            nvmlShutdown()
        except:
            pass

使用的时候也比较简单,如下图:

python 复制代码
class xxx:
    ...
    
    def __init__(self):
        ...
        self.memory_manager = gpu_memory_manager()
        self.memory_manager.start_monitoring()

    def generate_msg():
        try:
            ...
        finally:
            # 确保清理显存
            self.memory_manager.clean_memory()

    def __del__(self):
        self.memory_manager.stop_monitoring()

好了显存监控已经做好了,赶紧实操一下吧。

还是用回之前的压测工具(pressure_util.py),不过这次将 5 线程改为 10 线程,让其超出进程负荷出现资源争抢的情况,结果如下:

bash 复制代码
- 压测完成,运行时间: 726.09秒
- 共完成 100 个任务
- 平均QPS: 0.14
- 清理完成

通过 nvtop 能够得知

进程执行完成后就会回收显存了

这个时间是多了 300 秒,这可能是因为频繁查询显存和清理缓存造成的吧,稍微多做几次调整成合适自己的检查间隔就好了,起码现在无论多少用户进来都不至于爆显存了。

相关推荐
nancy_princess8 小时前
clip实验
人工智能·深度学习
飞哥数智坊8 小时前
TRAE Friends@济南第4次活动:100+极客集结,2小时极限编程燃爆全场!
人工智能
AI自动化工坊8 小时前
ProofShot实战:给AI编码助手添加可视化验证,提升前端开发效率3倍
人工智能·ai·开源·github
飞哥数智坊8 小时前
一场直播涨粉 2 万的背后!OpenClaw + 飞书,正在重塑软件交付的方式
人工智能
飞哥数智坊8 小时前
养虾记第3期:安装、调教、落地,这场沙龙我们全聊了
人工智能
再不会python就不礼貌了8 小时前
从工具到个人助理——AI Agent的原理、演进与安全风险
人工智能·安全·ai·大模型·transformer·ai编程
AI医影跨模态组学8 小时前
Radiother Oncol 空军军医大学西京医院等团队:基于纵向CT的亚区域放射组学列线图预测食管鳞状细胞癌根治性放化疗后局部无复发生存期
人工智能·深度学习·医学影像·影像组学
A尘埃8 小时前
神经网络的激活函数+损失函数
人工智能·深度学习·神经网络·激活函数
没有不重的名么9 小时前
Pytorch深度学习快速入门教程
人工智能·pytorch·深度学习
有为少年9 小时前
告别“唯语料论”:用合成抽象数据为大模型开智
人工智能·深度学习·神经网络·算法·机器学习·大模型·预训练