【从零开始】9. RAG 应用调优-再续(番外篇)

书接上回,上节我们使用了 torch.multiprocessing 的 Pool(池化)技术实现了多进程并发处理。本来我以为可以到此为止了。但实际上我还忽略了一件事...

那就是并发线程数超过可提供进程时会存在大量等待。在资源争抢的情况下会出现显存不足最后爆出 CUDA out of memory 的错误。但我明明已经写了显存检测代码了,为什么不生效呢?代码如下:

python 复制代码
...

class cuda_tools:

    _instance = None
    _initialized = False 

    def __init__(self):
        if not cuda_tools._initialized:
            # 设置显存使用率 80% 为阈值
            self.threshold_percentage = 80
            cuda_tools._initialized = True

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def check_and_clean_gpu_memory(self):
        """
        检查 GPU 显存使用情况,并在使用率高于 self.threshold_percentage 时清理 GPU 显存。

        返回:
            bool: 如果使用率高于 self.threshold_percentage,则返回True,否则返回False。
        """

        if torch.cuda.is_available():
            device = torch.cuda.current_device()
            memory_allocated = torch.cuda.memory_allocated(device)
            memory_total = torch.cuda.get_device_properties(device).total_memory
            # 计算显存使用百分比
            memory_used_percent = (memory_allocated / memory_total) * 100
            return memory_used_percent > self.threshold_percentage
        return False

看上去这段代码好像没有什么问题,但在大神的指点下发现其实会引发两个问题:

  • torch.cuda.memory_allocated() 只能获取到当前被 PyTorch 分配的显存,不包括 CUDA 缓存的显存;
  • 多进程情况下,每个进程都会独立加载模型,无法共享显存;

好吧,既然要改了那就稍微改得更精细一点吧,最起码要做到以下 5 点:

  1. 实时监控:使用单独的线程实时监控每个进程的显存使用情况
  2. 精确控制:使用 pynvml 获取更准确的显存使用信息
  3. 自动清理:当显存使用超过阈值时自动进行清理
  4. 进程隔离:每个进程独立控制自己的显存使用
  5. 安全退出:确保资源正确释放

以下是修改后的代码:

python 复制代码
...
class gpu_memory_manager:
    def __init__(self, check_interval: float = 60):
        """
        构造函数。

        参数:
            check_interval (float): 监控间隔,单位为秒,默认为60秒。

        属性:
            threshold_percentage (float): 内存阈值,单位为百分比,超过该值将清理内存。
            check_interval (float): 监控间隔,单位为秒。
            process_id (int): 当前进程ID。
            _stop_monitoring (threading.Event): 停止监控的事件。
            _monitor_thread (Optional[threading.Thread]): 监控线程。
            handle (nvmlDevice_t): NVML 设备句柄。
        """
        # 阈值
        self.threshold_percentage = 15
        # 检查间隔
        self.check_interval = check_interval
        # 进程ID
        self.process_id = os.getpid()
        # 停止监控
        self._stop_monitoring = threading.Event()
        # 监控线程
        self._monitor_thread: Optional[threading.Thread] = None
        
        # 初始化 NVML
        nvmlInit()
        # 获取设备句柄
        self.handle = nvmlDeviceGetHandleByIndex(0) 
        
    def _get_process_memory_info(self):
        """
        获取当前进程使用的 GPU 内存大小。
        该方法使用 NVML 库获取当前进程使用的 GPU 内存大小。

        返回:
            int: 当前进程使用的 GPU 内存大小,单位为字节。
        """
        try:
            # 获取当前进程的 GPU 内存使用情况
            processes = nvmlDeviceGetComputeRunningProcesses(self.handle)
            for process in processes:
                if process.pid == self.process_id:
                    return process.usedGpuMemory
            return 0
        except:
            return 0
            
    def _get_total_memory(self):
        """
        获取当前 GPU 的总内存大小。
        该方法使用 NVML 库获取当前 GPU 的总内存大小。

        返回:
            int: 当前 GPU 的总内存大小,单位为字节。
        """
        info = nvmlDeviceGetMemoryInfo(self.handle)
        return info.total
        
    def _monitor_memory(self):
        """
        一个循环线程,用于监控当前进程的 GPU 内存使用情况。
        如果当前进程的 GPU 内存使用率超过了 `threshold_percentage`,则执行 `clean_memory` 方法来回收 GPU 内存。
        """
        while not self._stop_monitoring.is_set():
            process_memory = self._get_process_memory_info()
            total_memory = self._get_total_memory()
            
            # 只在进程实际使用显存时检查
            if process_memory > 0:  
                usage_percentage = (process_memory / total_memory) * 100
                # 如果当前进程的 GPU 内存使用率超过了阈值,则执行 clean_memory
                if usage_percentage > self.threshold_percentage:
                    print(f"Process {self.process_id} memory usage ({usage_percentage:.2f}%) exceeded threshold ({self.threshold_percentage}%)")
                    self.clean_memory()
            time.sleep(self.check_interval)
            
    def start_monitoring(self):
        """
        启动监控线程,用于监控当前进程的 GPU 内存使用情况。
        如果当前进程的 GPU 内存使用率超过了 `threshold_percentage`,则执行 `clean_memory` 方法来回收 GPU 内存。
        """
        if self._monitor_thread is None:
            # 清空停止事件
            self._stop_monitoring.clear()
            # 启动监控线程
            self._monitor_thread = threading.Thread(target=self._monitor_memory, daemon=True)
            # 启动线程
            self._monitor_thread.start()
            
    def stop_monitoring(self):
        """
        停止监控线程,确保当前进程的 GPU 内存监控安全退出。
        如果监控线程正在运行,则设置停止事件,并等待线程结束。
        """
        if self._monitor_thread is not None:
            self._stop_monitoring.set()
            self._monitor_thread.join()
            self._monitor_thread = None
            
    def clean_memory(self):
        """
        清理当前进程的 GPU 内存缓存。

        该方法通过调用 PyTorch 的 `torch.cuda.empty_cache()` 来释放未使用的 GPU 内存缓存,
        并通过 `gc.collect()` 强制进行垃圾回收,以便更好地管理内存。

        注意:此方法不会影响已经分配给张量的内存,只会清除未使用缓存。
        """
        torch.cuda.empty_cache()
        gc.collect()
        
    def __del__(self):
        """
        析构函数,用于释放资源。
        该方法在对象销毁时被调用,用于停止监控线程,并释放 NVML 库的资源。
        """
        self.stop_monitoring()
        try:
            # 释放 NVML 资源
            nvmlShutdown()
        except:
            pass

使用的时候也比较简单,如下图:

python 复制代码
class xxx:
    ...
    
    def __init__(self):
        ...
        self.memory_manager = gpu_memory_manager()
        self.memory_manager.start_monitoring()

    def generate_msg():
        try:
            ...
        finally:
            # 确保清理显存
            self.memory_manager.clean_memory()

    def __del__(self):
        self.memory_manager.stop_monitoring()

好了显存监控已经做好了,赶紧实操一下吧。

还是用回之前的压测工具(pressure_util.py),不过这次将 5 线程改为 10 线程,让其超出进程负荷出现资源争抢的情况,结果如下:

bash 复制代码
- 压测完成,运行时间: 726.09秒
- 共完成 100 个任务
- 平均QPS: 0.14
- 清理完成

通过 nvtop 能够得知

进程执行完成后就会回收显存了

这个时间是多了 300 秒,这可能是因为频繁查询显存和清理缓存造成的吧,稍微多做几次调整成合适自己的检查间隔就好了,起码现在无论多少用户进来都不至于爆显存了。

相关推荐
Jacob_AI几秒前
为什么 Bert 的三个 Embedding 可以进行相加?
人工智能·bert·embedding
Narutolxy6 分钟前
在 Windows WSL 上部署 Ollama 和大语言模型:从镜像冗余问题看 Docker 最佳实践20241208
人工智能·docker·语言模型
人工智能培训网44 分钟前
《计算机视觉:瓶颈之辩与未来之路》
人工智能·学习·计算机视觉·人工智能培训·人工智能工程师
volcanical1 小时前
LoRA:低秩分解微调与代码
人工智能
keira6741 小时前
【21天学习AI底层概念】day3 机器学习的三大类型(监督学习、无监督学习、强化学习)分别适用于哪种类型的问题?
人工智能·学习·机器学习
海森大数据2 小时前
人工智能时代的计算化学实验:量子化学与机器学习的融合
大数据·人工智能·神经网络·机器学习
魏+Mtiao15_2 小时前
矩阵源代码部署与功能简介
人工智能·python·线性代数·矩阵·php·音视频
凝眸伏笔2 小时前
【TensorFlow】基本概念:张量、常量、变量、占位符、计算图
人工智能·tensorflow·neo4j
魏+Mtiao15_2 小时前
短视频矩阵系统功能介绍与独立部署流程
java·大数据·人工智能·矩阵
VIT199801062 小时前
AI实现葡萄叶片识别(基于深度学习的葡萄叶片识别)
人工智能·python·深度学习