伏羲0.06(文生图)

主要改进点:

数据准备:

数据清洗:增加了 clean_data 函数,用于去除空值和重复值。

数据增强:增加了 augment_data 函数,用于在训练时进行数据增强。

模型选择:

生成对抗网络 (GAN):增加了 Discriminator 类,用于判别生成的图像是否真实。

损失函数:增加了 GAN 损失和 L1 损失,用于训练生成器和判别器。

模型架构设计:

文本编码器:使用预训练的 Transformer 模型(如 BERT)来编码文本描述。

图像生成器:增加了更多的卷积转置层,并使用了批量归一化和激活函数。

多模态融合:将文本特征和图像特征进行有效融合,确保生成的图像与文本描述一致。

训练过程:

损失函数:使用 GAN 损失和 L1 损失,分别用于训练生成器和判别器。

优化算法:使用 Adam 优化器。

训练策略:使用批量归一化、梯度裁剪等技术来稳定训练过程。

正则化:防止过拟合,可以使用 L1/L2 正则化、Dropout 等技术。

python 复制代码
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import yaml
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import random
import numpy as np

# 配置文件加载
def load_config(config_path):
    with open(config_path, 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)
    return config

# 数据加载
def load_text_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        text_data = file.readlines()
    return [line.strip() for line in text_data]

# 数据清洗
def clean_data(data):
    # 这里可以添加更多的数据清洗逻辑
    return data.dropna().drop_duplicates()

# 数据增强
def augment_data(image, mode):
    if mode == 'train':
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    return transform(image)

# 文本编码器
class TextEncoder(nn.Module):
    def __init__(self, model_name):
        super(TextEncoder, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def forward(self, text):
        inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
        outputs = self.model(**inputs)
        return outputs.last_hidden_state.mean(dim=1)

# 图像生成器
class ImageGenerator(nn.Module):
    def __init__(self, in_channels):
        super(ImageGenerator, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(-1, x.size(1), 1, 1)
        return self.decoder(x)

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

# 模型定义
class TextToImageModel(nn.Module):
    def __init__(self, text_encoder_model_name):
        super(TextToImageModel, self).__init__()
        self.text_encoder = TextEncoder(text_encoder_model_name)
        self.image_generator = ImageGenerator(768)  # 768 is the hidden size of BERT

    def forward(self, text):
        text_features = self.text_encoder(text)
        return self.image_generator(text_features)

# 模型加载
def load_model(model_path, text_encoder_model_name):
    model = TextToImageModel(text_encoder_model_name)
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

# 图像保存
def save_image(image, path):
    if not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    image.save(path)

# 数据集类
class TextToImageDataset(Dataset):
    def __init__(self, csv_file, transform=None, mode='train'):
        self.data = pd.read_csv(csv_file)
        self.data = clean_data(self.data)
        self.transform = transform
        self.mode = mode

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.iloc[idx]['text']
        image_path = self.data.iloc[idx]['image_path']
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image, self.mode)
        return text, image

# 模型训练
def train_model(config):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = TextToImageDataset(config['training']['dataset_path'], transform=augment_data, mode='train')
    dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True)

    model = TextToImageModel(config['model']['text_encoder_model_name'])
    discriminator = Discriminator()

    optimizer_g = optim.Adam(model.parameters(), lr=config['training']['learning_rate'])
    optimizer_d = optim.Adam(discriminator.parameters(), lr=config['training']['learning_rate'])

    criterion_gan = nn.BCELoss()
    criterion_l1 = nn.L1Loss()

    for epoch in range(config['training']['epochs']):
        model.train()
        discriminator.train()
        running_loss_g = 0.0
        running_loss_d = 0.0

        for i, (text, images) in enumerate(dataloader):
            real_labels = torch.ones(images.size(0), 1)
            fake_labels = torch.zeros(images.size(0), 1)

            # Train Discriminator
            optimizer_d.zero_grad()
            real_outputs = discriminator(images)
            d_loss_real = criterion_gan(real_outputs, real_labels)

            generated_images = model(text)
            fake_outputs = discriminator(generated_images.detach())
            d_loss_fake = criterion_gan(fake_outputs, fake_labels)

            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            optimizer_d.step()

            # Train Generator
            optimizer_g.zero_grad()
            generated_images = model(text)
            g_outputs = discriminator(generated_images)
            g_loss_gan = criterion_gan(g_outputs, real_labels)
            g_loss_l1 = criterion_l1(generated_images, images)
            g_loss = g_loss_gan + 100 * g_loss_l1  # Weighted sum of GAN loss and L1 loss
            g_loss.backward()
            optimizer_g.step()

            running_loss_g += g_loss.item()
            running_loss_d += d_loss.item()

        print(f"Epoch {epoch + 1}, Generator Loss: {running_loss_g / len(dataloader)}, Discriminator Loss: {running_loss_d / len(dataloader)}")

    # 保存训练好的模型
    torch.save(model.state_dict(), config['model']['path'])

# 图像生成
def generate_images(model, text_data, output_dir):
    for text in text_data:
        input_tensor = model.text_encoder([text])
        image = model.image_generator(input_tensor)
        image = image.squeeze(0).detach().cpu().numpy()
        image = (image * 127.5 + 127.5).astype('uint8')
        image = Image.fromarray(image.transpose(1, 2, 0))

        # 保存图像
        save_image(image, f"{output_dir}/{text}.png")

# 图形用户界面
class TextToImageGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("文本生成图像")
        self.config = load_config('config.yaml')
        self.model = load_model(self.config['model']['path'], self.config['model']['text_encoder_model_name'])

        self.text_input = tk.Text(root, height=10, width=50)
        self.text_input.pack(pady=10)

        self.train_button = tk.Button(root, text="训练模型", command=self.train_model)
        self.train_button.pack(pady=10)

        self.generate_button = tk.Button(root, text="生成图像", command=self.generate_image)
        self.generate_button.pack(pady=10)

        self.image_label = tk.Label(root)
        self.image_label.pack(pady=10)

    def train_model(self):
        train_model(self.config)
        self.model = load_model(self.config['model']['path'], self.config['model']['text_encoder_model_name'])
        messagebox.showinfo("成功", "模型训练完成")

    def generate_image(self):
        text = self.text_input.get("1.0", tk.END).strip()
        if not text:
            messagebox.showwarning("警告", "请输入文本")
            return

        input_tensor = self.model.text_encoder([text])
        image = self.model.image_generator(input_tensor)
        image = image.squeeze(0).detach().cpu().numpy()
        image = (image * 127.5 + 127.5).astype('uint8')
        image = Image.fromarray(image.transpose(1, 2, 0))

        # 显示图像
        img_tk = ImageTk.PhotoImage(image)
        self.image_label.config(image=img_tk)
        self.image_label.image = img_tk

        # 保存图像
        save_image(image, f"{self.config['data']['output_dir']}/{text}.png")
        messagebox.showinfo("成功", "图像已生成并保存")

if __name__ == "__main__":
    config = load_config('config.yaml')

    # 加载模型
    model = load_model(config['model']['path'], config['model']['text_encoder_model_name'])

    # 加载文本数据
    text_data = load_text_data(config['data']['input_file'])

    # 生成图像
    generate_images(model, text_data, config['data']['output_dir'])

    # 启动图形用户界面
    root = tk.Tk()
    app = TextToImageGUI(root)
    root.mainloop()

希望这些改进能帮助你更好地实现文本生成图像的功能。如果有任何问题或需要进一步的帮助,请随时告诉我!

相关推荐
cooldream200932 分钟前
华为云Flexus+DeepSeek征文|基于华为云Flexus X和DeepSeek-R1打造个人知识库问答系统
人工智能·华为云·dify
Blossom.1184 小时前
使用Python和Scikit-Learn实现机器学习模型调优
开发语言·人工智能·python·深度学习·目标检测·机器学习·scikit-learn
scdifsn5 小时前
动手学深度学习12.7. 参数服务器-笔记&练习(PyTorch)
pytorch·笔记·深度学习·分布式计算·数据并行·参数服务器
DFminer5 小时前
【LLM】fast-api 流式生成测试
人工智能·机器人
郄堃Deep Traffic5 小时前
机器学习+城市规划第十四期:利用半参数地理加权回归来实现区域带宽不同的规划任务
人工智能·机器学习·回归·城市规划
海盗儿6 小时前
Attention Is All You Need (Transformer) 以及Transformer pytorch实现
pytorch·深度学习·transformer
GIS小天6 小时前
AI+预测3D新模型百十个定位预测+胆码预测+去和尾2025年6月7日第101弹
人工智能·算法·机器学习·彩票
阿部多瑞 ABU6 小时前
主流大语言模型安全性测试(三):阿拉伯语越狱提示词下的表现与分析
人工智能·安全·ai·语言模型·安全性测试
cnbestec6 小时前
Xela矩阵三轴触觉传感器的工作原理解析与应用场景
人工智能·线性代数·触觉传感器
不爱写代码的玉子6 小时前
HALCON透视矩阵
人工智能·深度学习·线性代数·算法·计算机视觉·矩阵·c#