VAE模型实现学习

python 复制代码
import os
from datetime import datetime
import glob
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils

# ========== Paths ==========
data_root = "./faces/img_align_celeba/img_align_celeba"
save_dir = "./vae_test_checkpoints"
os.makedirs(save_dir, exist_ok=True)

batch_size = 128
lr = 2e-4
num_epochs = 300
latent_dim = 128
sample_every = 5
num_sample_images = 5
image_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tranform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
])

class CelebADataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.paths = glob.glob(os.path.join(root, "*.jpg"))
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img_path = self.paths[index]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img
dataset = CelebADataset(data_root, transform=tranform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

class ConvVAE(nn.Module):
    def __init__(self, latent_dim=128, ch=64, image_size=64):
        super().__init__()
        self.latent_dim = latent_dim
        self.ch = ch
        self.image_size = image_size

        # encoder
        self.enc = nn.Sequential(
            nn.Conv2d(3, ch, 4, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(ch, ch * 2, 4, 2, 1),
            nn.BatchNorm2d(ch * 2),
            nn.ReLU(True),
            nn.Conv2d(ch * 2, ch * 4, 4, 2, 1),
            nn.BatchNorm2d(ch * 4),
            nn.ReLU(True),
            nn.Conv2d(ch * 4, ch * 8, 4, 2, 1),
            nn.BatchNorm2d(ch * 8),
            nn.ReLU(True),
        )

        with torch.no_grad():
            dummy = torch.zeros(1, 3, image_size, image_size)
            feat_dim = self.enc(dummy).view(1, -1).size(1)

        self.fc_mu = nn.Linear(feat_dim, latent_dim)
        self.fc_logvar = nn.Linear(feat_dim, latent_dim)
        self.fc_dec = nn.Linear(latent_dim, feat_dim)
        self._feat_shape = self.enc(dummy).shape[1:]

        # Decoder with Sigmoid output
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(ch * 8, ch * 4, 4, 2, 1),
            nn.BatchNorm2d(ch * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ch * 4, ch * 2, 4, 2, 1),
            nn.BatchNorm2d(ch * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ch * 2, ch, 4, 2, 1),
            nn.BatchNorm2d(ch),
            nn.ReLU(True),
            nn.ConvTranspose2d(ch, 3, 4, 2, 1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.enc(x)
        h = h.view(h.size(0), -1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_dec(z)
        h = h.view(h.size(0), *self._feat_shape)
        x_recon = self.dec(h)
        return x_recon

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

def vae_loss(x, x_recon, mu, logvar):
    recon_loss = F.mse_loss(x_recon, x, reduction='sum')
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld, recon_loss, kld

def save_checkpoint(model, optim, epoch, path):
    state = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optim_state":optim.state_dict(),
    }
    torch.save(state, path)

def save_image_grid(tensor, filename, nrow = 8):
    tensor = torch.clamp(tensor,0,1)
    utils.save_image(tensor, filename, nrow=nrow, padding=2)

def train():
    model = ConvVAE(latent_dim = latent_dim, image_size = image_size).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    global_step = 0
    for epoch in range(1, num_epochs + 1):
        model.train()
        epoch_loss = 0.0
        epoch_recon = 0.0
        epoch_kld = 0.0

        for batch_idx, imgs in enumerate(loader):
            imgs = imgs.to(device, non_blocking=True)
            optim.zero_grad()
            recon_imgs, mu, logvar = model(imgs)
            loss, recon_l, kld = vae_loss(imgs, recon_imgs, mu, logvar)
            loss.backward()
            optim.step()


            epoch_loss += loss.item()
            epoch_recon += recon_l.item()
            epoch_kld += kld.item()
            global_step += 1

            if batch_idx % 100 == 0:
                print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
                      f"Epoch {epoch}/{num_epochs} Batch {batch_idx}/{len(loader)} "
                      f"Loss {loss.item():.4f} "
                      f"(recon {recon_l.item():.4f}, kld {kld.item():.4f})")

        n_samples = len(loader.dataset)
        print(f"=== Epoch {epoch} finished. Avg loss: {epoch_loss / n_samples:.4f} "
              f"(recon {epoch_recon / n_samples:.4f}, kld {epoch_kld / n_samples:.4f}) ===")

        if epoch % sample_every == 0 or epoch == 1:
            ckpt = os.path.join(save_dir, f"vae_epoch{epoch}.pth")
            save_checkpoint(model, optim, epoch, ckpt)

            model.eval()
            with torch.no_grad():
                imgs = next(iter(loader))
                imgs = imgs.to(device)[:num_sample_images]
                recon_imgs,_,_ = model(imgs)
                combined = torch.cat([imgs, recon_imgs], dim=0)  # 不再需要clamp
                save_image_grid(combined, os.path.join(save_dir, f"recon_epoch{epoch}.png"), nrow=8)

                # 生成样本
                z = torch.randn(num_sample_images,latent_dim).to(device)
                samples = model.decode(z)
                save_image_grid(samples, os.path.join(save_dir, f"sample_epoch{epoch}.png"), nrow=8)
    print("Training complete.")

if __name__ == "__main__":
    print("Starting training on device:", device)
    print("Dataset size:", len(dataset))
    train()

来源:https://www.bilibili.com/video/BV1xFxMz1EMS/?spm_id_from=333.337.search-card.all.click\&vd_source=cc430eea67dc70ace8ab1fd4f7c63123