跟AI学一手之流媒体服务器

最近又要做个流媒体项目,总的来说就是给摄像头设备写服务端,服务端要实现接收设备发来的视频流,然后解析转码,再通过 http 发布出去,其实这个有现成的方案,比如 srs,zlmediakit 等开源软件,而且都是快十多年的成熟产品,但是老板意思是我们的业务量不大,再加上后面可能需要有修改需求,所以意思让我们自己开发,使用 python,虽然性能不太行,但是因为业务量不大,也将就能用,于是就开始研究视频编解码方面,通信协议方面的东西,但是这里面还是有很多细节,很复杂,只能让 AI 写个大概,跑通了以后再自己改,经过几天的摸索,终于跑通了一个流程,就是使用 ffmpeg 推一个视频到一个端口,使用 rtp 协议,然后我的流媒体服务器监听这个 rtp 端口收流,收流后解析出 h.264报文,再封装为 flv 格式发送给前端,前端使用 flv.js进行播放,废话不多说,直接上代码,rtp_server.py

python 复制代码
import struct
import os
import time
import asyncio
import threading
import socket
import uuid
import logging
import json
import signal
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Set, Tuple, Any, Callable
from collections import deque

import aiohttp
from aiohttp import web
from pathlib import Path
#-----------------------------------------------------------------------------------------------------------------------
# 日志配置
#-----------------------------------------------------------------------------------------------------------------------
#-----------------------------------------------------------------------------------------------------------------------
# RTP接收器
#-----------------------------------------------------------------------------------------------------------------------

class AsyncRTPReceiver:
    """异步RTP接收器"""
    
    def __init__(self, port: int, buffer_size: int = 65536):
        self.port = port
        self.buffer_size = buffer_size
        self.transport = None
        self.protocol = None
        self.running = False
        self.callback = None
        self.stats = {
            'packets_received': 0,
            'packets_dropped': 0,
            'bytes_received': 0,
            'last_packet_time': 0
        }
        self._lock = threading.RLock()
    
    async def start(self, callback: Callable[[bytes, Tuple[str, int]], None]):
        """启动RTP接收器"""
        if self.running:
            return False
        
        try:
            loop = asyncio.get_event_loop()
            
            # 创建UDP端点
            self.transport, self.protocol = await loop.create_datagram_endpoint(
                lambda: RTPProtocol(callback, self.stats, self._lock),
                local_addr=('0.0.0.0', self.port)
            )
            
            self.running = True
            logger.info(f"RTP接收器已启动: 端口{self.port}")
            return True
            
        except Exception as e:
            logger.error(f"RTP接收器启动失败: {e}")
            return False
    
    def stop(self):
        """停止RTP接收器"""
        self.running = False
        if self.transport:
            self.transport.close()
            self.transport = None
        logger.info("RTP接收器已停止")
    
    def get_stats(self) -> Dict:
        with self._lock:
            return self.stats.copy()


class RTPProtocol(asyncio.DatagramProtocol):
    """RTP协议处理器"""
    
    def __init__(self, callback, stats, lock):
        self.callback = callback
        self.stats = stats
        self._lock = lock
        super().__init__()
    
    def connection_made(self, transport):
        self.transport = transport
    
    def datagram_received(self, data, addr):
        with self._lock:
            self.stats['packets_received'] += 1
            self.stats['bytes_received'] += len(data)
            self.stats['last_packet_time'] = time.time()
        
        if self.callback:
            try:
                self.callback(data, addr)
            except Exception:
                with self._lock:
                    self.stats['packets_dropped'] += 1
    
    def error_received(self, exc):
        logger.error(f"RTP协议错误: {exc}")
    
    def connection_lost(self, exc):
        logger.info("RTP连接关闭")
class LoggerManager:
    _instances = {}
    
    def __new__(cls, name: str = 'MediaServer', level: str = 'INFO'):
        if name not in cls._instances:
            instance = super().__new__(cls)
            instance._setup_logger(name, level)
            cls._instances[name] = instance
        return cls._instances[name]

    def _setup_logger(self, name: str, level: str):
        self.logger = logging.getLogger(name)
        self.logger.setLevel(getattr(logging, level.upper()))
        console = logging.StreamHandler()
        console.setLevel(logging.DEBUG)
        formatter = logging.Formatter(
            '%(asctime)s [%(levelname)s] [%(threadName)s] %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )
        console.setFormatter(formatter)
        self.logger.addHandler(console)

logger = LoggerManager().logger

#-----------------------------------------------------------------------------------------------------------------------
# 配置
#-----------------------------------------------------------------------------------------------------------------------

@dataclass
class ServerConfig:
    rtp_port: int = 10000
    http_port: int = 8080
    http_host: str = '0.0.0.0'
    max_streams: int = 100
    max_clients_per_stream: int = 50
    buffer_size: int = 50 * 1024 * 1024
    stream_idle_timeout: int = 300
    client_idle_timeout: int = 60
    frame_rate: float = 25.0
    max_worker_threads: int = 20
    enable_auth: bool = False
    auth_token: Optional[str] = None
    allowed_ips: List[str] = field(default_factory=list)
    enable_stats: bool = True
    stats_interval: int = 30
    log_level: str = 'INFO'
    client_cleanup_interval: int = 30

#-----------------------------------------------------------------------------------------------------------------------
# RTP包解析
#-----------------------------------------------------------------------------------------------------------------------

class RTPPacket:
    def __init__(self, data: bytes):
        if len(data) < 12:
            raise ValueError("RTP包太短")
        
        self.version = (data[0] >> 6) & 0x03
        self.payload_type = data[1] & 0x7F
        self.sequence = struct.unpack('>H', data[2:4])[0]
        self.timestamp = struct.unpack('>I', data[4:8])[0]
        self.ssrc = struct.unpack('>I', data[8:12])[0]
        self.payload = data[12:]
        self.marker = (data[1] >> 7) & 0x01
    
    def is_valid(self) -> bool:
        return self.version == 2 and len(self.payload) > 0

#-----------------------------------------------------------------------------------------------------------------------
# 环形缓冲区
#-----------------------------------------------------------------------------------------------------------------------

class RingBuffer:
    def __init__(self, max_size: int = 50 * 1024 * 1024):
        self.buffer = bytearray(max_size)
        self.max_size = max_size
        self.write_pos = 0
        self.available = 0
        self._lock = threading.RLock()
        self.readers: Dict[str, int] = {}
        self.reader_lock = threading.RLock()
    
    def write(self, data: bytes) -> int:
        with self._lock:
            data_len = len(data)
            if data_len == 0:
                return 0
            
            if data_len > self.max_size:
                data = data[-self.max_size:]
                data_len = self.max_size
            
            free_space = self.max_size - self.available
            if data_len > free_space:
                discard = data_len - free_space
                self.available -= discard
                with self.reader_lock:
                    for client_id in list(self.readers.keys()):
                        if self.readers[client_id] < discard:
                            self.readers[client_id] = 0
                        else:
                            self.readers[client_id] -= discard
            
            end = self.write_pos + data_len
            if end <= self.max_size:
                self.buffer[self.write_pos:end] = data
            else:
                first = self.max_size - self.write_pos
                self.buffer[self.write_pos:] = data[:first]
                self.buffer[:data_len - first] = data[first:]
            
            self.write_pos = (self.write_pos + data_len) % self.max_size
            self.available += data_len
            return data_len
    
    def read(self, client_id: str, size: int = 8192) -> bytes:
        with self._lock:
            with self.reader_lock:
                if client_id not in self.readers:
                    self.readers[client_id] = self.write_pos
                    return b''
                read_pos = self.readers[client_id]
            
            if read_pos <= self.write_pos:
                available_for_client = self.write_pos - read_pos
            else:
                available_for_client = (self.max_size - read_pos) + self.write_pos
            
            if available_for_client == 0:
                return b''
            
            read_size = min(size, available_for_client)
            end = read_pos + read_size
            
            if end <= self.max_size:
                data = bytes(self.buffer[read_pos:end])
            else:
                first = self.max_size - read_pos
                data = bytes(self.buffer[read_pos:])
                data += bytes(self.buffer[:read_size - first])
            
            with self.reader_lock:
                self.readers[client_id] = (read_pos + read_size) % self.max_size
            
            return data
    
    def remove_reader(self, client_id: str):
        with self.reader_lock:
            self.readers.pop(client_id, None)
    
    def clear(self):
        with self._lock:
            self.write_pos = 0
            self.available = 0
            with self.reader_lock:
                self.readers.clear()

#-----------------------------------------------------------------------------------------------------------------------
# RTP裸流解析器
#-----------------------------------------------------------------------------------------------------------------------

class RTPRawParser:
    NAL_SPS = 7
    NAL_PPS = 8
    NAL_IDR = 5
    NAL_SLICE = 1
    
    PT_STAP_A = 24
    PT_FU_A = 28
    
    def __init__(self):
        self.fu_buffers = {}  # (ssrc, timestamp) -> bytearray
        self._lock = threading.RLock()
    
    def feed(self, rtp: RTPPacket) -> List[Tuple[int, bytes]]:
        """返回 [(nal_type, nal_data_without_start_code), ...]"""
        with self._lock:
            payload = rtp.payload
            if len(payload) == 0:
                return []
            
            nal_header = payload[0]
            nal_type = nal_header & 0x1F
            
            if nal_type == self.PT_STAP_A:
                return self._parse_stap_a(payload)
            elif nal_type == self.PT_FU_A:
                return self._parse_fu_a(rtp)
            elif 1 <= nal_type <= 23:
                return [(nal_type, payload)]
            else:
                return []
    
    def _parse_stap_a(self, payload: bytes) -> List[Tuple[int, bytes]]:
        results = []
        offset = 1
        while offset + 2 <= len(payload):
            size = struct.unpack('>H', payload[offset:offset+2])[0]
            offset += 2
            if offset + size > len(payload):
                break
            nal_data = payload[offset:offset+size]
            offset += size
            if len(nal_data) > 0:
                nal_type = nal_data[0] & 0x1F
                results.append((nal_type, nal_data))
        return results
    
    def _parse_fu_a(self, rtp: RTPPacket) -> List[Tuple[int, bytes]]:
        payload = rtp.payload
        if len(payload) < 2:
            return []
        
        fu_header = payload[1]
        start = (fu_header >> 7) & 0x01
        end = (fu_header >> 6) & 0x01
        fu_nal_type = fu_header & 0x1F
        
        nal_data = payload[2:]
        
        # 重构NAL头
        ref_header = (payload[0] & 0xE0) | fu_nal_type
        
        key = (rtp.ssrc, rtp.timestamp)
        
        if start:
            self.fu_buffers[key] = bytearray([ref_header]) + nal_data
            if end:
                result = self.fu_buffers.pop(key)
                return [(fu_nal_type, bytes(result))]
            return []
        elif key in self.fu_buffers:
            self.fu_buffers[key].extend(nal_data)
            if end:
                result = self.fu_buffers.pop(key)
                return [(fu_nal_type, bytes(result))]
            return []
        else:
            return []
    
    def reset(self):
        with self._lock:
            self.fu_buffers.clear()

#-----------------------------------------------------------------------------------------------------------------------
# FLV封装器 (简化但确保正确)
#-----------------------------------------------------------------------------------------------------------------------

class FLVMuxer:
    """FLV封装器"""
    CODEC_H264 = 7
    
    def __init__(self, frame_rate: float = 25.0):
        self.sps = None
        self.pps = None
        self.seq_header_sent = False
        self.timestamp = 0
        self.frame_count = 0
        self.time_base = 1000.0 / frame_rate
        self._lock = threading.RLock()
        self.stats = {'key_frames': 0, 'inter_frames': 0, 'bytes_written': 0}
    
    def feed_sps_pps(self, nal_type: int, nal_data: bytes):
        """处理SPS/PPS - 新增方法"""
        with self._lock:
            if nal_type == 7:  # SPS
                if self.sps != nal_data:
                    self.sps = nal_data
                    self.seq_header_sent = False
                    logger.info("SPS已更新")
            elif nal_type == 8:  # PPS
                if self.pps != nal_data:
                    self.pps = nal_data
                    self.seq_header_sent = False
                    logger.info("PPS已更新")
    
    def feed_frame(self, frame_type: int, nal_list: List[Tuple[int, bytes]]) -> Optional[bytes]:
        """处理完整帧"""
        with self._lock:
            if not nal_list or not self.sps or not self.pps:
                return None
            
            result = bytearray()
            
            # 发送序列头
            if not self.seq_header_sent:
                sh = self._make_sequence_header()
                if sh:
                    result.extend(sh)
                    self.seq_header_sent = True
                    logger.info("FLV序列头已发送")
            
            # 更新时间戳
            self.timestamp = int(self.frame_count * self.time_base)
            self.frame_count += 1
            
            # 提取NAL数据
            nal_data_list = [data for _, data in nal_list]
            
            # 创建视频标签
            tag = self._make_video_tag(frame_type, nal_data_list, self.timestamp)
            if tag:
                result.extend(tag)
                if frame_type == 1:
                    self.stats['key_frames'] += 1
                    total_size = sum(len(n) for n in nal_data_list)
                    logger.info(f"封装IDR帧 #{self.frame_count}, "
                              f"切片数={len(nal_data_list)}, 总大小={total_size}, ts={self.timestamp}")
                else:
                    self.stats['inter_frames'] += 1
                self.stats['bytes_written'] += len(tag)
            
            return bytes(result) if result else None
    
    def _make_sequence_header(self) -> Optional[bytes]:
        """创建AVC序列头"""
        if not self.sps or not self.pps or len(self.sps) < 3 or len(self.pps) < 1:
            return None
        
        config = bytearray()
        config.append(0x01)  # version
        config.append(self.sps[1])  # profile
        config.append(self.sps[2])  # compatibility
        config.append(self.sps[3])  # level
        config.append(0xFF)  # 4字节长度
        config.append(0xE1)  # 1个SPS
        config.extend(struct.pack('>H', len(self.sps)))
        config.extend(self.sps)
        config.append(0x01)  # 1个PPS
        config.extend(struct.pack('>H', len(self.pps)))
        config.extend(self.pps)
        
        vdata = bytearray()
        vdata.append((1 << 4) | self.CODEC_H264)  # 关键帧 + H.264
        vdata.append(0x00)  # 序列头
        vdata.extend(b'\x00\x00\x00')  # composition time
        vdata.extend(config)
        
        return self._make_tag(0x09, bytes(vdata), 0)
    
    def _make_video_tag(self, frame_type: int, nal_list: List[bytes], timestamp: int) -> Optional[bytes]:
        """创建视频标签"""
        if not nal_list:
            return None
        
        vdata = bytearray()
        vdata.append((frame_type << 4) | self.CODEC_H264)
        vdata.append(0x01)  # NALU
        vdata.extend(b'\x00\x00\x00')  # composition time
        
        for nal_data in nal_list:
            vdata.extend(struct.pack('>I', len(nal_data)))  # NALU长度
            vdata.extend(nal_data)  # NALU数据
        
        return self._make_tag(0x09, bytes(vdata), timestamp)
    
    @staticmethod
    def _make_tag(tag_type: int, data: bytes, timestamp: int) -> bytes:
        """创建FLV标签 (11字节头 + 数据 + 4字节PreviousTagSize)"""
        data_len = len(data)
        ts_low = timestamp & 0xFFFFFF
        ts_high = (timestamp >> 24) & 0xFF
        
        # 11字节标签头
        header = struct.pack('>B', tag_type)           # TagType
        header += struct.pack('>I', data_len)[1:]      # DataSize (3字节)
        header += struct.pack('>I', ts_low)[1:]        # Timestamp (3字节)
        header += struct.pack('>B', ts_high)            # TimestampExtended
        header += b'\x00\x00\x00'                       # StreamID
        
        # 标签 = 头 + 数据
        tag = header + data
        
        # PreviousTagSize = 11 + data_len
        tag += struct.pack('>I', 11 + data_len)
        
        return tag
    
    def reset(self):
        with self._lock:
            self.sps = None
            self.pps = None
            self.seq_header_sent = False
            self.timestamp = 0
            self.frame_count = 0
            self.stats = {'key_frames': 0, 'inter_frames': 0, 'bytes_written': 0}

class FLVHttpServer:
    """简单的HTTP服务器,专门用于FLV流推送"""
    
    def __init__(self, stream_manager, host='0.0.0.0', port=8081):
        self.stream_manager = stream_manager
        self.host = host
        self.port = port
        self.server = None
        self.running = False
    
    async def start(self):
        """启动服务器"""
        self.server = await asyncio.start_server(
            self._handle_client, 
            self.host, 
            self.port
        )
        self.running = True
        logger.info(f"FLV HTTP服务器启动: {self.host}:{self.port}")
    
    async def _handle_client(self, reader, writer):
        """处理客户端连接"""
        try:
            # 读取HTTP请求
            request_data = await asyncio.wait_for(reader.readuntil(b'\r\n\r\n'), timeout=5)
            request_text = request_data.decode('utf-8', errors='ignore')
            
            # 解析请求行
            first_line = request_text.split('\r\n')[0]
            method, path, _ = first_line.split(' ')
            
            # 解析流ID
            if path.startswith('/stream/'):
                stream_id = path[8:]  # 去掉 /stream/
                if stream_id.endswith('.flv'):
                    stream_id = stream_id[:-4]
                
                session = self.stream_manager.get_stream(stream_id)
                
                if session and session.running:
                    client_id = f"{writer.get_extra_info('peername')[0]}_{uuid.uuid4().hex[:8]}"
                    
                    if session.add_client(client_id):
                        # 发送HTTP响应头
                        response = (
                            "HTTP/1.1 200 OK\r\n"
                            "Content-Type: video/x-flv\r\n"
                            "Access-Control-Allow-Origin: *\r\n"
                            "Cache-Control: no-cache\r\n"
                            "Connection: close\r\n"
                            "\r\n"
                        )
                        writer.write(response.encode())
                        await writer.drain()
                        
                        logger.info(f"[FLV] 开始推送: {stream_id} -> {client_id}")
                        
                        try:
                            last_active = time.time()
                            while session.running and client_id in session.clients:
                                data = session.read_for_client(client_id, 4096)
                                if data:
                                    writer.write(data)
                                    await writer.drain()
                                    last_active = time.time()
                                else:
                                    await asyncio.sleep(0.01)
                                
                                if time.time() - last_active > 60:
                                    break
                        except Exception:
                            pass
                        finally:
                            session.remove_client(client_id)
                            logger.info(f"[FLV] 推送结束: {stream_id} -> {client_id}")
                    else:
                        writer.write(b"HTTP/1.1 429 Too Many Requests\r\n\r\n")
                else:
                    writer.write(b"HTTP/1.1 404 Not Found\r\n\r\n")
            else:
                writer.write(b"HTTP/1.1 404 Not Found\r\n\r\n")
                
        except asyncio.TimeoutError:
            pass
        except Exception as e:
            logger.error(f"FLV HTTP错误: {e}")
        finally:
            try:
                writer.close()
                await writer.wait_closed()
            except:
                pass
    
    async def stop(self):
        if self.server:
            self.server.close()
            await self.server.wait_closed()
        self.running = False
#-----------------------------------------------------------------------------------------------------------------------
# 流会话
#-----------------------------------------------------------------------------------------------------------------------

class StreamSession:
    def __init__(self, stream_id: str, config: ServerConfig):
        self.stream_id = stream_id
        self.config = config
        self.buffer = RingBuffer(max_size=config.buffer_size)
        self.parser = RTPRawParser()
        self.muxer = FLVMuxer(frame_rate=config.frame_rate)
        
        self.running = True
        self.created_at = time.time()
        self.last_frame_time = time.time()
        self.clients: Set[str] = set()
        self.client_lock = threading.RLock()
        self.stats = {'rtp_packets': 0, 'total_frames': 0, 'total_bytes': 0, 'current_fps': 0.0}
        self.fps_counter = 0
        self.fps_start_time = time.time()
        
        self.frame_buffer: List[Tuple[int, bytes]] = []
        self.current_frame_type = 2
        
        # 写入FLV头
        flv_header = b'FLV\x01\x01\x00\x00\x00\x09\x00\x00\x00\x00'
        self.buffer.write(flv_header)
        logger.info(f"[{stream_id}] FLV头已写入缓冲区: {flv_header.hex()}")
        
        # 调试文件
        try:
            os.makedirs('debug', exist_ok=True)
            self.debug_file = open(f"debug/{stream_id}.flv", "wb")
            self.debug_file.write(flv_header)
        except:
            self.debug_file = None
        
        t = threading.Thread(target=self._monitor, daemon=True)
        t.start()
        
        logger.info(f"[{stream_id}] 流会话已创建")
    
    def _monitor(self):
        while self.running:
            time.sleep(1)
            elapsed = time.time() - self.fps_start_time
            if elapsed > 0:
                self.stats['current_fps'] = self.fps_counter / elapsed
                self.fps_counter = 0
                self.fps_start_time = time.time()
    
    def feed_rtp(self, rtp_data: bytes, addr: Tuple[str, int]):
        if not self.running:
            return
        
        try:
            rtp = RTPPacket(rtp_data)
            if not rtp.is_valid():
                return
            
            self.stats['rtp_packets'] += 1
            
            nal_units = self.parser.feed(rtp)
            
            for nal_type, nal_data in nal_units:
                if nal_type in [7, 8]:
                    self.muxer.feed_sps_pps(nal_type, nal_data)
                    continue
                
                if nal_type not in [1, 5]:
                    continue
                
                self.frame_buffer.append((nal_type, nal_data))
                if nal_type == 5:
                    self.current_frame_type = 1
                
                if rtp.marker:
                    if self.frame_buffer:
                        flv_data = self.muxer.feed_frame(
                            self.current_frame_type,
                            self.frame_buffer
                        )
                        if flv_data:
                            self.buffer.write(flv_data)
                            self.stats['total_frames'] += 1
                            self.stats['total_bytes'] += len(flv_data)
                            self.last_frame_time = time.time()
                            self.fps_counter += 1
                            
                            if self.debug_file:
                                try:
                                    self.debug_file.write(flv_data)
                                    self.debug_file.flush()
                                except:
                                    pass
                        
                        self.frame_buffer = []
                        self.current_frame_type = 2
                    
        except Exception as e:
            logger.error(f"[{self.stream_id}] RTP处理错误: {e}")
    
    def add_client(self, client_id: str) -> bool:
        with self.client_lock:
            if len(self.clients) >= self.config.max_clients_per_stream:
                return False
            self.clients.add(client_id)
            return True
    
    def remove_client(self, client_id: str):
        with self.client_lock:
            self.clients.discard(client_id)
            self.buffer.remove_reader(client_id)
    
    def read_for_client(self, client_id: str, size: int = 8192) -> bytes:
        with self.client_lock:
            if client_id not in self.clients:
                return b''
        return self.buffer.read(client_id, size)
    
    def stop(self):
        self.running = False
        if self.debug_file:
            self.debug_file.close()
            self.debug_file = None
        logger.info(f"[{self.stream_id}] 流会话已停止")

#-----------------------------------------------------------------------------------------------------------------------
# 流管理器
#-----------------------------------------------------------------------------------------------------------------------

class StreamManager:
    def __init__(self, config: ServerConfig):
        self.config = config
        self.sessions: Dict[str, StreamSession] = {}
        self._lock = threading.RLock()
        self._device_map: Dict[str, str] = {}
        self.running = True
        
        t = threading.Thread(target=self._cleanup, daemon=True)
        t.start()
    
    def get_or_create_stream(self, device_key: str, ssrc: int, addr: Tuple[str, int]) -> StreamSession:
        with self._lock:
            if device_key in self._device_map:
                stream_id = self._device_map[device_key]
                session = self.sessions.get(stream_id)
                if session and session.running:
                    return session
                else:
                    if device_key in self._device_map:
                        del self._device_map[device_key]
                    if stream_id in self.sessions:
                        del self.sessions[stream_id]
            
            if len(self.sessions) >= self.config.max_streams:
                logger.warning(f"流数量已达上限")
                return None
            
            stream_id = f"stream_{ssrc}_{int(time.time())}"
            session = StreamSession(stream_id, self.config)
            self.sessions[stream_id] = session
            self._device_map[device_key] = stream_id
            
            logger.info(f"创建流: {stream_id} (SSRC={ssrc})")
            return session
    
    def get_stream(self, stream_id: str) -> Optional[StreamSession]:
        with self._lock:
            return self.sessions.get(stream_id)
    
    def list_streams(self) -> List[str]:
        with self._lock:
            return list(self.sessions.keys())
    
    def _cleanup(self):
        while self.running:
            time.sleep(60)
            with self._lock:
                for stream_id, session in list(self.sessions.items()):
                    if not session.running:
                        session.stop()
                        del self.sessions[stream_id]
                        for key, sid in list(self._device_map.items()):
                            if sid == stream_id:
                                del self._device_map[key]
    
    def get_stats(self) -> Dict:
        with self._lock:
            streams = {}
            total_clients = 0
            for sid, session in self.sessions.items():
                streams[sid] = {
                    'stream_id': sid,
                    'type': 'raw',
                    'clients': len(session.clients),
                    'stats': session.stats.copy(),
                    'uptime': time.time() - session.created_at
                }
                total_clients += len(session.clients)
            
            return {
                'total_streams': len(self.sessions),
                'total_clients': total_clients,
                'streams': streams
            }
    
    def stop(self):
        self.running = False
        with self._lock:
            for session in self.sessions.values():
                session.stop()
            self.sessions.clear()
            self._device_map.clear()

#-----------------------------------------------------------------------------------------------------------------------
# HTTP处理器 (aiohttp)
#-----------------------------------------------------------------------------------------------------------------------

class HTTPHandler:
    def __init__(self, server):
        self.server = server
        self.template_dir = Path(__file__).parent
    
    async def handle_index(self, request: web.Request) -> web.Response:
        """渲染首页"""
        try:
            template_path = self.template_dir / 'index.html'
            if template_path.exists():
                with open(template_path, 'r', encoding='utf-8') as f:
                    html = f.read()
                # 替换模板变量
                html = html.replace('{{rtp_port}}', str(self.server.config.rtp_port))
                return web.Response(text=html, content_type='text/html', charset='utf-8')
            else:
                return self._get_fallback_html()
        except Exception as e:
            logger.error(f"读取模板失败: {e}")
            return self._get_fallback_html()
    
    async def handle_stream(self, request: web.Request) -> web.StreamResponse:
        stream_id = request.match_info.get('stream_id', '')
        if stream_id.endswith('.flv'):
            stream_id = stream_id[:-4]
        
        session = self.server.stream_manager.get_stream(stream_id)
        if not session or not session.running:
            raise web.HTTPNotFound()
        
        client_id = f"{request.remote}_{uuid.uuid4().hex[:8]}"
        if not session.add_client(client_id):
            raise web.HTTPTooManyRequests()
        
        response = web.StreamResponse(
            status=200,
            headers={
                'Content-Type': 'video/x-flv',
                'Access-Control-Allow-Origin': '*',
                'Cache-Control': 'no-cache, no-store, must-revalidate',
                'Pragma': 'no-cache',
            }
        )
        await response.prepare(request)
        
        # ===== 修复:显式发送 FLV 文件头 =====
        flv_header = b'FLV\x01\x04\x00\x00\x00\x09\x00\x00\x00\x00'
        await response.write(flv_header)
        await response.drain()
        
        # ===== 修复:显式发送 AVC 序列头(SPS/PPS)=====
        # 注意:这里只读取 muxer 的 sps/pps,不要改 seq_header_sent 状态
        # 因为 RingBuffer 里的其他客户端还需要它
        if session.muxer.sps and session.muxer.pps:
            sh = session.muxer._make_sequence_header()
            if sh:
                await response.write(sh)
                await response.drain()
        
        # 然后从 RingBuffer 读取后续数据
        try:
            last_active = time.time()
            while session.running and client_id in session.clients:
                data = session.read_for_client(client_id, 8192)
                if data:
                    await response.write(data)
                    await response.drain()
                    last_active = time.time()
                else:
                    await asyncio.sleep(0.005)
                
                if time.time() - last_active > self.server.config.client_idle_timeout:
                    break
        except Exception:
            pass
        finally:
            session.remove_client(client_id)
        
        return response


    async def handle_stream_test(self, request: web.Request) -> web.Response:
        """处理FLV流请求 - 使用 Response 避免 chunked 问题"""
        stream_id = request.match_info.get('stream_id', '')
        
        if stream_id.endswith('.flv'):
            stream_id = stream_id[:-4]
        
        if not stream_id:
            raise web.HTTPBadRequest(text='Invalid stream ID')
        
        session = self.server.stream_manager.get_stream(stream_id)
        if not session or not session.running:
            raise web.HTTPNotFound(text=f'Stream not found: {stream_id}')
        
        client_id = f"{request.remote}_{uuid.uuid4().hex[:8]}"
        
        if not session.add_client(client_id):
            raise web.HTTPTooManyRequests(text='Too many clients')
        
        logger.info(f"[HTTP] 开始推送: {stream_id} -> {client_id}")
        
        # 使用异步生成器来流式发送数据
        async def flv_stream():
            last_active = time.time()
            try:
                while session.running and client_id in session.clients:
                    data = session.read_for_client(client_id, 8192)
                    if data:
                        yield data
                        last_active = time.time()
                    else:
                        await asyncio.sleep(0.01)
                    
                    if time.time() - last_active > self.server.config.client_idle_timeout:
                        logger.info(f"[HTTP] 客户端超时: {client_id}")
                        break
            except Exception as e:
                logger.error(f"[HTTP] 推送错误: {e}")
            finally:
                session.remove_client(client_id)
                logger.info(f"[HTTP] 推送结束: {stream_id} -> {client_id}")
        
        # 使用 web.Response 并设置正确的头
        return web.Response(
            body=flv_stream(),
            content_type='video/x-flv',
            headers={
                'Access-Control-Allow-Origin': '*',
                'Cache-Control': 'no-cache, no-store, must-revalidate',
                'Pragma': 'no-cache',
                'Expires': '0',
                'Connection': 'close',
                'X-Content-Type-Options': 'nosniff',
            }
        )
    
    async def handle_api_streams(self, request: web.Request) -> web.Response:
        """API: 获取流列表"""
        streams = self.server.stream_manager.list_streams()
        stats = self.server.stream_manager.get_stats()
        return web.json_response({
            'streams': streams,
            'total': len(streams),
            'total_clients': stats.get('total_clients', 0),
            'timestamp': time.time()
        })
    
    async def handle_api_stats(self, request: web.Request) -> web.Response:
        """API: 获取详细统计"""
        stats = self.server.stream_manager.get_stats()
        stats['timestamp'] = time.time()
        stats['version'] = '1.0'
        stats['uptime'] = time.time() - self.server.start_time if self.server.start_time else 0
        return web.json_response(stats)
    
    async def handle_health(self, request: web.Request) -> web.Response:
        """健康检查"""
        return web.json_response({
            'status': 'ok',
            'timestamp': time.time()
        })

#-----------------------------------------------------------------------------------------------------------------------
# 主服务器
#-----------------------------------------------------------------------------------------------------------------------

class MediaServer:
    def __init__(self, config: ServerConfig = None):
        self.config = config or ServerConfig()
        self.stream_manager = StreamManager(self.config)
        self.rtp_receiver = None
        self.http_app = None
        self.http_runner = None
        self.running = False
        self.start_time = None
        self.http_handler = HTTPHandler(self)
        
        try:
            loop = asyncio.get_event_loop()
            for sig in (signal.SIGINT, signal.SIGTERM):
                loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._shutdown(s)))
        except NotImplementedError:
            signal.signal(signal.SIGINT, self._sync_signal)
            signal.signal(signal.SIGTERM, self._sync_signal)
        self.flv_server = FLVHttpServer(self.stream_manager, config.http_host, config.http_port + 1)
    
    def _sync_signal(self, signum, frame):
        asyncio.create_task(self._shutdown(signum))
    
    async def _shutdown(self, signal=None):
        await self.stop()
    
    async def start(self) -> bool:
        logger.info("=" * 60)
        logger.info("🚀 启动流媒体服务器")
        logger.info("=" * 60)
        
        self.running = True
        self.start_time = time.time()
        await self.flv_server.start()
        logger.info(f"✅ FLV HTTP: http://{self.config.http_host}:{self.config.http_port + 1}")
        try:
            # HTTP服务
            self.http_app = web.Application()
            
            # 注册路由
            self.http_app.router.add_get('/', self.http_handler.handle_index)
            self.http_app.router.add_get('/stream/{stream_id}', self.http_handler.handle_stream)
            self.http_app.router.add_get('/api/streams', self.http_handler.handle_api_streams)
            self.http_app.router.add_get('/api/stats', self.http_handler.handle_api_stats)
            self.http_app.router.add_get('/health', self.http_handler.handle_health)
            
            # 添加CORS中间件
            @web.middleware
            async def cors_middleware(request, handler):
                response = await handler(request)
                response.headers['Access-Control-Allow-Origin'] = '*'
                response.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS'
                response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
                return response
            
            self.http_app.middlewares.append(cors_middleware)
            
            self.http_runner = web.AppRunner(self.http_app)
            await self.http_runner.setup()
            
            site = web.TCPSite(self.http_runner, self.config.http_host, self.config.http_port)
            await site.start()
            
            logger.info(f"✅ HTTP: http://{self.config.http_host}:{self.config.http_port}")
            
            # RTP服务
            self.rtp_receiver = AsyncRTPReceiver(self.config.rtp_port)
            
            def on_rtp(data, addr):
                try:
                    self._process_rtp(data, addr)
                except Exception as e:
                    logger.error(f"RTP处理错误: {e}")
            
            await self.rtp_receiver.start(on_rtp)
            logger.info(f"✅ RTP: 端口 {self.config.rtp_port}")
            
            logger.info("=" * 60)
            logger.info("✅ 服务器启动成功!")
            logger.info(f"   推流: ffmpeg -re -i test.mp4 -an -c:v libx264 -f rtp rtp://127.0.0.1:{self.config.rtp_port}")
            logger.info(f"   播放: http://localhost:{self.config.http_port}")
            logger.info("=" * 60)
            
            return True
            
        except Exception as e:
            logger.error(f"启动失败: {e}")
            await self.stop()
            return False
    
    def _process_rtp(self, data: bytes, addr: Tuple[str, int]):
        try:
            rtp = RTPPacket(data)
            if not rtp.is_valid():
                return
            
            device_key = f"{addr[0]}:{addr[1]}:{rtp.ssrc}"
            session = self.stream_manager.get_or_create_stream(device_key, rtp.ssrc, addr)
            
            if session:
                session.feed_rtp(data, addr)
        except Exception as e:
            logger.debug(f"RTP处理错误: {e}")
    
    async def stop(self):
        if not self.running:
            return
        
        logger.info("正在停止服务器...")
        self.running = False
        
        if self.rtp_receiver:
            self.rtp_receiver.stop()
        
        if self.http_runner:
            await self.http_runner.cleanup()
        await self.flv_server.stop()
        self.stream_manager.stop()
        
        logger.info("✅ 服务器已停止")
    
    async def run_forever(self):
        try:
            while self.running:
                await asyncio.sleep(1)
        except asyncio.CancelledError:
            pass
        finally:
            await self.stop()

#-----------------------------------------------------------------------------------------------------------------------
# 入口
#-----------------------------------------------------------------------------------------------------------------------

async def main():
    import argparse
    
    parser = argparse.ArgumentParser(description='流媒体服务器')
    parser.add_argument('--rtp-port', type=int, default=10000)
    parser.add_argument('--http-port', type=int, default=8080)
    parser.add_argument('--http-host', type=str, default='0.0.0.0')
    
    args = parser.parse_args()
    
    config = ServerConfig()
    config.rtp_port = args.rtp_port
    config.http_port = args.http_port
    config.http_host = args.http_host
    
    server = MediaServer(config)
    
    if await server.start():
        await server.run_forever()

if __name__ == "__main__":
    asyncio.run(main())

前端页面index.html

html 复制代码
<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>流媒体服务器</title>
    <style>
        * { margin:0; padding:0; box-sizing:border-box; }
        body { font-family:Arial,sans-serif; background:#1a1a2e; color:#e0e0e0; padding:20px; }
        h1 { text-align:center; margin-bottom:20px; }
        .streams { display:grid; grid-template-columns:repeat(auto-fill, minmax(400px,1fr)); gap:20px; }
        .card { background:#16213e; border-radius:12px; overflow:hidden; }
        .card video { width:100%; background:#000; display:block; }
        .card .info { padding:10px 15px; font-size:13px; display:flex; justify-content:space-between; }
        .dot { width:8px; height:8px; border-radius:50%; display:inline-block; }
        .dot.live { background:#4caf50; }
        .dot.error { background:#f44336; }
        .empty { text-align:center; padding:60px; color:#666; grid-column:1/-1; }
        button { padding:10px 20px; border:none; border-radius:8px; cursor:pointer; background:#0f3460; color:#fff; }
    </style>
</head>
<body>
    <h1>📡 流媒体服务器</h1>
    <div style="text-align:center;margin-bottom:20px;">
        <span id="status">流: 0</span>
        <button onclick="refresh()" style="margin-left:10px;">🔄 刷新</button>
    </div>
    <div id="streams" class="streams">
        <div class="empty">等待推流...<br>
        <small style="color:#888;">ffmpeg -re -i test.mp4 -an -c:v libx264 -f rtp rtp://127.0.0.1:{{rtp_port}}</small></div>
    </div>

    <!-- 更换为可靠的CDN -->
    <script src="https://unpkg.com/flv.js@1.6.2/dist/flv.min.js"></script>
    <script>
        // 检查flv.js是否正确加载
        if (typeof flvjs === 'undefined') {
            document.getElementById('streams').innerHTML = 
                '<div class="empty">❌ flv.js 加载失败<br><small>请刷新页面或检查网络连接</small></div>';
            throw new Error('flv.js not loaded');
        }

        let players = {};
        let currentStreams = new Set();

        async function refresh() {
            try {
                const res = await fetch('/api/streams');
                const data = await res.json();
                document.getElementById('status').textContent = '流: ' + data.total;

                const newSet = new Set(data.streams);

                // 移除不存在的流
                for (const sid of [...currentStreams]) {
                    if (!newSet.has(sid)) {
                        destroyPlayer(sid);
                        const card = document.getElementById('card_' + sid);
                        if (card) card.remove();
                        currentStreams.delete(sid);
                    }
                }

                // 添加新流
                for (const sid of newSet) {
                    if (!currentStreams.has(sid)) {
                        currentStreams.add(sid);
                        createCard(sid);
                        setTimeout(() => initPlayer(sid), 100);
                    }
                }
            } catch (e) {
                console.error('刷新失败:', e);
            }
        }

        function destroyPlayer(sid) {
            if (players[sid]) {
                try {
                    const player = players[sid];
                    player.pause();
                    player.unload();
                    player.detachMediaElement();
                    player.destroy();
                } catch (e) {
                    console.warn('销毁播放器失败:', e);
                }
                delete players[sid];
            }
        }

        function createCard(sid) {
            const container = document.getElementById('streams');
            const empty = container.querySelector('.empty');
            if (empty) empty.remove();

            const card = document.createElement('div');
            card.className = 'card';
            card.id = 'card_' + sid;
            card.innerHTML = `
                <video id="v_${sid}" controls autoplay muted playsinline 
                    style="width:100%;min-height:240px;background:#000;"></video>
                <div class="info">
                    <span style="font-size:12px;">${sid.substring(0,20)}...</span>
                    <span>
                        <span class="dot" id="dot_${sid}"></span> 
                        <span id="st_${sid}" style="font-size:12px;">加载中...</span>
                    </span>
                </div>
            `;
            container.appendChild(card);
        }

        function initPlayer(sid) {
            if (players[sid]) return;

            const videoEl = document.getElementById('v_' + sid);
            if (!videoEl) {
                console.warn('Video元素不存在:', sid);
                return;
            }

            if (!flvjs.isSupported()) {
                console.error('当前浏览器不支持flv.js');
                updateStreamStatus(sid, '不支持', 'error');
                return;
            }

            try {
                const player = flvjs.createPlayer({
                    type: 'flv',
                    url: '/stream/' + sid + '.flv',
                    isLive: true,
                    hasAudio: false,
                    hasVideo: true,
                    cors: true
                }, {
                    enableWorker: true,
                    enableStashBuffer: false,
                    stashInitialSize: 128,
                    autoCleanupSourceBuffer: true,
                    autoCleanupMaxBackwardDuration: 30,
                    autoCleanupMinBackwardDuration: 10,
                    liveBufferLatencyChasing: true,
                    liveBufferLatencyMaxLatency: 1.5,
                    liveBufferLatencyMinRemain: 0.2
                });

                player.attachMediaElement(videoEl);
                player.load();

                // 自动播放处理
                const autoplay = () => {
                    const playPromise = videoEl.play();
                    if (playPromise && playPromise.catch) {
                        playPromise.catch(() => {
                            const resume = () => {
                                videoEl.play().catch(() => {});
                                videoEl.removeEventListener('click', resume);
                                videoEl.removeEventListener('touchstart', resume);
                            };
                            videoEl.addEventListener('click', resume, { once: true });
                            videoEl.addEventListener('touchstart', resume, { once: true });
                        });
                    }
                };

                videoEl.addEventListener('loadedmetadata', autoplay, { once: true });

                // 错误处理
                player.on(flvjs.Events.ERROR, (type, detail) => {
                    console.error('[' + sid + '] 播放错误:', type, detail);
                    updateStreamStatus(sid, '错误', 'error');
                    
                    if (type === flvjs.ErrorTypes.NETWORK_ERROR) {
                        setTimeout(() => {
                            if (players[sid]) {
                                try {
                                    player.load();
                                } catch(e) {
                                    console.error('重连失败:', e);
                                }
                            }
                        }, 2000);
                    }
                });

                // 统计信息
                player.on(flvjs.Events.STATISTICS_INFO, (stats) => {
                    if (stats && stats.speed) {
                        updateStreamStatus(sid, Math.round(stats.speed) + ' KB/s', 'live');
                    }
                });

                players[sid] = player;
                console.log('播放器初始化成功:', sid);

            } catch (e) {
                console.error('播放器初始化失败:', sid, e);
                updateStreamStatus(sid, '初始化失败', 'error');
            }
        }

        function updateStreamStatus(sid, text, status) {
            const stEl = document.getElementById('st_' + sid);
            const dotEl = document.getElementById('dot_' + sid);
            
            if (stEl) stEl.textContent = text;
            if (dotEl) {
                dotEl.className = 'dot ' + status;
            }
        }

        // 定期刷新
        setInterval(refresh, 5000);
        refresh();

        // 页面卸载时清理
        window.addEventListener('beforeunload', () => {
            Object.keys(players).forEach(sid => {
                try {
                    players[sid].destroy();
                } catch(e) {
                    console.warn('清理播放器失败:', sid, e);
                }
            });
            players = {};
        });
    </script>
</body>
</html>

测试命令

bash 复制代码
ffmpeg -re -i test.mp4 -an -c:v libx264 -f rtp rtp://127.0.0.1:10000

流程,启动服务器,python rtp_server.py,然后启动推流命令,然后访问前端 xxx:8080,就可以预览视频了,后面开始研究 接入 gb28181协议,gb28181是 ps 封装,跟裸 rtp 有不一样,不知道 还要踩多少坑,好在这个调通了,调了几天了,不知到出了多少版了,come on!