
一、数据库操作编程实现(PostgreSQL + Python)
1. 数据库连接与初始化
使用psycopg2库实现PostgreSQL数据库连接,并初始化核心数据表:
python
import psycopg2
from psycopg2.extras import RealDictCursor
import json
# 数据库连接配置
DB_CONFIG = {
"dbname": "medical_ai_db",
"user": "postgres",
"password": "your_password",
"host": "localhost",
"port": "5432"
}
def init_db():
"""初始化数据库表结构"""
conn = psycopg2.connect(**DB_CONFIG)
cursor = conn.cursor()
# 创建患者表
cursor.execute("""
CREATE TABLE IF NOT EXISTS patients (
patient_hash_id VARCHAR(64) PRIMARY KEY,
age INT,
gender VARCHAR(10),
diagnosis_code VARCHAR(20),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""")
# 创建训练模型表
cursor.execute("""
CREATE TABLE IF NOT EXISTS trained_models (
model_id SERIAL PRIMARY KEY,
architecture_name VARCHAR(50),
version_tag VARCHAR(20),
checkpoint_path VARCHAR(255),
hyperparameters JSONB,
is_active BOOLEAN DEFAULT TRUE,
trained_at TIMESTAMP
);
""")
# 创建合成任务表
cursor.execute("""
CREATE TABLE IF NOT EXISTS synthesis_tasks (
task_id SERIAL PRIMARY KEY,
model_id INT REFERENCES trained_models(model_id),
prompt_condition TEXT,
batch_size INT DEFAULT 1,
status VARCHAR(20) CHECK (status IN ('QUEUED', 'RUNNING', 'COMPLETED', 'FAILED')),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""")
# 创建合成影像表
cursor.execute("""
CREATE TABLE IF NOT EXISTS synthetic_images (
synth_id SERIAL PRIMARY KEY,
task_id INT REFERENCES synthesis_tasks(task_id),
s3_storage_path VARCHAR(255) NOT NULL,
seed_used BIGINT,
guidance_scale FLOAT,
generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""")
# 创建评估指标表
cursor.execute("""
CREATE TABLE IF NOT EXISTS evaluation_metrics (
eval_id SERIAL PRIMARY KEY,
task_id INT REFERENCES synthesis_tasks(task_id),
metric_name VARCHAR(20),
metric_value FLOAT,
evaluated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""")
conn.commit()
cursor.close()
conn.close()
print("数据库表初始化完成!")
# 执行初始化
init_db()
2. 数据增删改查操作封装
python
class MedicalAIDB:
def __init__(self):
self.conn = psycopg2.connect(**DB_CONFIG)
self.cursor = self.conn.cursor(cursor_factory=RealDictCursor)
def __del__(self):
self.cursor.close()
self.conn.close()
# 插入训练模型记录
def insert_trained_model(self, arch_name, version, checkpoint_path, hyperparams):
sql = """
INSERT INTO trained_models (architecture_name, version_tag, checkpoint_path, hyperparameters, trained_at)
VALUES (%s, %s, %s, %s, CURRENT_TIMESTAMP)
RETURNING model_id;
"""
self.cursor.execute(sql, (arch_name, version, checkpoint_path, json.dumps(hyperparams)))
model_id = self.cursor.fetchone()["model_id"]
self.conn.commit()
return model_id
# 创建合成任务
def create_synthesis_task(self, model_id, prompt, batch_size=1):
sql = """
INSERT INTO synthesis_tasks (model_id, prompt_condition, batch_size, status)
VALUES (%s, %s, %s, 'QUEUED')
RETURNING task_id;
"""
self.cursor.execute(sql, (model_id, prompt, batch_size))
task_id = self.cursor.fetchone()["task_id"]
self.conn.commit()
return task_id
# 更新任务状态
def update_task_status(self, task_id, status):
sql = "UPDATE synthesis_tasks SET status = %s WHERE task_id = %s;"
self.cursor.execute(sql, (status, task_id))
self.conn.commit()
# 记录合成影像
def insert_synthetic_image(self, task_id, s3_path, seed, guidance_scale=0.0):
sql = """
INSERT INTO synthetic_images (task_id, s3_storage_path, seed_used, guidance_scale)
VALUES (%s, %s, %s, %s);
"""
self.cursor.execute(sql, (task_id, s3_path, seed, guidance_scale))
self.conn.commit()
# 记录评估指标
def insert_evaluation_metric(self, task_id, metric_name, metric_value):
sql = """
INSERT INTO evaluation_metrics (task_id, metric_name, metric_value)
VALUES (%s, %s, %s);
"""
self.cursor.execute(sql, (task_id, metric_name, metric_value))
self.conn.commit()
# 使用示例
db = MedicalAIDB()
# 插入DDPM模型记录
ddpm_hyperparams = {
"diffusion_steps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"lr": 1e-4,
"batch_size": 8
}
model_id = db.insert_trained_model("DDPM-UNet", "v1.0", "/models/ddpm_lung_ct.pt", ddpm_hyperparams)
# 创建合成任务
task_id = db.create_synthesis_task(model_id, "Lung CT with 5mm nodule in upper lobe", batch_size=4)
print(f"创建合成任务ID: {task_id}")
二、MinIO对象存储操作(医学影像文件管理)
python
from minio import Minio
from minio.error import S3Error
class MedicalImageStorage:
def __init__(self, endpoint, access_key, secret_key, secure=False):
self.client = Minio(endpoint, access_key=access_key, secret_key=secret_key, secure=secure)
self.bucket_name = "medical-synthetic-images"
# 创建存储桶(如果不存在)
if not self.client.bucket_exists(self.bucket_name):
self.client.make_bucket(self.bucket_name)
def upload_image(self, local_path, object_name):
"""上传合成影像到MinIO"""
try:
self.client.fput_object(
self.bucket_name, object_name, local_path
)
return f"s3://{self.bucket_name}/{object_name}"
except S3Error as e:
print(f"上传失败: {e}")
return None
def download_image(self, object_name, local_path):
"""从MinIO下载影像"""
try:
self.client.fget_object(
self.bucket_name, object_name, local_path
)
return True
except S3Error as e:
print(f"下载失败: {e}")
return False
# 使用示例
storage = MedicalImageStorage("localhost:9000", "minio_access_key", "minio_secret_key")
# 上传合成的DICOM文件
s3_path = storage.upload_image("/tmp/synthetic_ct.dcm", f"task_{task_id}/image_1.dcm")
if s3_path:
db.insert_synthetic_image(task_id, s3_path, seed=42)
三、DDPM医学影像合成模型实现(PyTorch)
1. 扩散模型核心模块
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os
# 定义UNet骨干网络(简化版)
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1, base_channels=64):
super().__init__()
self.down1 = nn.Conv2d(in_channels, base_channels, 4, 2, 1)
self.down2 = nn.Conv2d(base_channels, base_channels*2, 4, 2, 1)
self.down3 = nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1)
self.up1 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, 2, 1)
self.up2 = nn.ConvTranspose2d(base_channels*2, base_channels, 4, 2, 1)
self.up3 = nn.ConvTranspose2d(base_channels, out_channels, 4, 2, 1)
self.bn1 = nn.BatchNorm2d(base_channels)
self.bn2 = nn.BatchNorm2d(base_channels*2)
self.bn3 = nn.BatchNorm2d(base_channels*4)
def forward(self, x, t):
# 时间嵌入(简化处理)
t_emb = torch.zeros_like(x[:, :1, :, :]) + t / 1000
x = torch.cat([x, t_emb], dim=1)
# 下采样
d1 = F.relu(self.bn1(self.down1(x)))
d2 = F.relu(self.bn2(self.down2(d1)))
d3 = F.relu(self.bn3(self.down3(d2)))
# 上采样
u1 = F.relu(self.up1(d3))
u2 = F.relu(self.up2(u1 + d2))
u3 = torch.sigmoid(self.up3(u2 + d1))
return u3
# 定义DDPM模型
class DDPM(nn.Module):
def __init__(self, model, betas=(1e-4, 0.02), n_T=1000):
super().__init__()
self.model = model
self.n_T = n_T
# 预计算beta和alpha
self.betas = torch.linspace(betas[0], betas[1], n_T)
self.alphas = 1 - self.betas
self.alphas_bar = torch.cumprod(self.alphas, dim=0)
def forward(self, x):
"""训练阶段:加噪并预测噪声"""
t = torch.randint(0, self.n_T, (x.shape[0],)).to(x.device)
noise = torch.randn_like(x)
# 计算t时刻加噪后的图像
alpha_bar_t = self.alphas_bar[t].reshape(-1, 1, 1, 1).to(x.device)
x_t = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise
# 预测噪声
noise_pred = self.model(x_t, t)
return F.mse_loss(noise_pred, noise)
def sample(self, n_samples, img_size, device):
"""采样阶段:从噪声生成图像"""
x = torch.randn(n_samples, 1, img_size, img_size).to(device)
for t in reversed(range(self.n_T)):
t_tensor = torch.tensor([t]).to(device)
alpha_t = self.alphas[t].to(device)
alpha_bar_t = self.alphas_bar[t].to(device)
# 预测噪声
noise_pred = self.model(x, t_tensor)
# 反向去噪步骤
if t > 0:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
x = (1 / torch.sqrt(alpha_t)) * (
x - (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t) * noise_pred
) + torch.sqrt(self.betas[t]) * noise
return x
2. 医学影像数据集与训练流程
python
# 医学影像数据集类(处理DICOM/PNG格式)
class MedicalImageDataset(Dataset):
def __init__(self, data_dir, img_size=128):
self.data_dir = data_dir
self.img_size = img_size
self.file_list = [f for f in os.listdir(data_dir) if f.endswith(('.png', '.dcm'))]
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
file_path = os.path.join(self.data_dir, self.file_list[idx])
# 读取并预处理图像
if file_path.endswith('.png'):
img = Image.open(file_path).convert('L')
else:
# 处理DICOM文件(需安装pydicom)
import pydicom
dicom = pydicom.dcmread(file_path)
img = Image.fromarray(dicom.pixel_array).convert('L')
img = img.resize((self.img_size, self.img_size))
img = np.array(img).astype(np.float32) / 255.0
img = torch.from_numpy(img).unsqueeze(0)
return img
# 训练函数
def train_ddpm():
# 配置参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_dir = "/data/lung_ct_dataset"
img_size = 128
batch_size = 8
epochs = 100
lr = 1e-4
# 初始化数据集和模型
dataset = MedicalImageDataset(data_dir, img_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
unet = UNet(in_channels=2, out_channels=1).to(device) # +1通道用于时间嵌入
ddpm = DDPM(unet, n_T=1000).to(device)
optimizer = torch.optim.Adam(ddpm.parameters(), lr=lr)
# 训练循环
for epoch in range(epochs):
ddpm.train()
total_loss = 0
for batch in dataloader:
batch = batch.to(device)
optimizer.zero_grad()
loss = ddpm(batch)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
# 每10轮保存模型并生成示例
if (epoch + 1) % 10 == 0:
torch.save(ddpm.state_dict(), f"/models/ddpm_epoch_{epoch+1}.pt")
# 生成示例图像
ddpm.eval()
with torch.no_grad():
samples = ddpm.sample(n_samples=4, img_size=img_size, device=device)
# 保存生成的图像
for i, sample in enumerate(samples):
sample = (sample.squeeze().cpu().numpy() * 255).astype(np.uint8)
Image.fromarray(sample).save(f"/samples/epoch_{epoch+1}_sample_{i}.png")
# 执行训练
train_ddpm()
3. 合成影像生成与DICOM格式转换
python
# 生成合成影像并保存为DICOM格式
def generate_synthetic_images(model_path, task_id, num_samples=4):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img_size = 128
# 加载预训练模型
unet = UNet(in_channels=2, out_channels=1).to(device)
ddpm = DDPM(unet, n_T=1000).to(device)
ddpm.load_state_dict(torch.load(model_path))
ddpm.eval()
# 生成影像
with torch.no_grad():
samples = ddpm.sample(n_samples=num_samples, img_size=img_size, device=device)
# 转换为DICOM格式并保存
storage = MedicalImageStorage("localhost:9000", "minio_access_key", "minio_secret_key")
db = MedicalAIDB()
for i, sample in enumerate(samples):
# 转换为numpy数组
img_array = (sample.squeeze().cpu().numpy() * 255).astype(np.uint16)
# 创建DICOM文件(需安装pydicom)
import pydicom
from pydicom.dataset import Dataset, FileDataset
from pydicom.uid import ExplicitVRLittleEndian
# 设置DICOM元数据
file_meta = Dataset()
file_meta.MediaStorageSOPClassUID = pydicom.uid.CTImageStorage
file_meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid()
file_meta.TransferSyntaxUID = ExplicitVRLittleEndian
ds = FileDataset(f"/tmp/synth_{task_id}_{i}.dcm", {}, file_meta=file_meta, preamble=b"\0"*128)
ds.PatientName = f"Synthetic_{task_id}"
ds.PatientID = f"Synth-{task_id}-{i}"
ds.Modality = "CT"
ds.SeriesDescription = "Synthetic Lung CT"
ds.Rows = img_size
ds.Columns = img_size
ds.PixelSpacing = [1.0, 1.0]
ds.SliceThickness = 1.0
ds.PixelData = img_array.tobytes()
# 保存临时文件并上传到MinIO
temp_path = f"/tmp/synth_{task_id}_{i}.dcm"
ds.save_as(temp_path)
s3_path = storage.upload_image(temp_path, f"task_{task_id}/image_{i}.dcm")
# 记录到数据库
db.insert_synthetic_image(task_id, s3_path, seed=42+i)
# 删除临时文件
os.remove(temp_path)
print(f"成功生成{num_samples}张合成影像,任务ID: {task_id}")
# 使用示例
generate_synthetic_images("/models/ddpm_epoch_100.pt", task_id=1, num_samples=4)
四、合成影像质量评估实现
1. FID计算实现
python
import torch
import torch.nn as nn
from torchvision.models import inception_v3
import numpy as np
from scipy.linalg import sqrtm
class FIDEvaluator:
def __init__(self, device):
self.device = device
# 加载预训练的InceptionV3模型
self.inception = inception_v3(pretrained=True, transform_input=False).to(device)
self.inception.eval()
# 移除分类层,保留特征提取层
self.feature_extractor = nn.Sequential(*list(self.inception.children())[:-1])
def get_features(self, images):
"""提取图像特征(需将图像调整为3通道并归一化)"""
# 转换为3通道(重复单通道)
if images.shape[1] == 1:
images = torch.cat([images]*3, dim=1)
# 调整大小到InceptionV3输入尺寸
images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
# 归一化到[-1, 1]
images = (images - 0.5) * 2.0
with torch.no_grad():
features = self.feature_extractor(images)
features = features.view(features.shape[0], -1)
return features.cpu().numpy()
def calculate_fid(self, real_images, fake_images):
"""计算FID分数"""
# 提取特征
real_features = self.get_features(real_images.to(self.device))
fake_features = self.get_features(fake_images.to(self.device))
# 计算均值和协方差
mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
# 计算FID
ssdiff = np.sum((mu1 - mu2)**2.0)
covmean = sqrtm(sigma1.dot(sigma2))
# 处理复数结果
if np.iscomplexobj(covmean):
covmean = covmean.real
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return fid
# 使用示例
def evaluate_synthetic_images():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fid_evaluator = FIDEvaluator(device)
# 加载真实影像和合成影像
real_dataset = MedicalImageDataset("/data/lung_ct_dataset", img_size=128)
real_loader = DataLoader(real_dataset, batch_size=16, shuffle=False)
real_images = next(iter(real_loader))
# 加载合成影像
synth_images = []
for i in range(16):
img = Image.open(f"/samples/epoch_100_sample_{i%4}.png").convert('L')
img = img.resize((128, 128))
img = np.array(img).astype(np.float32) / 255.0
synth_images.append(torch.from_numpy(img).unsqueeze(0))
synth_images = torch.stack(synth_images)
# 计算FID
fid_score = fid_evaluator.calculate_fid(real_images, synth_images)
print(f"FID Score: {fid_score:.2f}")
# 记录到数据库
db = MedicalAIDB()
db.insert_evaluation_metric(task_id=1, metric_name="FID", metric_value=fid_score)
evaluate_synthetic_images()
2. SSIM计算实现
python
import torch
import torch.nn.functional as F
def calculate_ssim(img1, img2, window_size=11, sigma=1.5):
"""计算结构相似性指数(SSIM)"""
# 创建高斯核
gaussian = torch.exp(-torch.arange(window_size)**2 / (2 * sigma**2))
gaussian = gaussian / gaussian.sum()
window_1d = gaussian.unsqueeze(1)
window_2d = window_1d @ window_1d.t()
window = window_2d.unsqueeze(0).unsqueeze(0).to(img1.device)
# 计算均值
mu1 = F.conv2d(img1, window, padding=window_size//2, groups=1)
mu2 = F.conv2d(img2, window, padding=window_size//2, groups=1)
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
# 计算方差和协方差
sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=1) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=1) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=1) - mu1_mu2
# SSIM公式
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
# 使用示例
def evaluate_ssim():
# 加载真实影像和合成影像
real_img = Image.open("/data/lung_ct_dataset/real_001.png").convert('L')
synth_img = Image.open("/samples/epoch_100_sample_0.png").convert('L')
# 预处理
real_img = real_img.resize((128, 128))
synth_img = synth_img.resize((128, 128))
real_tensor = torch.from_numpy(np.array(real_img)).float().unsqueeze(0).unsqueeze(0) / 255.0
synth_tensor = torch.from_numpy(np.array(synth_img)).float().unsqueeze(0).unsqueeze(0) / 255.0
# 计算SSIM
ssim_score = calculate_ssim(real_tensor, synth_tensor).item()
print(f"SSIM Score: {ssim_score:.4f}")
# 记录到数据库
db = MedicalAIDB()
db.insert_evaluation_metric(task_id=1, metric_name="SSIM", metric_value=ssim_score)
evaluate_ssim()
五、完整系统集成与任务调度
python
import threading
import time
from queue import Queue
class SynthesisTaskScheduler:
def __init__(self):
self.task_queue = Queue()
self.db = MedicalAIDB()
self.storage = MedicalImageStorage("localhost:9000", "minio_access_key", "minio_secret_key")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载预训练模型
self.models = {
"DDPM": self.load_model("/models/ddpm_epoch_100.pt"),
"LDM": self.load_model("/models/ldm_lung_ct.pt")
}
# 启动任务处理线程
self.worker_thread = threading.Thread(target=self.process_tasks, daemon=True)
self.worker_thread.start()
def load_model(self, model_path):
"""加载预训练模型"""
if "ddpm" in model_path.lower():
unet = UNet(in_channels=2, out_channels=1).to(self.device)
model = DDPM(unet, n_T=1000).to(self.device)
else:
# 加载LDM模型(示例)
from ldm.models.diffusion.ddpm import LatentDiffusion
model = LatentDiffusion.load_from_checkpoint(model_path).to(self.device)
model.load_state_dict(torch.load(model_path))
model.eval()
return model
def submit_task(self, model_name, prompt, batch_size=1):
"""提交合成任务"""
# 获取模型ID
model_id = self.db.cursor.execute(
"SELECT model_id FROM trained_models WHERE architecture_name = %s AND is_active = TRUE",
(model_name,)
).fetchone()["model_id"]
# 创建数据库任务记录
task_id = self.db.create_synthesis_task(model_id, prompt, batch_size)
# 添加到任务队列
self.task_queue.put({
"task_id": task_id,
"model_name": model_name,
"prompt": prompt,
"batch_size": batch_size
})
return task_id
def process_tasks(self):
"""处理合成任务"""
while True:
if not self.task_queue.empty():
task = self.task_queue.get()
task_id = task["task_id"]
model_name = task["model_name"]
batch_size = task["batch_size"]
try:
# 更新任务状态为运行中
self.db.update_task_status(task_id, "RUNNING")
# 获取模型并生成影像
model = self.models[model_name]
if model_name == "DDPM":
samples = model.sample(n_samples=batch_size, img_size=128, device=self.device)
else:
# LDM生成(需处理文本提示)
from ldm.util import instantiate_from_config
samples = model.sample(prompt=task["prompt"], batch_size=batch_size)
# 保存生成的影像
for i in range(batch_size):
# 转换并保存为DICOM
img_array = (samples[i].squeeze().cpu().numpy() * 255).astype(np.uint16)
temp_path = f"/tmp/synth_{task_id}_{i}.dcm"
self.save_as_dicom(img_array, temp_path)
# 上传到MinIO
s3_path = self.storage.upload_image(temp_path, f"task_{task_id}/image_{i}.dcm")
# 记录到数据库
self.db.insert_synthetic_image(task_id, s3_path, seed=42+i)
# 删除临时文件
os.remove(temp_path)
# 评估合成影像质量
self.evaluate_task(task_id)
# 更新任务状态为完成
self.db.update_task_status(task_id, "COMPLETED")
except Exception as e:
print(f"任务处理失败: {e}")
self.db.update_task_status(task_id, "FAILED")
finally:
self.task_queue.task_done()
time.sleep(1)
def save_as_dicom(self, img_array, save_path):
"""将numpy数组保存为DICOM格式"""
import pydicom
from pydicom.dataset import Dataset, FileDataset
file_meta = Dataset()
file_meta.MediaStorageSOPClassUID = pydicom.uid.CTImageStorage
file_meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid()
file_meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian
ds = FileDataset(save_path, {}, file_meta=file_meta, preamble=b"\0"*128)
ds.Rows, ds.Columns = img_array.shape
ds.PixelSpacing = [1.0, 1.0]
ds.SliceThickness = 1.0
ds.Modality = "CT"
ds.PatientName = "Synthetic"
ds.PatientID = "Synth-001"
ds.PixelData = img_array.tobytes()
ds.save_as(save_path)
def evaluate_task(self, task_id):
"""评估任务生成的影像质量"""
# 加载合成影像
synth_images = []
for i in range(4):
temp_path = f"/tmp/synth_{task_id}_{i}.dcm"
import pydicom
dicom = pydicom.dcmread(temp_path)
img = Image.fromarray(dicom.pixel_array).convert('L')
img = img.resize((128, 128))
img = np.array(img).astype(np.float32) / 255.0
synth_images.append(torch.from_numpy(img).unsqueeze(0))
synth_images = torch.stack(synth_images)
# 加载真实影像作为参考
real_dataset = MedicalImageDataset("/data/lung_ct_dataset", img_size=128)
real_loader = DataLoader(real_dataset, batch_size=4, shuffle=False)
real_images = next(iter(real_loader))
# 计算FID和SSIM
fid_evaluator = FIDEvaluator(self.device)
fid_score = fid_evaluator.calculate_fid(real_images, synth_images)
ssim_score = calculate_ssim(real_images, synth_images).item()
# 记录评估结果
self.db.insert_evaluation_metric(task_id, "FID", fid_score)
self.db.insert_evaluation_metric(task_id, "SSIM", ssim_score)
# 启动任务调度器
scheduler = SynthesisTaskScheduler()
# 提交合成任务
task_id = scheduler.submit_task(
model_name="DDPM",
prompt="Lung CT with 5mm nodule in upper lobe",
batch_size=4
)
print(f"提交合成任务ID: {task_id}")
以上代码实现了从数据库架构设计、医学影像合成模型训练、合成任务调度到质量评估的完整流程,涵盖了医疗AI领域合成数据生成的核心技术环节。代码可根据实际需求进一步优化,例如添加分布式训练支持、优化DICOM元数据配置、集成更多评估指标等。