[特殊字符]用于糖尿病视网膜病变图像生成的生成对抗网络(GAN)

本项目基于 TensorFlow 构建了一个生成对抗网络(GAN),用于生成逼真的糖尿病视网膜病变(DR)图像。该 GAN 包含一个基于残差网络(ResNet)的生成器(Generator)和判别器(Discriminator)。

1. 数据集
  • 来源:Kaggle 糖尿病视网膜病变数据集
  • 结构:
Go 复制代码
/diabetic-retinopathy/
    0/  -> 健康眼(无病变)
    1/  -> 轻度糖尿病视网膜病变(mild DR)
    2/  -> 中度糖尿病视网膜病变(moderate DR)
    3/  -> 重度糖尿病视网膜病变(severe DR)
    4/  -> 增殖性糖尿病视网膜病变(proliferative DR)
  • 划分比例:60% 训练集、20% 验证集、20% 测试集
  • 变量:训练集路径(train_paths)、训练集标签(train_labels)、验证集路径(val_paths)、验证集标签(val_labels)、测试集路径(test_paths)、测试集标签(test_labels)
  • 图像预处理:调整尺寸为 128×128×3(宽 × 高 × 通道数),并归一化至 [-1, 1] 区间

2. Preprocessing

Go 复制代码
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    img = (tf.cast(img, tf.float32) - 127.5) / 127.5
    return img

train_dataset = tf.data.Dataset.from_tensor_slices(train_paths)
train_dataset = train_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
3. GAN 架构
  • 生成器(ResNet 风格) 输入随机噪声向量(噪声维度 NOISE_DIM = 100),通过转置卷积(Conv2DTranspose)+ 批归一化(BatchNorm)+ 带泄漏的 ReLU 激活函数(LeakyReLU)进行上采样,最终使用 tanh 激活函数输出 128×128×3 的图像。代码定义:generator = make_generator_model()

  • 判别器(ResNet 风格) 输入 128×128×3 的图像,通过残差块(Residual Blocks)+ 平均池化(AveragePooling2D)进行下采样,最终使用 sigmoid 激活函数输出单个值(表示图像为真实 / 伪造的概率)。代码定义:discriminator = build_discriminator(img_shape=(128,128,3))

4. 损失函数与优化器
Go 复制代码
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

5. Training Step

复制代码
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
        real_loss = cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
        disc_loss = (real_loss + fake_loss) / 2

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

6. Training Loop

复制代码
seed = tf.random.normal([NUM_EXAMPLES, NOISE_DIM])

for epoch in range(1, EPOCHS + 1):
    for image_batch in train_dataset:
        gen_loss, disc_loss = train_step(image_batch)
    print(f"Epoch {epoch}, Generator Loss: {gen_loss:.4f}, Discriminator Loss: {disc_loss:.4f}")
    generate_and_save_images(generator, epoch, seed)

7. Visualization

  • Generated images: after each epoch

  • Discriminator confidence: green = real, red = fake

    preds = discriminator(fake_images, training=False).numpy()
    pred_value = preds[i][0]
    color = 'green' if pred_value > 0.5 else 'red'

8. Testing / Evaluation

  • GANs do not have "accuracy" like classifiers
  • Check visual quality of generated images
  • Discriminator confidence (0--1) for generated images
  • Optionally compute FID (Fréchet Inception Distance) for quantitative evaluation

Note: You can adjust IMG_SIZE, BATCH_SIZE, and NOISE_DIM according to GPU capacity.

Go 复制代码
Needed Library¶
import os ,glob
import random
import warnings


import numpy as np # linear algebra
import pandas as pd # data processing
import matplotlib.pyplot as plt
from PIL import Image
from IPython.display import Image

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

import tensorflow as tf 
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from sklearn.utils import shuffle
from tensorflow.keras import layers
from tensorflow.keras.utils import plot_model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing import image


print("Library Called ")
2025-10-18 20:03:31.002983: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1760817811.235525      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760817811.301397      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Library Called 
Split Data & Create Class
# Path of Data File
path='/kaggle/input/diabetic-retinopathy'
if not os.path.exists(path):
    print("Dataset folder not found at:", path)
    print("Download and extract the Kaggle dataset to that path.")
    raise SystemExit(1)

#discover classes and image paths
class_dirs = sorted([d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))])
if not class_dirs:
    print("No class subfolders found. Expect: data/diabetic_retinopathy/<class>/*.jpg")
    raise SystemExit(1)

print("Found classes:", class_dirs)

# List after split (images , Labels)
image_paths = [] 
labels = []
for idx, cls in enumerate(class_dirs):
    files = glob.glob(os.path.join(path, cls, "*"))
    files = [f for f in files if os.path.splitext(f)[1].lower() in (".jpg", ".jpeg", ".png")]
    for f in files:
        image_paths.append(f)
        labels.append(idx)

print("Total images:", len(image_paths))
if len(image_paths) == 0:
    print("No images found - check paths.")
    raise SystemExit(1)

#Combination and Random distribution
combined=list(zip(image_paths,labels))
random.seed(42)
random.shuffle(combined)
image_paths,labels = zip(*combined)
image_paths=list(image_paths)
labels=list(labels)
n =len(image_paths)

# ٍSplitting Fractions
train_frac = 0.60
val_frac = 0.20 
n_train = int(n* train_frac)
n_val = int(n*val_frac)
n_test = n - n_train - n_val

train_paths = image_paths[:n_train]
train_labels =labels[:n_train]
val_paths = image_paths[n_train:n_train+n_val]
val_labels = labels[n_train:n_train+n_val]
test_paths = image_paths[n_train+n_val:]
test_labels = labels[n_train+n_val:]
print("Split sizes -> train:", len(train_paths), "val:", len(val_paths), "test:", len(test_paths))

dist = pd.DataFrame({
    "train": pd.Series(train_labels).value_counts().sort_index(),
    "val": pd.Series(val_labels).value_counts().sort_index(),
    "test": pd.Series(test_labels).value_counts().sort_index()
}).fillna(0).astype(int)
print("\nClass distribution (by index):\n", dist)
Found classes: ['Healthy', 'Mild DR', 'Moderate DR', 'Proliferate DR', 'Severe DR']
Total images: 2750
Split sizes -> train: 1650 val: 550 test: 550

Class distribution (by index):
    train  val  test
0    594  200   206
1    212   77    81
2    547  183   170
3    181   55    54
4    116   35    39
Parameters Gan Model
# Parameters
BUFFER_SIZE = 60000
BATCH_SIZE = 128
NOISE_DIM = 100
# Hyperparameters
LATENT_DIM = 100
EPOCHS = 80
NUM_EXAMPLES = 16
IMG_SIZE = (128, 128)
Convert them into a TensorFlow dataset
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMG_SIZE)
    img = (tf.cast(img, tf.float32) - 127.5) / 127.5  # Normalize [-1,1]
    return img
Building PipeLine
# Build tf.data pipeline correctly
train_dataset = tf.data.Dataset.from_tensor_slices(train_paths)
train_dataset = train_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_dataset = tf.data.Dataset.from_tensor_slices(val_paths)
val_dataset = val_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices(test_paths)
test_dataset = test_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
I0000 00:00:1760817824.458397      37 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1760817824.459131      37 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
for images in train_dataset.take(1):  # take one batch
    plt.figure(figsize=(10, 10))
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow((images[i] * 127.5 + 127.5).numpy().astype("uint8"))
        plt.axis("off")
构建生成器 ¶

构建生成器有两种方案(基于残差网络(ResNet)或从零搭建生成器),两者的差异如下表所示:

概念 DCGAN 生成器(你的代码) 基于 ResNet 的生成器(我的代码)
架构类型 转置卷积层(Conv2DTranspose)的顺序堆叠 带跳跃连接(skip connections)的残差块(Residual blocks)
数据流 简单的上采样卷积堆叠 每个块学习残差特征,并与输入相加反馈
跳跃连接 ❌ 无跳跃连接(朴素架构) ✅ 采用 x + F (x) 结构,保障梯度流动
梯度流动 训练过程中可能出现梯度消失或梯度爆炸(处理复杂数据时稳定性较差) 稳定性更强(梯度可通过跳跃连接 "shortcut" 传递)
图像质量 在简单数据集(如 MNIST、CIFAR)上表现良好 更适用于高分辨率或复杂图像(如视网膜图像)
训练稳定性 中等;可能需要调整学习率 更高;通常收敛更平滑
模型灵活性 易于实现;参数量更少 性能更强,但模型稍厚重
适用场景 适用于小型、低分辨率数据集训练 适用于医疗图像、高真实感图像或高分辨率数据集训练
Go 复制代码
Frist Build The Block¶
def residual_block(x, filters):
    shortcut = x
    x = layers.Conv2D(filters, kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.add([shortcut, x])
    x = layers.ReLU()(x)
    return x
Sec Build The Generator (Takes Noise Paramter)
def build_generator(noise_dim):
    inputs = layers.Input(shape=(noise_dim,))
    x = layers.Dense(8 * 8 * 256, use_bias=False)(inputs)
    x = layers.Reshape((8, 8, 256))(x)
    x = layers.ReLU()(x)
    
    # Residual blocks
    for _ in range(3):
        x = residual_block(x, 256)
    
    # Upsampling layers
    x = layers.Conv2DTranspose(128, 4, strides=2, padding='same')(x)
    x = layers.ReLU()(x)
    x = layers.Conv2DTranspose(64, 4, strides=2, padding='same')(x)
    x = layers.ReLU()(x)
    x = layers.Conv2DTranspose(32, 4, strides=2, padding='same')(x)
    x = layers.ReLU()(x)
    
    # Output
    outputs = layers.Conv2D(3, (7,7), padding='same', activation='tanh')(x)
    
    return Model(inputs, outputs, name="generator")
generator = build_generator(NOISE_DIM)
generator.summary()
Model: "generator"
复制代码
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input_layer         │ (None, 100)       │          0 │ -                 │
│ (InputLayer)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense (Dense)       │ (None, 16384)     │  1,638,400 │ input_layer[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ reshape (Reshape)   │ (None, 8, 8, 256) │          0 │ dense[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu (ReLU)        │ (None, 8, 8, 256) │          0 │ reshape[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d (Conv2D)     │ (None, 8, 8, 256) │    590,080 │ re_lu[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalization │ (None, 8, 8, 256) │      1,024 │ conv2d[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu_1 (ReLU)      │ (None, 8, 8, 256) │          0 │ batch_normalizat… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_1 (Conv2D)   │ (None, 8, 8, 256) │    590,080 │ re_lu_1[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 8, 8, 256) │      1,024 │ conv2d_1[0][0]    │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 8, 8, 256) │          0 │ re_lu[0][0],      │
│                     │                   │            │ batch_normalizat… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu_2 (ReLU)      │ (None, 8, 8, 256) │          0 │ add[0][0]         │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_2 (Conv2D)   │ (None, 8, 8, 256) │    590,080 │ re_lu_2[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 8, 8, 256) │      1,024 │ conv2d_2[0][0]    │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu_3 (ReLU)      │ (None, 8, 8, 256) │          0 │ batch_normalizat… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_3 (Conv2D)   │ (None, 8, 8, 256) │    590,080 │ re_lu_3[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 8, 8, 256) │      1,024 │ conv2d_3[0][0]    │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_1 (Add)         │ (None, 8, 8, 256) │          0 │ re_lu_2[0][0],    │
│                     │                   │            │ batch_normalizat… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu_4 (ReLU)      │ (None, 8, 8, 256) │          0 │ add_1[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_4 (Conv2D)   │ (None, 8, 8, 256) │    590,080 │ re_lu_4[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 8, 8, 256) │      1,024 │ conv2d_4[0][0]    │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu_5 (ReLU)      │ (None, 8, 8, 256) │          0 │ batch_normalizat… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_5 (Conv2D)   │ (None, 8, 8, 256) │    590,080 │ re_lu_5[0][0]     │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 8, 8, 256) │      1,024 │ conv2d_5[0][0]    │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_2 (Add)         │ (None, 8, 8, 256) │          0 │ re_lu_4[0][0],    │
│                     │                   │            │ batch_normalizat… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu_6 (ReLU)      │ (None, 8, 8, 256) │          0 │ add_2[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_transpose    │ (None, 16, 16,    │    524,416 │ re_lu_6[0][0]     │
│ (Conv2DTranspose)   │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu_7 (ReLU)      │ (None, 16, 16,    │          0 │ conv2d_transpose… │
│                     │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_transpose_1  │ (None, 32, 32,    │    131,136 │ re_lu_7[0][0]     │
│ (Conv2DTranspose)   │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu_8 (ReLU)      │ (None, 32, 32,    │          0 │ conv2d_transpose… │
│                     │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_transpose_2  │ (None, 64, 64,    │     32,800 │ re_lu_8[0][0]     │
│ (Conv2DTranspose)   │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ re_lu_9 (ReLU)      │ (None, 64, 64,    │          0 │ conv2d_transpose… │
│                     │ 32)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_6 (Conv2D)   │ (None, 64, 64, 3) │      4,707 │ re_lu_9[0][0]     │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
Go 复制代码
Total params: 5,878,083 (22.42 MB)
 Trainable params: 5,875,011 (22.41 MB)
 Non-trainable params: 3,072 (12.00 KB)

🧱Build the Discriminator

Go 复制代码
def residual_block(x, filters, downsample=False):
    shortcut = x

    # First convolution
    x = layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    # Second convolution
    x = layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)

    # Optionally downsample
    if downsample:
        x = layers.AveragePooling2D(pool_size=(2, 2))(x)
        shortcut = layers.Conv2D(filters, (1, 1), padding='same')(shortcut)
        shortcut = layers.AveragePooling2D(pool_size=(2, 2))(shortcut)

    # Add skip connection
    x = layers.add([x, shortcut])
    x = layers.LeakyReLU(alpha=0.2)(x)
    return x

def build_discriminator(img_shape=(128, 128, 3)):
    inp = layers.Input(shape=img_shape)

    x = layers.Conv2D(64, (7,7), strides=2, padding='same')(inp)
    x = layers.LeakyReLU(alpha=0.2)(x)

    # Residual blocks with downsampling
    for filters in [64, 128, 256]:
        x = residual_block(x, filters, downsample=True)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(1, activation='sigmoid')(x)

    model = tf.keras.Model(inp, x, name="Discriminator")
    return model

discriminator = build_discriminator()
discriminator.summary()
复制代码
/usr/local/lib/python3.11/dist-packages/keras/src/layers/activations/leaky_relu.py:41: UserWarning: Argument `alpha` is deprecated. Use `negative_slope` instead.
  warnings.warn(
复制代码
Model: "Discriminator"
复制代码
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)        ┃ Output Shape      ┃    Param # ┃ Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input_layer_1       │ (None, 128, 128,  │          0 │ -                 │
│ (InputLayer)        │ 3)                │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_7 (Conv2D)   │ (None, 64, 64,    │      9,472 │ input_layer_1[0]… │
│                     │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ leaky_re_lu         │ (None, 64, 64,    │          0 │ conv2d_7[0][0]    │
│ (LeakyReLU)         │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_8 (Conv2D)   │ (None, 64, 64,    │     36,928 │ leaky_re_lu[0][0] │
│                     │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 64, 64,    │        256 │ conv2d_8[0][0]    │
│ (BatchNormalizatio… │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ leaky_re_lu_1       │ (None, 64, 64,    │          0 │ batch_normalizat… │
│ (LeakyReLU)         │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_9 (Conv2D)   │ (None, 64, 64,    │     36,928 │ leaky_re_lu_1[0]… │
│                     │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 64, 64,    │        256 │ conv2d_9[0][0]    │
│ (BatchNormalizatio… │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_10 (Conv2D)  │ (None, 64, 64,    │      4,160 │ leaky_re_lu[0][0] │
│                     │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ average_pooling2d   │ (None, 32, 32,    │          0 │ batch_normalizat… │
│ (AveragePooling2D)  │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ average_pooling2d_1 │ (None, 32, 32,    │          0 │ conv2d_10[0][0]   │
│ (AveragePooling2D)  │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_3 (Add)         │ (None, 32, 32,    │          0 │ average_pooling2… │
│                     │ 64)               │            │ average_pooling2… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ leaky_re_lu_2       │ (None, 32, 32,    │          0 │ add_3[0][0]       │
│ (LeakyReLU)         │ 64)               │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_11 (Conv2D)  │ (None, 32, 32,    │     73,856 │ leaky_re_lu_2[0]… │
│                     │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 32, 32,    │        512 │ conv2d_11[0][0]   │
│ (BatchNormalizatio… │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ leaky_re_lu_3       │ (None, 32, 32,    │          0 │ batch_normalizat… │
│ (LeakyReLU)         │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_12 (Conv2D)  │ (None, 32, 32,    │    147,584 │ leaky_re_lu_3[0]… │
│                     │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 32, 32,    │        512 │ conv2d_12[0][0]   │
│ (BatchNormalizatio… │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_13 (Conv2D)  │ (None, 32, 32,    │      8,320 │ leaky_re_lu_2[0]… │
│                     │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ average_pooling2d_2 │ (None, 16, 16,    │          0 │ batch_normalizat… │
│ (AveragePooling2D)  │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ average_pooling2d_3 │ (None, 16, 16,    │          0 │ conv2d_13[0][0]   │
│ (AveragePooling2D)  │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_4 (Add)         │ (None, 16, 16,    │          0 │ average_pooling2… │
│                     │ 128)              │            │ average_pooling2… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ leaky_re_lu_4       │ (None, 16, 16,    │          0 │ add_4[0][0]       │
│ (LeakyReLU)         │ 128)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_14 (Conv2D)  │ (None, 16, 16,    │    295,168 │ leaky_re_lu_4[0]… │
│                     │ 256)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 16, 16,    │      1,024 │ conv2d_14[0][0]   │
│ (BatchNormalizatio… │ 256)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ leaky_re_lu_5       │ (None, 16, 16,    │          0 │ batch_normalizat… │
│ (LeakyReLU)         │ 256)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_15 (Conv2D)  │ (None, 16, 16,    │    590,080 │ leaky_re_lu_5[0]… │
│                     │ 256)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ batch_normalizatio… │ (None, 16, 16,    │      1,024 │ conv2d_15[0][0]   │
│ (BatchNormalizatio… │ 256)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ conv2d_16 (Conv2D)  │ (None, 16, 16,    │     33,024 │ leaky_re_lu_4[0]… │
│                     │ 256)              │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ average_pooling2d_4 │ (None, 8, 8, 256) │          0 │ batch_normalizat… │
│ (AveragePooling2D)  │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ average_pooling2d_5 │ (None, 8, 8, 256) │          0 │ conv2d_16[0][0]   │
│ (AveragePooling2D)  │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_5 (Add)         │ (None, 8, 8, 256) │          0 │ average_pooling2… │
│                     │                   │            │ average_pooling2… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ leaky_re_lu_6       │ (None, 8, 8, 256) │          0 │ add_5[0][0]       │
│ (LeakyReLU)         │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ global_average_poo… │ (None, 256)       │          0 │ leaky_re_lu_6[0]… │
│ (GlobalAveragePool… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_1 (Dense)     │ (None, 1)         │        257 │ global_average_p… │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
复制代码
 Total params: 1,239,361 (4.73 MB)
复制代码
 Trainable params: 1,237,569 (4.72 MB)
复制代码
 Non-trainable params: 1,792 (7.00 KB)
Go 复制代码
Define Loss and Optimizers¶
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

generator_optimizer = tf.keras.optimizers.Adam(1e-4) # => for generator
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) # => for Discriminator
Show Random Samples from Your Training Data
train_dataset = tf.data.Dataset.from_tensor_slices(train_paths)
train_dataset = train_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# Take one batch from dataset
for image_batch in train_dataset.take(1):
    sample_images = image_batch[:16]  # first 16 images
    break

# Denormalize from [-1,1] → [0,1] for display
sample_images = (sample_images + 1) / 2.0  

# Plot them
plt.figure(figsize=(8, 8))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(sample_images[i])
    plt.axis("off")
plt.suptitle("Random Samples from Training Data", fontsize=14)
plt.show()
Go 复制代码
Generator visualizie¶
# visualize your generator
plot_model(
    generator,
    to_file="generator_model.png",   # will save as an image in Kaggle output
    show_shapes=True,                # show layer input/output shapes
    show_layer_names=True,           # show layer names
    expand_nested=True,              # expand residual blocks
    dpi=96
)
Image(filename='generator_model.png')

solve the problem

复制代码
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
print(generated_image.shape)
复制代码
I0000 00:00:1760817835.976041      37 cuda_dnn.cc:529] Loaded cuDNN version 90300
复制代码
(1, 64, 64, 3)
复制代码
layers.Conv2DTranspose(3, (5,5), strides=(2,2), padding='same', activation='tanh')
复制代码
<Conv2DTranspose name=conv2d_transpose_3, built=False>

This version outputs (128, 128, 3) --- perfect for your discriminator

Go 复制代码
def make_generator_model():
    model = tf.keras.Sequential([
        layers.Dense(8*8*512, use_bias=False, input_shape=(100,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((8, 8, 512)),

        layers.Conv2DTranspose(256, (5,5), strides=(2,2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),  # 16x16

        layers.Conv2DTranspose(128, (5,5), strides=(2,2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),  # 32x32

        layers.Conv2DTranspose(64, (5,5), strides=(2,2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),  # 64x64

        # 👇 ADD ONE MORE UPSAMPLING TO REACH 128x128
        layers.Conv2DTranspose(32, (5,5), strides=(2,2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),  # 128x128

        layers.Conv2DTranspose(3, (5,5), strides=(1,1), padding='same', use_bias=False, activation='tanh')
    ])
    return model
# check again the output 
generator = make_generator_model()
noise = tf.random.normal([1, 100])
print(generator(noise).shape)
复制代码
/usr/local/lib/python3.11/dist-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
复制代码
(1, 128, 128, 3)

Build The Gans model

复制代码
# Loss and optimizers
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

🧠 Define the Training Step

Go 复制代码
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        # Losses
        gen_loss = cross_entropy(tf.ones_like(fake_output), fake_output)
        real_loss = cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
        disc_loss = (real_loss + fake_loss) / 2

    # Compute gradients
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    # Apply updates
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

Define Function to Visualize Results

python 复制代码
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    predictions = (predictions + 1) / 2.0  # back to [0,1]

    fig = plt.figure(figsize=(4,4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i])
        plt.axis('off')
    plt.suptitle(f"Epoch {epoch}")
    plt.show()

🧪 The Training Loop

python 复制代码
seed = tf.random.normal([NUM_EXAMPLES, NOISE_DIM])

from tqdm import tqdm # => to display epochs when training

for epoch in range(1, EPOCHS + 1):
    for image_batch in tqdm(train_dataset, desc=f"Epoch {epoch}/{EPOCHS}"):
        gen_loss, disc_loss = train_step(image_batch)

    print(f"Epoch {epoch}, Generator Loss: {gen_loss:.4f}, Discriminator Loss: {disc_loss:.4f}")
    generate_and_save_images(generator, epoch, seed)
复制代码
/dist-packages/keras/src/backend/tensorflow/nn.py:780: UserWarning: "`binary_crossentropy` received `from_logits=True`, but the `output` argument was produced by a Sigmoid activation and thus does not represent logits. Was this intended?
  output, from_logits = _get_logits(
Epoch 1/80: 100%|██████████| 13/13 [00:40<00:00,  3.15s/it]
复制代码
Epoch 1, Generator Loss: 0.8707, Discriminator Loss: 0.4276
复制代码
Epoch 78/80: 100%|██████████| 13/13 [00:14<00:00,  1.11s/it]
复制代码
Epoch 78, Generator Loss: 3.5214, Discriminator Loss: 0.0547
复制代码
Epoch 79/80: 100%|██████████| 13/13 [00:14<00:00,  1.11s/it]
复制代码
Epoch 79, Generator Loss: 5.3691, Discriminator Loss: 0.0091
复制代码
Epoch 80/80: 100%|██████████| 13/13 [00:14<00:00,  1.11s/it]
复制代码
Epoch 80, Generator Loss: 3.7142, Discriminator Loss: 0.0200

🎯 Why there's no "accuracy" in GANs

在生成对抗网络(GAN)中,包含两个核心模型:¶

  • 生成器(G):尝试生成看似真实的伪造图像。
  • 判别器(D):尝试区分图像的真实与伪造。

它们以对抗方式进行训练:

  • 当生成器成功欺骗判别器时,生成器会进一步优化。
  • 当判别器能准确识别伪造图像时,判别器会持续改进。

因此,训练过程中不追踪准确率,而是关注损失值:

  • 生成器损失(gen_loss):衡量生成器欺骗判别器的效果。
  • 判别器损失(disc_loss):衡量判别器区分真实与伪造图像的能力。

若随着训练推进,两个损失值均趋于稳定(未出现发散),则说明训练有效。

python 复制代码
# Generate noise
noise = tf.random.normal([16, NOISE_DIM])

# Generate fake images
generated_images = generator(noise, training=False)

# Plot them
plt.figure(figsize=(4, 4))
for i in range(generated_images.shape[0]):
    plt.subplot(4, 4, i + 1)
    plt.imshow((generated_images[i] * 127.5 + 127.5).numpy().astype("uint8"))
    plt.axis("off")
plt.show()
python 复制代码
# Generate fake images
noise = tf.random.normal([100, NOISE_DIM])
fake_images = generator(noise, training=False)

# Discriminator output for generated images
predictions = discriminator(fake_images, training=False)

# Check how "real" the discriminator thinks they are
mean_pred = tf.reduce_mean(predictions).numpy()
print(f"Average discriminator confidence (real=1, fake=0): {mean_pred:.3f}")
Average discriminator confidence (real=1, fake=0): 0.007
save model
generator.save('generator_model.h5')
discriminator.save('discriminator_model.h5')
Call CNN Model
from tensorflow.keras.models import load_model

classifier = load_model('/kaggle/input/cnn/keras/default/1/improved_dr_cnn.h5')
The next step Analysis and interpretation 🔍
N_SAMPLES = 500  # increase if you want
noise = tf.random.normal([N_SAMPLES, NOISE_DIM])
generated_images = generator(noise, training=False)

# recove the image to the real size 
generated_images = (generated_images + 1) / 2.0
preds = classifier.predict(generated_images)
pred_labels = np.argmax(preds, axis=1)
16/16 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step
class_names = ['Healthy', 'Mild DR', 'Moderate DR','Proliferate DR', 'Severe DR']

plt.figure(figsize=(12, 12))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(generated_images[i])
    plt.title(class_names[pred_labels[i]])
    plt.axis("off")

plt.suptitle("Predicted Classes for GAN Generated Images", fontsize=16)
plt.show()
复制代码
18/18 ━━━━━━━━━━━━━━━━━━━━ 1s 39ms/step
                precision    recall  f1-score   support

       Healthy       0.80      0.97      0.87       200
       Mild DR       0.55      0.21      0.30        77
   Moderate DR       0.53      0.81      0.64       183
Proliferate DR       0.00      0.00      0.00        55
     Severe DR       0.00      0.00      0.00        35

      accuracy                           0.65       550
     macro avg       0.38      0.40      0.36       550
  weighted avg       0.54      0.65      0.57       550

Confusion matrix:
 [[194   2   4   0   0]
 [  7  16  54   0   0]
 [ 27   8 148   0   0]
 [  9   3  43   0   0]
 [  7   0  28   0   0]]
复制代码
/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/usr/local/lib/python3.11/dist-packages/sklearn/metrics/_classification.py:1344: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
相关推荐
xier_ran2 小时前
深度学习:为什么不能将多层神经网络参数全部初始化为零以及如何进行随机初始化
人工智能·深度学习
文心快码BaiduComate2 小时前
疯了!双11,百度文心快码帮我省钱又赚钱?
人工智能
大刘讲IT2 小时前
赋能中小企业:基于五大开源模块的AI智能体构建方案与细化拆解
人工智能·经验分享·ai·开源·制造
来让爷抱一个2 小时前
企业级AI知识库新纪元:如何用开源力量重塑知识管理?
人工智能·开源
巴塞罗那的风2 小时前
Eino框架快速搭建出行agent(二)引入12306 mcp
人工智能·golang·mcp
oak隔壁找我2 小时前
Spring AI 入门教程,使用Ollama本地模型集成,实现对话记忆功能。
java·人工智能·后端
一只柠檬新2 小时前
当AI开始读源码,调Bug这件事彻底变了
android·人工智能·ai编程
大千AI助手2 小时前
SPT:选择性提示调优——让模型自动学习最佳提示插入策略
人工智能·神经网络·llm·提示词·大千ai助手·spt·选择性提示调优
夫唯不争,故无尤也2 小时前
PyTorch中张量和模型的核心属性解析
人工智能·pytorch·深度学习