解锁医疗AI新引擎:从数据库与编程语言透视合成数据生成(代码部分)

一、数据库操作编程实现(PostgreSQL + Python)

1. 数据库连接与初始化

使用psycopg2库实现PostgreSQL数据库连接,并初始化核心数据表:

python 复制代码
import psycopg2
from psycopg2.extras import RealDictCursor
import json

# 数据库连接配置
DB_CONFIG = {
    "dbname": "medical_ai_db",
    "user": "postgres",
    "password": "your_password",
    "host": "localhost",
    "port": "5432"
}

def init_db():
    """初始化数据库表结构"""
    conn = psycopg2.connect(**DB_CONFIG)
    cursor = conn.cursor()
    
    # 创建患者表
    cursor.execute("""
    CREATE TABLE IF NOT EXISTS patients (
        patient_hash_id VARCHAR(64) PRIMARY KEY,
        age INT,
        gender VARCHAR(10),
        diagnosis_code VARCHAR(20),
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    );
    """)
    
    # 创建训练模型表
    cursor.execute("""
    CREATE TABLE IF NOT EXISTS trained_models (
        model_id SERIAL PRIMARY KEY,
        architecture_name VARCHAR(50),
        version_tag VARCHAR(20),
        checkpoint_path VARCHAR(255),
        hyperparameters JSONB,
        is_active BOOLEAN DEFAULT TRUE,
        trained_at TIMESTAMP
    );
    """)
    
    # 创建合成任务表
    cursor.execute("""
    CREATE TABLE IF NOT EXISTS synthesis_tasks (
        task_id SERIAL PRIMARY KEY,
        model_id INT REFERENCES trained_models(model_id),
        prompt_condition TEXT,
        batch_size INT DEFAULT 1,
        status VARCHAR(20) CHECK (status IN ('QUEUED', 'RUNNING', 'COMPLETED', 'FAILED')),
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    );
    """)
    
    # 创建合成影像表
    cursor.execute("""
    CREATE TABLE IF NOT EXISTS synthetic_images (
        synth_id SERIAL PRIMARY KEY,
        task_id INT REFERENCES synthesis_tasks(task_id),
        s3_storage_path VARCHAR(255) NOT NULL,
        seed_used BIGINT,
        guidance_scale FLOAT,
        generated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    );
    """)
    
    # 创建评估指标表
    cursor.execute("""
    CREATE TABLE IF NOT EXISTS evaluation_metrics (
        eval_id SERIAL PRIMARY KEY,
        task_id INT REFERENCES synthesis_tasks(task_id),
        metric_name VARCHAR(20),
        metric_value FLOAT,
        evaluated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    );
    """)
    
    conn.commit()
    cursor.close()
    conn.close()
    print("数据库表初始化完成!")

# 执行初始化
init_db()
2. 数据增删改查操作封装
python 复制代码
class MedicalAIDB:
    def __init__(self):
        self.conn = psycopg2.connect(**DB_CONFIG)
        self.cursor = self.conn.cursor(cursor_factory=RealDictCursor)
    
    def __del__(self):
        self.cursor.close()
        self.conn.close()
    
    # 插入训练模型记录
    def insert_trained_model(self, arch_name, version, checkpoint_path, hyperparams):
        sql = """
        INSERT INTO trained_models (architecture_name, version_tag, checkpoint_path, hyperparameters, trained_at)
        VALUES (%s, %s, %s, %s, CURRENT_TIMESTAMP)
        RETURNING model_id;
        """
        self.cursor.execute(sql, (arch_name, version, checkpoint_path, json.dumps(hyperparams)))
        model_id = self.cursor.fetchone()["model_id"]
        self.conn.commit()
        return model_id
    
    # 创建合成任务
    def create_synthesis_task(self, model_id, prompt, batch_size=1):
        sql = """
        INSERT INTO synthesis_tasks (model_id, prompt_condition, batch_size, status)
        VALUES (%s, %s, %s, 'QUEUED')
        RETURNING task_id;
        """
        self.cursor.execute(sql, (model_id, prompt, batch_size))
        task_id = self.cursor.fetchone()["task_id"]
        self.conn.commit()
        return task_id
    
    # 更新任务状态
    def update_task_status(self, task_id, status):
        sql = "UPDATE synthesis_tasks SET status = %s WHERE task_id = %s;"
        self.cursor.execute(sql, (status, task_id))
        self.conn.commit()
    
    # 记录合成影像
    def insert_synthetic_image(self, task_id, s3_path, seed, guidance_scale=0.0):
        sql = """
        INSERT INTO synthetic_images (task_id, s3_storage_path, seed_used, guidance_scale)
        VALUES (%s, %s, %s, %s);
        """
        self.cursor.execute(sql, (task_id, s3_path, seed, guidance_scale))
        self.conn.commit()
    
    # 记录评估指标
    def insert_evaluation_metric(self, task_id, metric_name, metric_value):
        sql = """
        INSERT INTO evaluation_metrics (task_id, metric_name, metric_value)
        VALUES (%s, %s, %s);
        """
        self.cursor.execute(sql, (task_id, metric_name, metric_value))
        self.conn.commit()

# 使用示例
db = MedicalAIDB()
# 插入DDPM模型记录
ddpm_hyperparams = {
    "diffusion_steps": 1000,
    "beta_start": 0.0001,
    "beta_end": 0.02,
    "lr": 1e-4,
    "batch_size": 8
}
model_id = db.insert_trained_model("DDPM-UNet", "v1.0", "/models/ddpm_lung_ct.pt", ddpm_hyperparams)
# 创建合成任务
task_id = db.create_synthesis_task(model_id, "Lung CT with 5mm nodule in upper lobe", batch_size=4)
print(f"创建合成任务ID: {task_id}")

二、MinIO对象存储操作(医学影像文件管理)

python 复制代码
from minio import Minio
from minio.error import S3Error

class MedicalImageStorage:
    def __init__(self, endpoint, access_key, secret_key, secure=False):
        self.client = Minio(endpoint, access_key=access_key, secret_key=secret_key, secure=secure)
        self.bucket_name = "medical-synthetic-images"
        # 创建存储桶(如果不存在)
        if not self.client.bucket_exists(self.bucket_name):
            self.client.make_bucket(self.bucket_name)
    
    def upload_image(self, local_path, object_name):
        """上传合成影像到MinIO"""
        try:
            self.client.fput_object(
                self.bucket_name, object_name, local_path
            )
            return f"s3://{self.bucket_name}/{object_name}"
        except S3Error as e:
            print(f"上传失败: {e}")
            return None
    
    def download_image(self, object_name, local_path):
        """从MinIO下载影像"""
        try:
            self.client.fget_object(
                self.bucket_name, object_name, local_path
            )
            return True
        except S3Error as e:
            print(f"下载失败: {e}")
            return False

# 使用示例
storage = MedicalImageStorage("localhost:9000", "minio_access_key", "minio_secret_key")
# 上传合成的DICOM文件
s3_path = storage.upload_image("/tmp/synthetic_ct.dcm", f"task_{task_id}/image_1.dcm")
if s3_path:
    db.insert_synthetic_image(task_id, s3_path, seed=42)

三、DDPM医学影像合成模型实现(PyTorch)

1. 扩散模型核心模块
python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image
import os

# 定义UNet骨干网络(简化版)
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_channels=64):
        super().__init__()
        self.down1 = nn.Conv2d(in_channels, base_channels, 4, 2, 1)
        self.down2 = nn.Conv2d(base_channels, base_channels*2, 4, 2, 1)
        self.down3 = nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1)
        
        self.up1 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, 2, 1)
        self.up2 = nn.ConvTranspose2d(base_channels*2, base_channels, 4, 2, 1)
        self.up3 = nn.ConvTranspose2d(base_channels, out_channels, 4, 2, 1)
        
        self.bn1 = nn.BatchNorm2d(base_channels)
        self.bn2 = nn.BatchNorm2d(base_channels*2)
        self.bn3 = nn.BatchNorm2d(base_channels*4)
        
    def forward(self, x, t):
        # 时间嵌入(简化处理)
        t_emb = torch.zeros_like(x[:, :1, :, :]) + t / 1000
        x = torch.cat([x, t_emb], dim=1)
        
        # 下采样
        d1 = F.relu(self.bn1(self.down1(x)))
        d2 = F.relu(self.bn2(self.down2(d1)))
        d3 = F.relu(self.bn3(self.down3(d2)))
        
        # 上采样
        u1 = F.relu(self.up1(d3))
        u2 = F.relu(self.up2(u1 + d2))
        u3 = torch.sigmoid(self.up3(u2 + d1))
        
        return u3

# 定义DDPM模型
class DDPM(nn.Module):
    def __init__(self, model, betas=(1e-4, 0.02), n_T=1000):
        super().__init__()
        self.model = model
        self.n_T = n_T
        
        # 预计算beta和alpha
        self.betas = torch.linspace(betas[0], betas[1], n_T)
        self.alphas = 1 - self.betas
        self.alphas_bar = torch.cumprod(self.alphas, dim=0)
    
    def forward(self, x):
        """训练阶段:加噪并预测噪声"""
        t = torch.randint(0, self.n_T, (x.shape[0],)).to(x.device)
        noise = torch.randn_like(x)
        
        # 计算t时刻加噪后的图像
        alpha_bar_t = self.alphas_bar[t].reshape(-1, 1, 1, 1).to(x.device)
        x_t = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise
        
        # 预测噪声
        noise_pred = self.model(x_t, t)
        return F.mse_loss(noise_pred, noise)
    
    def sample(self, n_samples, img_size, device):
        """采样阶段:从噪声生成图像"""
        x = torch.randn(n_samples, 1, img_size, img_size).to(device)
        
        for t in reversed(range(self.n_T)):
            t_tensor = torch.tensor([t]).to(device)
            alpha_t = self.alphas[t].to(device)
            alpha_bar_t = self.alphas_bar[t].to(device)
            
            # 预测噪声
            noise_pred = self.model(x, t_tensor)
            
            # 反向去噪步骤
            if t > 0:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
            
            x = (1 / torch.sqrt(alpha_t)) * (
                x - (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t) * noise_pred
            ) + torch.sqrt(self.betas[t]) * noise
        
        return x
2. 医学影像数据集与训练流程
python 复制代码
# 医学影像数据集类(处理DICOM/PNG格式)
class MedicalImageDataset(Dataset):
    def __init__(self, data_dir, img_size=128):
        self.data_dir = data_dir
        self.img_size = img_size
        self.file_list = [f for f in os.listdir(data_dir) if f.endswith(('.png', '.dcm'))]
    
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        
        # 读取并预处理图像
        if file_path.endswith('.png'):
            img = Image.open(file_path).convert('L')
        else:
            # 处理DICOM文件(需安装pydicom)
            import pydicom
            dicom = pydicom.dcmread(file_path)
            img = Image.fromarray(dicom.pixel_array).convert('L')
        
        img = img.resize((self.img_size, self.img_size))
        img = np.array(img).astype(np.float32) / 255.0
        img = torch.from_numpy(img).unsqueeze(0)
        
        return img

# 训练函数
def train_ddpm():
    # 配置参数
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_dir = "/data/lung_ct_dataset"
    img_size = 128
    batch_size = 8
    epochs = 100
    lr = 1e-4
    
    # 初始化数据集和模型
    dataset = MedicalImageDataset(data_dir, img_size)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    unet = UNet(in_channels=2, out_channels=1).to(device)  # +1通道用于时间嵌入
    ddpm = DDPM(unet, n_T=1000).to(device)
    
    optimizer = torch.optim.Adam(ddpm.parameters(), lr=lr)
    
    # 训练循环
    for epoch in range(epochs):
        ddpm.train()
        total_loss = 0
        
        for batch in dataloader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            loss = ddpm(batch)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        
        # 每10轮保存模型并生成示例
        if (epoch + 1) % 10 == 0:
            torch.save(ddpm.state_dict(), f"/models/ddpm_epoch_{epoch+1}.pt")
            
            # 生成示例图像
            ddpm.eval()
            with torch.no_grad():
                samples = ddpm.sample(n_samples=4, img_size=img_size, device=device)
            
            # 保存生成的图像
            for i, sample in enumerate(samples):
                sample = (sample.squeeze().cpu().numpy() * 255).astype(np.uint8)
                Image.fromarray(sample).save(f"/samples/epoch_{epoch+1}_sample_{i}.png")

# 执行训练
train_ddpm()
3. 合成影像生成与DICOM格式转换
python 复制代码
# 生成合成影像并保存为DICOM格式
def generate_synthetic_images(model_path, task_id, num_samples=4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    img_size = 128
    
    # 加载预训练模型
    unet = UNet(in_channels=2, out_channels=1).to(device)
    ddpm = DDPM(unet, n_T=1000).to(device)
    ddpm.load_state_dict(torch.load(model_path))
    ddpm.eval()
    
    # 生成影像
    with torch.no_grad():
        samples = ddpm.sample(n_samples=num_samples, img_size=img_size, device=device)
    
    # 转换为DICOM格式并保存
    storage = MedicalImageStorage("localhost:9000", "minio_access_key", "minio_secret_key")
    db = MedicalAIDB()
    
    for i, sample in enumerate(samples):
        # 转换为numpy数组
        img_array = (sample.squeeze().cpu().numpy() * 255).astype(np.uint16)
        
        # 创建DICOM文件(需安装pydicom)
        import pydicom
        from pydicom.dataset import Dataset, FileDataset
        from pydicom.uid import ExplicitVRLittleEndian
        
        # 设置DICOM元数据
        file_meta = Dataset()
        file_meta.MediaStorageSOPClassUID = pydicom.uid.CTImageStorage
        file_meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid()
        file_meta.TransferSyntaxUID = ExplicitVRLittleEndian
        
        ds = FileDataset(f"/tmp/synth_{task_id}_{i}.dcm", {}, file_meta=file_meta, preamble=b"\0"*128)
        ds.PatientName = f"Synthetic_{task_id}"
        ds.PatientID = f"Synth-{task_id}-{i}"
        ds.Modality = "CT"
        ds.SeriesDescription = "Synthetic Lung CT"
        ds.Rows = img_size
        ds.Columns = img_size
        ds.PixelSpacing = [1.0, 1.0]
        ds.SliceThickness = 1.0
        ds.PixelData = img_array.tobytes()
        
        # 保存临时文件并上传到MinIO
        temp_path = f"/tmp/synth_{task_id}_{i}.dcm"
        ds.save_as(temp_path)
        s3_path = storage.upload_image(temp_path, f"task_{task_id}/image_{i}.dcm")
        
        # 记录到数据库
        db.insert_synthetic_image(task_id, s3_path, seed=42+i)
        
        # 删除临时文件
        os.remove(temp_path)
    
    print(f"成功生成{num_samples}张合成影像,任务ID: {task_id}")

# 使用示例
generate_synthetic_images("/models/ddpm_epoch_100.pt", task_id=1, num_samples=4)

四、合成影像质量评估实现

1. FID计算实现
python 复制代码
import torch
import torch.nn as nn
from torchvision.models import inception_v3
import numpy as np
from scipy.linalg import sqrtm

class FIDEvaluator:
    def __init__(self, device):
        self.device = device
        # 加载预训练的InceptionV3模型
        self.inception = inception_v3(pretrained=True, transform_input=False).to(device)
        self.inception.eval()
        # 移除分类层,保留特征提取层
        self.feature_extractor = nn.Sequential(*list(self.inception.children())[:-1])
    
    def get_features(self, images):
        """提取图像特征(需将图像调整为3通道并归一化)"""
        # 转换为3通道(重复单通道)
        if images.shape[1] == 1:
            images = torch.cat([images]*3, dim=1)
        # 调整大小到InceptionV3输入尺寸
        images = F.interpolate(images, size=(299, 299), mode='bilinear', align_corners=False)
        # 归一化到[-1, 1]
        images = (images - 0.5) * 2.0
        
        with torch.no_grad():
            features = self.feature_extractor(images)
            features = features.view(features.shape[0], -1)
        
        return features.cpu().numpy()
    
    def calculate_fid(self, real_images, fake_images):
        """计算FID分数"""
        # 提取特征
        real_features = self.get_features(real_images.to(self.device))
        fake_features = self.get_features(fake_images.to(self.device))
        
        # 计算均值和协方差
        mu1, sigma1 = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
        mu2, sigma2 = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)
        
        # 计算FID
        ssdiff = np.sum((mu1 - mu2)**2.0)
        covmean = sqrtm(sigma1.dot(sigma2))
        
        # 处理复数结果
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        
        fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
        return fid

# 使用示例
def evaluate_synthetic_images():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    fid_evaluator = FIDEvaluator(device)
    
    # 加载真实影像和合成影像
    real_dataset = MedicalImageDataset("/data/lung_ct_dataset", img_size=128)
    real_loader = DataLoader(real_dataset, batch_size=16, shuffle=False)
    real_images = next(iter(real_loader))
    
    # 加载合成影像
    synth_images = []
    for i in range(16):
        img = Image.open(f"/samples/epoch_100_sample_{i%4}.png").convert('L')
        img = img.resize((128, 128))
        img = np.array(img).astype(np.float32) / 255.0
        synth_images.append(torch.from_numpy(img).unsqueeze(0))
    synth_images = torch.stack(synth_images)
    
    # 计算FID
    fid_score = fid_evaluator.calculate_fid(real_images, synth_images)
    print(f"FID Score: {fid_score:.2f}")
    
    # 记录到数据库
    db = MedicalAIDB()
    db.insert_evaluation_metric(task_id=1, metric_name="FID", metric_value=fid_score)

evaluate_synthetic_images()
2. SSIM计算实现
python 复制代码
import torch
import torch.nn.functional as F

def calculate_ssim(img1, img2, window_size=11, sigma=1.5):
    """计算结构相似性指数(SSIM)"""
    # 创建高斯核
    gaussian = torch.exp(-torch.arange(window_size)**2 / (2 * sigma**2))
    gaussian = gaussian / gaussian.sum()
    window_1d = gaussian.unsqueeze(1)
    window_2d = window_1d @ window_1d.t()
    window = window_2d.unsqueeze(0).unsqueeze(0).to(img1.device)
    
    # 计算均值
    mu1 = F.conv2d(img1, window, padding=window_size//2, groups=1)
    mu2 = F.conv2d(img2, window, padding=window_size//2, groups=1)
    
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    
    # 计算方差和协方差
    sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=1) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=1) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=1) - mu1_mu2
    
    # SSIM公式
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2
    
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

# 使用示例
def evaluate_ssim():
    # 加载真实影像和合成影像
    real_img = Image.open("/data/lung_ct_dataset/real_001.png").convert('L')
    synth_img = Image.open("/samples/epoch_100_sample_0.png").convert('L')
    
    # 预处理
    real_img = real_img.resize((128, 128))
    synth_img = synth_img.resize((128, 128))
    
    real_tensor = torch.from_numpy(np.array(real_img)).float().unsqueeze(0).unsqueeze(0) / 255.0
    synth_tensor = torch.from_numpy(np.array(synth_img)).float().unsqueeze(0).unsqueeze(0) / 255.0
    
    # 计算SSIM
    ssim_score = calculate_ssim(real_tensor, synth_tensor).item()
    print(f"SSIM Score: {ssim_score:.4f}")
    
    # 记录到数据库
    db = MedicalAIDB()
    db.insert_evaluation_metric(task_id=1, metric_name="SSIM", metric_value=ssim_score)

evaluate_ssim()

五、完整系统集成与任务调度

python 复制代码
import threading
import time
from queue import Queue

class SynthesisTaskScheduler:
    def __init__(self):
        self.task_queue = Queue()
        self.db = MedicalAIDB()
        self.storage = MedicalImageStorage("localhost:9000", "minio_access_key", "minio_secret_key")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # 加载预训练模型
        self.models = {
            "DDPM": self.load_model("/models/ddpm_epoch_100.pt"),
            "LDM": self.load_model("/models/ldm_lung_ct.pt")
        }
        
        # 启动任务处理线程
        self.worker_thread = threading.Thread(target=self.process_tasks, daemon=True)
        self.worker_thread.start()
    
    def load_model(self, model_path):
        """加载预训练模型"""
        if "ddpm" in model_path.lower():
            unet = UNet(in_channels=2, out_channels=1).to(self.device)
            model = DDPM(unet, n_T=1000).to(self.device)
        else:
            # 加载LDM模型(示例)
            from ldm.models.diffusion.ddpm import LatentDiffusion
            model = LatentDiffusion.load_from_checkpoint(model_path).to(self.device)
        
        model.load_state_dict(torch.load(model_path))
        model.eval()
        return model
    
    def submit_task(self, model_name, prompt, batch_size=1):
        """提交合成任务"""
        # 获取模型ID
        model_id = self.db.cursor.execute(
            "SELECT model_id FROM trained_models WHERE architecture_name = %s AND is_active = TRUE",
            (model_name,)
        ).fetchone()["model_id"]
        
        # 创建数据库任务记录
        task_id = self.db.create_synthesis_task(model_id, prompt, batch_size)
        
        # 添加到任务队列
        self.task_queue.put({
            "task_id": task_id,
            "model_name": model_name,
            "prompt": prompt,
            "batch_size": batch_size
        })
        
        return task_id
    
    def process_tasks(self):
        """处理合成任务"""
        while True:
            if not self.task_queue.empty():
                task = self.task_queue.get()
                task_id = task["task_id"]
                model_name = task["model_name"]
                batch_size = task["batch_size"]
                
                try:
                    # 更新任务状态为运行中
                    self.db.update_task_status(task_id, "RUNNING")
                    
                    # 获取模型并生成影像
                    model = self.models[model_name]
                    if model_name == "DDPM":
                        samples = model.sample(n_samples=batch_size, img_size=128, device=self.device)
                    else:
                        # LDM生成(需处理文本提示)
                        from ldm.util import instantiate_from_config
                        samples = model.sample(prompt=task["prompt"], batch_size=batch_size)
                    
                    # 保存生成的影像
                    for i in range(batch_size):
                        # 转换并保存为DICOM
                        img_array = (samples[i].squeeze().cpu().numpy() * 255).astype(np.uint16)
                        temp_path = f"/tmp/synth_{task_id}_{i}.dcm"
                        self.save_as_dicom(img_array, temp_path)
                        
                        # 上传到MinIO
                        s3_path = self.storage.upload_image(temp_path, f"task_{task_id}/image_{i}.dcm")
                        
                        # 记录到数据库
                        self.db.insert_synthetic_image(task_id, s3_path, seed=42+i)
                        
                        # 删除临时文件
                        os.remove(temp_path)
                    
                    # 评估合成影像质量
                    self.evaluate_task(task_id)
                    
                    # 更新任务状态为完成
                    self.db.update_task_status(task_id, "COMPLETED")
                    
                except Exception as e:
                    print(f"任务处理失败: {e}")
                    self.db.update_task_status(task_id, "FAILED")
                
                finally:
                    self.task_queue.task_done()
            
            time.sleep(1)
    
    def save_as_dicom(self, img_array, save_path):
        """将numpy数组保存为DICOM格式"""
        import pydicom
        from pydicom.dataset import Dataset, FileDataset
        
        file_meta = Dataset()
        file_meta.MediaStorageSOPClassUID = pydicom.uid.CTImageStorage
        file_meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid()
        file_meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian
        
        ds = FileDataset(save_path, {}, file_meta=file_meta, preamble=b"\0"*128)
        ds.Rows, ds.Columns = img_array.shape
        ds.PixelSpacing = [1.0, 1.0]
        ds.SliceThickness = 1.0
        ds.Modality = "CT"
        ds.PatientName = "Synthetic"
        ds.PatientID = "Synth-001"
        ds.PixelData = img_array.tobytes()
        
        ds.save_as(save_path)
    
    def evaluate_task(self, task_id):
        """评估任务生成的影像质量"""
        # 加载合成影像
        synth_images = []
        for i in range(4):
            temp_path = f"/tmp/synth_{task_id}_{i}.dcm"
            import pydicom
            dicom = pydicom.dcmread(temp_path)
            img = Image.fromarray(dicom.pixel_array).convert('L')
            img = img.resize((128, 128))
            img = np.array(img).astype(np.float32) / 255.0
            synth_images.append(torch.from_numpy(img).unsqueeze(0))
        synth_images = torch.stack(synth_images)
        
        # 加载真实影像作为参考
        real_dataset = MedicalImageDataset("/data/lung_ct_dataset", img_size=128)
        real_loader = DataLoader(real_dataset, batch_size=4, shuffle=False)
        real_images = next(iter(real_loader))
        
        # 计算FID和SSIM
        fid_evaluator = FIDEvaluator(self.device)
        fid_score = fid_evaluator.calculate_fid(real_images, synth_images)
        ssim_score = calculate_ssim(real_images, synth_images).item()
        
        # 记录评估结果
        self.db.insert_evaluation_metric(task_id, "FID", fid_score)
        self.db.insert_evaluation_metric(task_id, "SSIM", ssim_score)

# 启动任务调度器
scheduler = SynthesisTaskScheduler()

# 提交合成任务
task_id = scheduler.submit_task(
    model_name="DDPM",
    prompt="Lung CT with 5mm nodule in upper lobe",
    batch_size=4
)
print(f"提交合成任务ID: {task_id}")

以上代码实现了从数据库架构设计、医学影像合成模型训练、合成任务调度到质量评估的完整流程,涵盖了医疗AI领域合成数据生成的核心技术环节。代码可根据实际需求进一步优化,例如添加分布式训练支持、优化DICOM元数据配置、集成更多评估指标等。

相关推荐
AI弟1 小时前
第13章 迁移学习:让AI学会“举一反三“的艺术
人工智能·机器学习·迁移学习
ccLianLian1 小时前
MaskCLIP+
人工智能·计算机视觉
艾莉丝努力练剑1 小时前
【C++:C++11收尾】解构C++可调用对象:从入门到精通,掌握function包装器与bind适配器包装器详解
java·开发语言·c++·人工智能·c++11·右值引用
卿雪1 小时前
MySQL【索引】篇:索引的分类、B+树、创建索引的原则、索引失效的情况...
java·开发语言·数据结构·数据库·b树·mysql·golang
cipher1 小时前
删库之夜V2·天网恢恢
服务器·数据库·git
a***56061 小时前
【Navicat+MySQL】 在Navicat内创建管理数据库、数据库表。
数据库·mysql·oracle
赵渝强老师1 小时前
【赵渝强老师】PostgreSQL锁的类型
数据库·postgresql
一点事1 小时前
oracle:密码过期处理
数据库·oracle
李景琰1 小时前
Java 25+AI+物联网+区块链融合平台:架构设计与企业级实现
java·人工智能·物联网·区块链