修改Spatial-MLLM项目,使其专注于无人机航拍视频的空间理解。以下是修改方案和关键代码实现:
修改思路
- 输入处理:将原项目的视频+文本输入改为单一无人机航拍视频/图像输入
- 问题生成:自动生成空间理解相关的问题(无需用户输入文本)
- 模型适配:调整视觉编码器处理航拍图像特征
- 输出优化:聚焦空间关系、物体定位和场景结构的分析
关键代码修改
1. 输入处理模块 (video_processor.py
)
python
import cv2
import torch
from PIL import Image
from torchvision import transforms
class DroneVideoProcessor:
def __init__(self, frame_stride=5, target_size=224):
self.frame_stride = frame_stride
self.transform = transforms.Compose([
transforms.Resize((target_size, target_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def load_video(self, video_path):
"""处理无人机航拍视频,提取关键帧"""
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % self.frame_stride == 0:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frame = self.transform(frame)
frames.append(frame)
frame_count += 1
cap.release()
return torch.stack(frames) # [T, C, H, W]
def load_image(self, image_path):
"""处理单张航拍图像"""
img = Image.open(image_path).convert('RGB')
return self.transform(img).unsqueeze(0) # [1, C, H, W]
2. 问题生成模块 (question_generator.py
)
python
class SpatialQuestionGenerator:
def generate_questions(self, input_type):
"""自动生成空间理解问题"""
base_questions = {
'video': [
"描述视频中的空间布局和物体位置关系",
"分析场景中物体的三维空间分布",
"识别视频中的地形特征和空间结构",
"解释物体间的相对位置和空间关系",
"总结场景的整体空间组织方式"
],
'image': [
"分析图像中的深度信息和空间层次",
"描述物体的空间分布和方位关系",
"识别场景中的前景、中景和背景元素",
"解释透视关系和空间尺度",
"评估场景的空间复杂度和结构特征"
]
}
return base_questions[input_type]
3. 模型适配模块 (spatial_mllm.py
)
python
class SpatialMLLM(nn.Module):
def __init__(self, vision_encoder, text_decoder):
super().__init__()
self.vision_encoder = vision_encoder
self.text_decoder = text_decoder
# 添加航拍特征适配层
self.drone_adapter = nn.Sequential(
nn.Linear(vision_encoder.output_dim, 1024),
nn.ReLU(),
nn.Linear(1024, text_decoder.hidden_size)
)
def forward(self, video_frames, generated_questions):
# 提取航拍视觉特征
visual_features = self.vision_encoder(video_frames) # [T, D]
# 适配无人机视角特征
drone_features = self.drone_adapter(visual_features.mean(dim=0)) # [D]
# 处理每个生成的问题
answers = []
for question in generated_questions:
# 融合视觉特征和问题文本
input_text = f"无人机航拍分析: {question} 视觉上下文:"
output = self.text_decoder(
input_text,
context_vector=drone_features
)
answers.append(output)
return answers
4. 主执行模块 (run_drone_spatial.py
)
python
from video_processor import DroneVideoProcessor
from question_generator import SpatialQuestionGenerator
from spatial_mllm import SpatialMLLM
def main(input_path):
# 1. 输入处理
processor = DroneVideoProcessor()
if input_path.endswith(('.mp4', '.avi')):
frames = processor.load_video(input_path)
input_type = 'video'
else: # 图像输入
frames = processor.load_image(input_path)
input_type = 'image'
# 2. 生成空间理解问题
question_gen = SpatialQuestionGenerator()
questions = question_gen.generate_questions(input_type)
# 3. 加载模型 (使用原项目预训练权重)
model = SpatialMLLM(
vision_encoder=load_pretrained_vision_encoder(),
text_decoder=load_pretrained_text_decoder()
)
# 4. 执行空间理解分析
with torch.no_grad():
results = model(frames, questions)
# 5. 输出结果
print("\n无人机航拍空间理解分析报告:")
for q, a in zip(questions, results):
print(f"\n问题: {q}")
print(f"分析: {a}\n{'='*50}")
if __name__ == "__main__":
import sys
main(sys.argv[1])
使用方式
bash
# 分析航拍视频
python run_drone_spatial.py drone_video.mp4
# 分析航拍图像
python run_drone_spatial.py aerial_photo.jpg
输出示例
无人机航拍空间理解分析报告:
问题: 描述视频中的空间布局和物体位置关系
分析: 视频显示城市区域的空间布局呈现网格状结构。建筑物高度由市中心向郊区递减,形成梯度分布。主干道两侧建筑物密度较高,与支路形成层次结构。公园区域(约占总面积15%)作为开放空间位于东南象限,与商业区形成鲜明对比...
==================================================================
关键技术点
-
航拍特征增强:
- 添加
drone_adapter
模块专门处理俯视视角特征 - 采用多帧时空融合策略处理视频序列
- 添加
-
空间分析维度:
- 深度估计与层次分离(前景/背景)
- 物体相对位置关系分析
- 区域分割与功能分区识别
- 三维空间重建(高度/密度分布)
- 动态物体轨迹预测(仅视频模式)
-
优化策略:
python# 在video_processor.py中添加 def enhance_aerial_features(self, frames): """航拍图像增强处理""" # 1. 对比度增强(突出地形特征) # 2. 边缘增强(强化建筑轮廓) # 3. 色度校正(补偿大气散射) # 4. 小目标检测增强 return enhanced_frames