图像生成:PyTorch从零开始实现一个简单的扩散模型


前言
- 由于本人水平有限,难免出现错漏,敬请批评改正。
- 更多精彩内容,可点击进入Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏、人工智能混合编程实践专栏或我的个人主页查看
- Ultralytics:使用 YOLO11 进行速度估计
- Ultralytics:使用 YOLO11 进行物体追踪
- Ultralytics:使用 YOLO11 进行物体计数
- Ultralytics:使用 YOLO11 进行目标打码
- 人工智能混合编程实践:C++调用Python ONNX进行YOLOv8推理
- 人工智能混合编程实践:C++调用封装好的DLL进行YOLOv8实例分割
- 人工智能混合编程实践:C++调用Python ONNX进行图像超分重建
- 人工智能混合编程实践:C++调用Python AgentOCR进行文本识别
- 通过计算实例简单地理解PatchCore异常检测
- Python将YOLO格式实例分割数据集转换为COCO格式实例分割数据集
- YOLOv8 Ultralytics:使用Ultralytics框架训练RT-DETR实时目标检测模型
- 基于DETR的人脸伪装检测
- YOLOv7训练自己的数据集(口罩检测)
- YOLOv8训练自己的数据集(足球检测)
- YOLOv5:TensorRT加速YOLOv5模型推理
- YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
- 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
- YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
- YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
- Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
- YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
- 使用Kaggle GPU资源免费体验Stable Diffusion开源项目
- Stable Diffusion:在服务器上部署使用Stable Diffusion WebUI进行AI绘图(v2.0)
- Stable Diffusion:使用自己的数据集微调训练LoRA模型(v2.0)
环境要求
bash
Package Version Editable project location
------------------------------------- ------------------- -------------------------
absl-py 1.4.0
absolufy-imports 0.3.1
accelerate 1.9.0
aiofiles 22.1.0
aiohappyeyeballs 2.6.1
aiohttp 3.12.15
aiosignal 1.4.0
aiosqlite 0.21.0
alabaster 1.0.0
albucore 0.0.24
albumentations 2.0.8
ale-py 0.11.2
alembic 1.16.5
altair 5.5.0
annotated-types 0.7.0
annoy 1.17.3
ansicolors 1.1.8
antlr4-python3-runtime 4.9.3
anyio 4.11.0
anywidget 0.9.18
argon2-cffi 25.1.0
argon2-cffi-bindings 21.2.0
args 0.1.0
array_record 0.7.2
arrow 1.3.0
arviz 0.21.0
astropy 7.1.0
astropy-iers-data 0.2025.7.21.0.41.39
asttokens 3.0.0
astunparse 1.6.3
atpublic 5.1
attrs 25.3.0
audioread 3.0.1
Authlib 1.6.4
autograd 1.8.0
babel 2.17.0
backcall 0.2.0
backports.tarfile 1.2.0
bayesian-optimization 3.1.0
beartype 0.21.0
beautifulsoup4 4.13.4
betterproto 2.0.0b7
bigframes 2.12.0
bigquery-magics 0.10.1
black 25.9.0
bleach 6.2.0
blinker 1.9.0
blis 1.3.0
blobfile 3.0.0
blosc2 3.6.1
bokeh 3.7.3
Boruta 0.4.3
boto3 1.40.39
botocore 1.40.39
Bottleneck 1.4.2
bq_helper 0.4.1 /root/src/BigQuery_Helper
bqplot 0.12.45
branca 0.8.1
Brotli 1.1.0
build 1.2.2.post1
CacheControl 0.14.3
cachetools 5.5.2
Cartopy 0.24.1
catalogue 2.0.10
catboost 1.2.8
category_encoders 2.7.0
certifi 2025.8.3
cesium 0.12.4
cffi 2.0.0
chardet 5.2.0
charset-normalizer 3.4.3
Chessnut 0.4.1
chex 0.1.90
clarabel 0.11.1
click 8.3.0
click-plugins 1.1.1.2
cligj 0.7.2
clint 0.5.1
cloudpathlib 0.21.1
cloudpickle 3.1.1
cmake 3.31.6
cmdstanpy 1.2.5
colorama 0.4.6
colorcet 3.1.0
colorlog 6.9.0
colorlover 0.3.0
colour 0.1.5
comm 0.2.3
community 1.0.0b1
confection 0.1.5
cons 0.4.7
contourpy 1.3.2
coverage 7.10.7
cramjam 2.10.0
cryptography 46.0.1
cuda-bindings 12.9.2
cuda-pathfinder 1.2.3
cuda-python 12.9.2
cudf-cu12 25.2.2
cudf-polars-cu12 25.6.0
cufflinks 0.17.3
cuml-cu12 25.2.1
cupy-cuda12x 13.6.0
curl_cffi 0.12.0
cuvs-cu12 25.2.1
cvxopt 1.3.2
cvxpy 1.6.7
cycler 0.12.1
cyipopt 1.5.0
cymem 2.0.11
Cython 3.0.12
cytoolz 1.0.1
daal 2025.8.0
dacite 1.9.2
dask 2024.12.1
dask-cuda 25.2.0
dask-cudf-cu12 25.2.2
dask-expr 1.1.21
dataclasses-json 0.6.7
dataproc-spark-connect 0.8.3
datasets 4.1.1
db-dtypes 1.4.3
dbus-python 1.2.18
deap 1.4.3
debugpy 1.8.15
decorator 4.4.2
deepdiff 8.6.1
defusedxml 0.7.1
Deprecated 1.2.18
diffusers 0.34.0
dill 0.4.0
dipy 1.11.0
distributed 2024.12.1
distributed-ucxx-cu12 0.42.0
distro 1.9.0
dlib 19.24.6
dm-tree 0.1.9
dnspython 2.8.0
docker 7.1.0
docstring_parser 0.17.0
docstring-to-markdown 0.17
docutils 0.21.2
dopamine_rl 4.1.2
duckdb 1.3.2
earthengine-api 1.5.24
easydict 1.13
easyocr 1.7.2
editdistance 0.8.1
eerepr 0.1.2
einops 0.8.1
eli5 0.13.0
email-validator 2.3.0
emoji 2.15.0
en_core_web_sm 3.8.0
entrypoints 0.4
et_xmlfile 2.0.0
etils 1.13.0
etuples 0.3.10
execnb 0.1.14
Farama-Notifications 0.0.4
fastai 2.8.4
fastapi 0.116.1
fastcore 1.8.11
fastdownload 0.0.7
fastjsonschema 2.21.1
fastprogress 1.0.3
fastrlock 0.8.3
fasttext 0.9.3
fasttransform 0.0.2
featuretools 1.31.0
ffmpy 0.6.1
filelock 3.19.1
filetype 1.2.0
fiona 1.10.1
firebase-admin 6.9.0
Flask 3.1.1
flatbuffers 25.2.10
flax 0.10.6
folium 0.20.0
fonttools 4.59.0
fqdn 1.5.1
frozendict 2.4.6
frozenlist 1.7.0
fsspec 2025.9.0
funcy 2.0
fury 0.12.0
future 1.0.0
fuzzywuzzy 0.18.0
gast 0.6.0
gatspy 0.3
gcsfs 2025.3.0
GDAL 3.8.4
gdown 5.2.0
geemap 0.35.3
gensim 4.3.3
geocoder 1.38.1
geographiclib 2.0
geojson 3.2.0
geopandas 0.14.4
geopy 2.4.1
ghapi 1.0.8
gin-config 0.5.0
gitdb 4.0.12
GitPython 3.1.45
glob2 0.7
google 2.0.3
google-adk 1.14.1
google-ai-generativelanguage 0.6.15
google-api-core 1.34.1
google-api-python-client 2.177.0
google-auth 2.40.3
google-auth-httplib2 0.2.0
google-auth-oauthlib 1.2.2
google-cloud-aiplatform 1.105.0
google-cloud-appengine-logging 1.6.2
google-cloud-audit-log 0.3.2
google-cloud-automl 1.0.1
google-cloud-bigquery 3.25.0
google-cloud-bigquery-connection 1.18.3
google-cloud-bigtable 2.32.0
google-cloud-core 2.4.3
google-cloud-dataproc 5.21.0
google-cloud-datastore 2.21.0
google-cloud-firestore 2.21.0
google-cloud-functions 1.20.4
google-cloud-iam 2.19.1
google-cloud-language 2.17.2
google-cloud-logging 3.12.1
google-cloud-resource-manager 1.14.2
google-cloud-secret-manager 2.24.0
google-cloud-spanner 3.56.0
google-cloud-speech 2.33.0
google-cloud-storage 2.19.0
google-cloud-trace 1.16.2
google-cloud-translate 3.12.1
google-cloud-videointelligence 2.16.2
google-cloud-vision 3.10.2
google-colab 1.0.0
google-crc32c 1.7.1
google-genai 1.27.0
google-generativeai 0.8.5
google-pasta 0.2.0
google-resumable-media 2.7.2
googleapis-common-protos 1.70.0
googledrivedownloader 1.1.0
gpxpy 1.6.2
gradio 5.38.1
gradio_client 1.11.0
graphviz 0.21
greenlet 3.2.3
groovy 0.1.2
grpc-google-iam-v1 0.14.2
grpc-interceptor 0.15.4
grpcio 1.75.1
grpcio-status 1.49.0rc1
grpclib 0.4.8
gspread 6.2.1
gspread-dataframe 4.0.0
gym 0.25.2
gym-notices 0.0.8
gymnasium 0.29.0
h11 0.16.0
h2 4.3.0
h2o 3.46.0.7
h5netcdf 1.6.3
h5py 3.14.0
haversine 2.9.0
hdbscan 0.8.40
hep_ml 0.8.0
hf_transfer 0.1.9
hf-xet 1.1.10
highspy 1.11.0
holidays 0.77
holoviews 1.21.0
hpack 4.1.0
html5lib 1.1
httpcore 1.0.9
httpimport 1.4.1
httplib2 0.22.0
httpx 0.28.1
httpx-sse 0.4.1
huggingface-hub 1.0.0rc2
humanize 4.12.3
hyperframe 6.1.0
hyperopt 0.2.7
ibis-framework 9.5.0
id 1.5.0
idna 3.10
igraph 0.11.9
ImageHash 4.3.1
imageio 2.37.0
imageio-ffmpeg 0.6.0
imagesize 1.4.1
imbalanced-learn 0.13.0
immutabledict 4.2.1
importlib_metadata 8.7.0
importlib_resources 6.5.2
imutils 0.5.4
in-toto-attestation 0.9.3
inflect 7.5.0
iniconfig 2.1.0
intel-cmplr-lib-rt 2024.2.0
intel-cmplr-lib-ur 2024.2.0
intel-openmp 2024.2.0
ipyevents 2.0.2
ipyfilechooser 0.6.0
ipykernel 6.17.1
ipyleaflet 0.20.0
ipympl 0.9.7
ipyparallel 8.8.0
ipython 7.34.0
ipython-genutils 0.2.0
ipython_pygments_lexers 1.1.1
ipython-sql 0.5.0
ipytree 0.2.2
ipywidgets 8.1.5
isoduration 20.11.0
isoweek 1.3.3
itsdangerous 2.2.0
Janome 0.5.0
jaraco.classes 3.4.0
jaraco.context 6.0.1
jaraco.functools 4.2.1
jax 0.5.2
jax-cuda12-pjrt 0.5.1
jax-cuda12-plugin 0.5.1
jaxlib 0.5.1
jedi 0.19.2
jeepney 0.9.0
jieba 0.42.1
Jinja2 3.1.6
jiter 0.10.0
jmespath 1.0.1
joblib 1.5.2
json5 0.12.1
jsonpatch 1.33
jsonpickle 4.1.1
jsonpointer 3.0.0
jsonschema 4.25.0
jsonschema-specifications 2025.4.1
jupyter_client 8.6.3
jupyter-console 6.1.0
jupyter_core 5.8.1
jupyter-events 0.12.0
jupyter_kernel_gateway 2.5.2
jupyter-leaflet 0.20.0
jupyter-lsp 1.5.1
jupyter_server 2.12.5
jupyter_server_fileid 0.9.3
jupyter_server_terminals 0.5.3
jupyter_server_ydoc 0.8.0
jupyter-ydoc 0.2.5
jupyterlab 3.6.8
jupyterlab-lsp 3.10.2
jupyterlab_pygments 0.3.0
jupyterlab_server 2.27.3
jupyterlab_widgets 3.0.15
jupytext 1.17.2
kaggle 1.7.4.5
kaggle-environments 1.18.0
kagglehub 0.3.13
keras 3.8.0
keras-core 0.1.7
keras-cv 0.9.0
keras-hub 0.18.1
keras-nlp 0.18.1
keras-tuner 1.4.7
keyring 25.6.0
keyrings.google-artifactregistry-auth 1.1.2
kiwisolver 1.4.8
kornia 0.8.1
kornia_rs 0.1.9
kt-legacy 1.0.5
langchain 0.3.27
langchain-core 0.3.72
langchain-text-splitters 0.3.9
langcodes 3.5.0
langid 1.1.6
langsmith 0.4.8
language_data 1.3.0
lark 1.3.0
launchpadlib 1.10.16
lazr.restfulclient 0.14.4
lazr.uri 1.0.6
lazy_loader 0.4
learntools 0.3.5
libclang 18.1.1
libcudf-cu12 25.2.2
libcugraph-cu12 25.6.0
libcuml-cu12 25.2.1
libcuvs-cu12 25.2.1
libkvikio-cu12 25.2.1
libpysal 4.9.2
libraft-cu12 25.2.0
librmm-cu12 25.6.0
librosa 0.11.0
libucx-cu12 1.18.1
libucxx-cu12 0.42.0
lightgbm 4.6.0
lightning-utilities 0.15.2
lime 0.2.0.1
line_profiler 5.0.0
linkify-it-py 2.0.3
llvmlite 0.43.0
lml 0.2.0
locket 1.0.0
logical-unification 0.4.6
lxml 5.4.0
Mako 1.3.10
mamba 0.11.3
marisa-trie 1.2.1
Markdown 3.8.2
markdown-it-py 4.0.0
MarkupSafe 3.0.2
marshmallow 3.26.1
matplotlib 3.7.2
matplotlib-inline 0.1.7
matplotlib-venn 1.1.2
mcp 1.15.0
mdit-py-plugins 0.4.2
mdurl 0.1.2
minify_html 0.16.4
miniKanren 1.0.5
missingno 0.5.2
mistune 0.8.4
mizani 0.13.5
mkl 2025.2.0
mkl-fft 1.3.8
mkl-random 1.2.4
mkl-service 2.4.1
mkl-umath 0.1.1
ml_collections 1.1.0
ml-dtypes 0.4.1
mlcrate 0.2.0
mlxtend 0.23.4
mne 1.10.1
model-signing 1.0.1
more-itertools 10.7.0
moviepy 1.0.3
mpld3 0.5.11
mpmath 1.3.0
msgpack 1.1.1
multidict 6.6.4
multimethod 1.12
multipledispatch 1.0.0
multiprocess 0.70.16
multitasking 0.0.12
murmurhash 1.0.13
music21 9.3.0
mypy_extensions 1.1.0
namex 0.1.0
narwhals 1.48.1
natsort 8.4.0
nbclassic 1.3.1
nbclient 0.5.13
nbconvert 6.4.5
nbdev 2.4.5
nbformat 5.10.4
ndindex 1.10.0
nest-asyncio 1.6.0
networkx 3.5
nibabel 5.3.2
nilearn 0.10.4
ninja 1.13.0
nltk 3.9.1
notebook 6.5.4
notebook_shim 0.2.4
numba 0.60.0
numba-cuda 0.2.0
numexpr 2.11.0
numpy 1.26.4
nvidia-cublas-cu12 12.5.3.2
nvidia-cuda-cupti-cu12 12.5.82
nvidia-cuda-nvcc-cu12 12.5.82
nvidia-cuda-nvrtc-cu12 12.5.82
nvidia-cuda-runtime-cu12 12.5.82
nvidia-cudnn-cu12 9.3.0.75
nvidia-cufft-cu12 11.2.3.61
nvidia-curand-cu12 10.3.6.82
nvidia-cusolver-cu12 11.6.3.83
nvidia-cusparse-cu12 12.5.1.3
nvidia-cusparselt-cu12 0.6.2
nvidia-ml-py 12.575.51
nvidia-nccl-cu12 2.21.5
nvidia-nvcomp-cu12 4.2.0.11
nvidia-nvjitlink-cu12 12.5.82
nvidia-nvtx-cu12 12.4.127
nvtx 0.2.13
nx-cugraph-cu12 25.6.0
oauth2client 4.1.3
oauthlib 3.3.1
odfpy 1.4.1
olefile 0.47
omegaconf 2.3.0
onnx 1.18.0
open_spiel 1.6.1
openai 1.97.1
opencv-contrib-python 4.12.0.88
opencv-python 4.12.0.88
opencv-python-headless 4.12.0.88
openpyxl 3.1.5
openslide-bin 4.0.0.8
openslide-python 1.4.2
opentelemetry-api 1.37.0
opentelemetry-exporter-gcp-trace 1.9.0
opentelemetry-resourcedetector-gcp 1.9.0a0
opentelemetry-sdk 1.37.0
opentelemetry-semantic-conventions 0.58b0
opt_einsum 3.4.0
optax 0.2.5
optree 0.16.0
optuna 4.5.0
orbax-checkpoint 0.11.19
orderly-set 5.5.0
orjson 3.11.0
osqp 1.0.4
overrides 7.7.0
packaging 25.0
pandas 2.2.3
pandas-datareader 0.10.0
pandas-gbq 0.29.2
pandas-profiling 3.6.6
pandas-stubs 2.2.2.240909
pandasql 0.7.3
pandocfilters 1.5.1
panel 1.7.5
papermill 2.6.0
param 2.2.1
parso 0.8.4
parsy 2.1
partd 1.4.2
path 17.1.1
path.py 12.5.0
pathos 0.3.2
pathspec 0.12.1
patsy 1.0.1
pdf2image 1.17.0
peewee 3.18.2
peft 0.16.0
pettingzoo 1.24.0
pexpect 4.9.0
phik 0.12.5
pickleshare 0.7.5
pillow 11.3.0
pip 24.1.2
platformdirs 4.4.0
plotly 5.24.1
plotly-express 0.4.1
plotnine 0.14.5
pluggy 1.6.0
plum-dispatch 2.5.7
ply 3.11
polars 1.25.0
pooch 1.8.2
portpicker 1.5.2
pox 0.3.6
ppft 1.7.7
preprocessing 0.1.13
preshed 3.0.10
prettytable 3.16.0
proglog 0.1.12
progressbar2 4.5.0
prometheus_client 0.22.1
promise 2.3
prompt_toolkit 3.0.51
propcache 0.3.2
prophet 1.1.7
proto-plus 1.26.1
protobuf 3.20.3
psutil 7.1.0
psycopg2 2.9.10
psygnal 0.14.0
ptyprocess 0.7.0
pudb 2025.1.1
puremagic 1.30
py-cpuinfo 9.0.0
py4j 0.10.9.7
pyaml 25.7.0
PyArabic 0.6.15
pyarrow 19.0.1
pyasn1 0.6.1
pyasn1_modules 0.4.2
pybind11 3.0.1
pycairo 1.28.0
pyclipper 1.3.0.post6
pycocotools 2.0.10
pycparser 2.23
pycryptodome 3.23.0
pycryptodomex 3.23.0
pycuda 2025.1.2
pydantic 2.12.0a1
pydantic_core 2.37.2
pydantic-settings 2.11.0
pydata-google-auth 1.9.1
pydegensac 0.1.2
pydicom 3.0.1
pydot 3.0.4
pydotplus 2.0.2
PyDrive 1.3.1
PyDrive2 1.21.3
pydub 0.25.1
pyemd 1.0.0
pyerfa 2.0.1.5
pyexcel-io 0.6.7
pyexcel-ods 0.6.0
pygame 2.6.1
pygit2 1.18.0
pygltflib 1.16.5
Pygments 2.19.2
PyGObject 3.42.0
PyJWT 2.10.1
pyLDAvis 3.4.1
pylibcudf-cu12 25.2.2
pylibcugraph-cu12 25.6.0
pylibraft-cu12 25.2.0
pymc 5.25.1
pymc3 3.11.4
pymongo 4.15.1
Pympler 1.1
pynndescent 0.5.13
pynvjitlink-cu12 0.5.2
pynvml 12.0.0
pyogrio 0.11.0
pyomo 6.9.2
PyOpenGL 3.1.9
pyOpenSSL 25.3.0
pyparsing 3.0.9
pypdf 6.1.0
pyperclip 1.9.0
pyproj 3.7.1
pyproject_hooks 1.2.0
pyshp 2.3.1
PySocks 1.7.1
pyspark 3.5.1
pytensor 2.31.7
pytesseract 0.3.13
pytest 8.4.1
python-apt 0.0.0
python-bidi 0.6.6
python-box 7.3.2
python-dateutil 2.9.0.post0
python-dotenv 1.1.1
python-json-logger 3.3.0
python-louvain 0.16
python-lsp-jsonrpc 1.1.2
python-lsp-server 1.13.1
python-multipart 0.0.20
python-slugify 8.0.4
python-snappy 0.7.3
python-utils 3.9.1
pytokens 0.1.10
pytools 2025.2.4
pytorch-ignite 0.5.2
pytorch-lightning 2.5.5
pytz 2025.2
PyUpSet 0.1.1.post7
pyviz_comms 3.0.6
PyWavelets 1.8.0
PyYAML 6.0.3
pyzmq 26.2.1
qgrid 1.3.1
qtconsole 5.7.0
QtPy 2.4.3
raft-dask-cu12 25.2.0
rapids-dask-dependency 25.2.0
rapids-logger 0.1.1
ratelim 0.1.6
ray 2.49.2
referencing 0.36.2
regex 2025.9.18
requests 2.32.5
requests-oauthlib 2.0.0
requests-toolbelt 1.0.0
requirements-parser 0.9.0
rfc3161-client 1.0.5
rfc3339-validator 0.1.4
rfc3986-validator 0.1.1
rfc3987-syntax 1.1.0
rfc8785 0.1.4
rgf-python 3.12.0
rich 14.1.0
rmm-cu12 25.2.0
roman-numerals-py 3.1.0
rpds-py 0.26.0
rpy2 3.5.17
rsa 4.9.1
rtree 1.4.1
ruff 0.12.5
s3fs 0.4.2
s3transfer 0.14.0
safehttpx 0.1.6
safetensors 0.5.3
scikit-image 0.25.2
scikit-learn 1.2.2
scikit-learn-intelex 2025.8.0
scikit-multilearn 0.2.0
scikit-optimize 0.10.2
scikit-plot 0.3.7
scikit-surprise 1.1.4
scipy 1.15.3
scooby 0.10.1
scs 3.2.7.post2
seaborn 0.12.2
SecretStorage 3.3.3
securesystemslib 1.3.1
segment_anything 1.0
semantic-version 2.10.0
semver 3.0.4
Send2Trash 1.8.3
sentence-transformers 4.1.0
sentencepiece 0.2.0
sentry-sdk 2.33.2
setuptools 75.2.0
setuptools-scm 9.2.0
shap 0.44.1
shapely 2.1.2
shellingham 1.5.4
Shimmy 1.3.0
sigstore 4.0.0
sigstore-models 0.0.5
sigstore-rekor-types 0.0.18
simple-parsing 0.1.7
simpleitk 2.5.2
simplejson 3.20.1
simsimd 6.5.0
siphash24 1.8
six 1.17.0
sklearn-compat 0.1.3
sklearn-pandas 2.2.0
slicer 0.0.7
smart_open 7.3.0.post1
smmap 5.0.2
sniffio 1.3.1
snowballstemmer 3.0.1
sortedcontainers 2.4.0
soundfile 0.13.1
soupsieve 2.7
soxr 0.5.0.post1
spacy 3.8.7
spacy-legacy 3.0.12
spacy-loggers 1.0.5
spanner-graph-notebook 1.1.7
Sphinx 8.2.3
sphinx-rtd-theme 0.2.4
sphinxcontrib-applehelp 2.0.0
sphinxcontrib-devhelp 2.0.0
sphinxcontrib-htmlhelp 2.1.0
sphinxcontrib-jsmath 1.0.1
sphinxcontrib-qthelp 2.0.0
sphinxcontrib-serializinghtml 2.0.0
SQLAlchemy 2.0.41
sqlalchemy-spanner 1.16.0
sqlglot 25.20.2
sqlparse 0.5.3
squarify 0.4.4
srsly 2.5.1
sse-starlette 3.0.2
stable-baselines3 2.1.0
stanio 0.5.1
starlette 0.47.2
statsmodels 0.14.5
stopit 1.1.2
stringzilla 3.12.5
stumpy 1.13.0
sympy 1.13.1
tables 3.10.2
tabulate 0.9.0
tbb 2022.2.0
tbb4py 2022.2.0
tblib 3.1.0
tcmlib 1.4.0
tenacity 8.5.0
tensorboard 2.18.0
tensorboard-data-server 0.7.2
tensorflow 2.18.0
tensorflow-cloud 0.1.5
tensorflow-datasets 4.9.9
tensorflow_decision_forests 1.11.0
tensorflow-hub 0.16.1
tensorflow-io 0.37.1
tensorflow-io-gcs-filesystem 0.37.1
tensorflow-metadata 1.17.2
tensorflow-probability 0.25.0
tensorflow-text 2.18.1
tensorstore 0.1.74
termcolor 3.1.0
terminado 0.18.1
testpath 0.6.0
text-unidecode 1.3
textblob 0.19.0
texttable 1.7.0
tf_keras 2.18.0
tf-slim 1.1.0
Theano 1.0.5
Theano-PyMC 1.1.2
thinc 8.3.6
threadpoolctl 3.6.0
tifffile 2025.6.11
tiktoken 0.9.0
timm 1.0.19
tinycss2 1.4.0
tokenizers 0.21.2
toml 0.10.2
tomlkit 0.13.3
toolz 1.0.0
torch 2.6.0+cu124
torchao 0.10.0
torchaudio 2.6.0+cu124
torchdata 0.11.0
torchinfo 1.8.0
torchmetrics 1.8.2
torchsummary 1.5.1
torchtune 0.6.1
torchvision 0.21.0+cu124
tornado 6.5.2
TPOT 0.12.1
tqdm 4.67.1
traitlets 5.7.1
traittypes 0.2.1
transformers 4.53.3
treelite 4.4.1
treescope 0.1.9
triton 3.2.0
trx-python 0.3
tsfresh 0.21.0
tuf 6.0.0
tweepy 4.16.0
typeguard 4.4.4
typer 0.16.0
typer-slim 0.19.2
types-python-dateutil 2.9.0.20250822
types-pytz 2025.2.0.20250516
types-setuptools 80.9.0.20250529
typing_extensions 4.15.0
typing-inspect 0.9.0
typing-inspection 0.4.1
tzdata 2025.2
tzlocal 5.3.1
uc-micro-py 1.0.3
ucx-py-cu12 0.42.0
ucxx-cu12 0.42.0
ujson 5.11.0
umap-learn 0.5.9.post2
umf 0.11.0
update-checker 0.18.0
uri-template 1.3.0
uritemplate 4.2.0
urllib3 2.5.0
urwid 3.0.3
urwid_readline 0.15.1
uvicorn 0.35.0
vega-datasets 0.9.0
visions 0.8.1
vtk 9.3.1
wadllib 1.3.6
Wand 0.6.13
wandb 0.21.0
wasabi 1.1.3
watchdog 6.0.0
wavio 0.0.9
wcwidth 0.2.13
weasel 0.4.1
webcolors 24.11.1
webencodings 0.5.1
websocket-client 1.8.0
websockets 15.0.1
Werkzeug 3.1.3
wheel 0.45.1
widgetsnbextension 4.0.14
woodwork 0.31.0
wordcloud 1.9.4
wrapt 1.17.2
wurlitzer 3.1.1
xarray 2025.7.1
xarray-einstats 0.9.1
xgboost 2.0.3
xlrd 2.0.2
xvfbwrapper 0.2.14
xxhash 3.5.0
xyzservices 2025.4.0
y-py 0.6.2
yarl 1.20.1
ydata-profiling 4.17.0
ydf 0.9.0
yellowbrick 1.5
yfinance 0.2.65
ypy-websocket 0.8.4
zict 3.0.0
zipp 3.23.0
zstandard 0.23.0
相关介绍
- Python是一种跨平台的计算机程序设计语言。是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能的添加,越多被用于独立的、大型项目的开发。
- PyTorch 是一个深度学习框架,封装好了很多网络和深度学习相关的工具方便我们调用,而不用我们一个个去单独写了。它分为 CPU 和 GPU 版本,其他框架还有 TensorFlow、Caffe 等。PyTorch 是由 Facebook 人工智能研究院(FAIR)基于 Torch 推出的,它是一个基于 Python 的可续计算包,提供两个高级功能:1、具有强大的 GPU 加速的张量计算(如 NumPy);2、构建深度神经网络时的自动微分机制。
- 扩散模型(Diffusion Models)是一类强大的生成模型,近年来在图像、音频、文本等生成任务中取得了突破性成果。它们通过模拟一个逐步加噪(前向过程)和逐步去噪(反向过程)的机制,学习如何从纯噪声中重建出真实数据。
- 核心思想
扩散模型的灵感来源于非平衡热力学:
- 前向过程(Forward Process):将真实数据(如一张图像)逐步加入高斯噪声,经过若干步后,数据最终变成完全的随机噪声。
- 反向过程(Reverse Process):训练一个神经网络,学习如何从噪声中一步步"去噪",最终还原出类似原始数据的新样本。
- 这个过程类似于"破坏-重建":先慢慢把一张清晰的图片弄模糊直至完全看不清,再教会模型如何从模糊中恢复清晰图像。
具体实现
导入相关库
python
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
准备数据集
python
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
bash
Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([2, 2, 0, 8, 8, 6, 2, 2])

定义加噪函数
- 假设你从未读过任何扩散模型论文,但你知道该过程涉及添加噪声。你会如何实现?
- 通过一个简单的方法来控制加噪的程度。那么,如果我们引入一个参数来指定要添加的噪声量,然后执行以下操作:
python
noise = torch.rand like(x)
noisyx=(1-amount)*x+ amount*noise
- 如果amount=0,我们就会原封不动地返回输入值。如果amount达到1,我们就会返回与输入值x毫无关联的噪声。
- 通过这种方式将输入值与噪声混合,我们可以保持输出值在相同的范围内(0到1)。
python
def corrupt(x, amount):
"""Corrupt the input `x` by mixing it with noise according to `amount`"""
noise = torch.rand_like(x)
amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
return x*(1-amount) + noise*amount
定义网络模型

python
class BasicUNet(nn.Module):
"""A minimal UNet implementation."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList([
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
])
self.up_layers = torch.nn.ModuleList([
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
])
self.act = nn.SiLU() # The activation function
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
x = self.act(l(x)) # Through the layer and the activation function
if i < 2: # For all but the third (final) down layer:
h.append(x) # Storing output for skip connection
x = self.downscale(x) # Downscale ready for the next layer
for i, l in enumerate(self.up_layers):
if i > 0: # For all except the first up layer
x = self.upscale(x) # Upscale
x += h.pop() # Fetching stored output (skip connection)
x = self.act(l(x)) # Through the layer and the activation function
return x
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
sum([p.numel() for p in net.parameters()])
bash
torch.Size([8, 1, 28, 28])
309057
训练模型
python
# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# How many runs through the data should we do?
n_epochs = 3
# Create the network
net = BasicUNet()
net.to(device)
# Our loss function
loss_fn = nn.MSELoss()
# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
# Keeping a record of the losses for later viewing
losses = []
# The training loop
for epoch in range(n_epochs):
for x, y in train_dataloader:
# Get some data and prepare the corrupted version
x = x.to(device) # Data on the GPU
noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
noisy_x = corrupt(x, noise_amount) # Create our noisy x
# # 修复:创建子图并添加 plt.show()
# fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# # 显示原始图像
# axs[0].imshow(x[0].cpu().squeeze(), cmap='Greys')
# axs[0].set_title(f'Original Image{str(x[0].shape)}')
# axs[0].axis('off')
# # 显示噪声图像
# axs[1].imshow(noisy_x[0].cpu().squeeze(), cmap='Greys')
# axs[1].set_title(f'Noisy Image{str(noisy_x[0].shape)}')
# axs[1].axis('off')
# Get the model prediction
pred = net(noisy_x)
# Calculate the loss
loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?
# Backprop and update the params:
opt.zero_grad()
loss.backward()
opt.step()
# Store the loss for later
losses.append(loss.item())
# Print our the average of the loss values for this epoch:
avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')
# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);
bash
Finished epoch 0. Average loss for this epoch: 0.025291
Finished epoch 1. Average loss for this epoch: 0.019563
Finished epoch 2. Average loss for this epoch: 0.017892

扩散生成
python
#@markdown Visualizing model predictions on noisy inputs:
# Fetch some data
x, y = next(iter(train_dataloader))
x = x[:8] # Only using the first 8 for easy plotting
# Corrupt with a range of amounts
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# Get the model predictions
with torch.no_grad():
preds = net(noised_x.to(device)).detach().cpu()
# Plot
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');
输出结果

对于较低的数据量,预测结果相当不错!但是,当amount变得很高时,模型可以利用的东西就少了,当amount=1时,它就会输出一个接近数据集平均值的模糊结果。
python
#@markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device) # Start from random
step_history = [x.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
with torch.no_grad(): # No need to track gradients during inference
pred = net(x) # Predict the denoised x0
pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
mix_factor = 1/(n_steps - i) # How much we move towards the prediction
x = x*(1-mix_factor) + pred*mix_factor # Move part of the way there
step_history.append(x.detach().cpu()) # Store step for plotting
fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')
输出结果

python
#@markdown Showing more results, using 40 sampling steps
n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low
with torch.no_grad():
pred = net(x)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')
输出结果

虽然不是很好,但也有一些可识别的数字!你可以尝试延长训练时间(比如 10 或 20 个epoch),并调整模型配置、学习率、优化器等。此外,如果你想尝试难度稍高的数据集,可以使用 fashionMNIST 是数据集。
参考
1\] https://huggingface.co/learn/diffusion-course/ > * 由于本人水平有限,难免出现错漏,敬请批评改正。 > * 更多精彩内容,可点击进入[Python日常小操作](https://blog.csdn.net/friendshiptang/category_11653584.html)专栏、[OpenCV-Python小应用](https://blog.csdn.net/friendshiptang/category_11975851.html)专栏、[YOLO系列](https://blog.csdn.net/friendshiptang/category_12168736.html)专栏、[自然语言处理](https://blog.csdn.net/friendshiptang/category_12396029.html)专栏、[人工智能混合编程实践](https://blog.csdn.net/friendshiptang/category_12915912.html)专栏或我的[个人主页](https://blog.csdn.net/FriendshipTang)查看 > * [Ultralytics:使用 YOLO11 进行速度估计](https://blog.csdn.net/FriendshipTang/article/details/151989345) > * [Ultralytics:使用 YOLO11 进行物体追踪](https://blog.csdn.net/FriendshipTang/article/details/151988142) > * [Ultralytics:使用 YOLO11 进行物体计数](https://blog.csdn.net/FriendshipTang/article/details/151866467) > * [Ultralytics:使用 YOLO11 进行目标打码](https://blog.csdn.net/FriendshipTang/article/details/151868450) > * [人工智能混合编程实践:C++调用Python ONNX进行YOLOv8推理](https://blog.csdn.net/FriendshipTang/article/details/146188546) > * [人工智能混合编程实践:C++调用封装好的DLL进行YOLOv8实例分割](https://blog.csdn.net/FriendshipTang/article/details/149050653) > * [人工智能混合编程实践:C++调用Python ONNX进行图像超分重建](https://blog.csdn.net/FriendshipTang/article/details/146210258) > * [人工智能混合编程实践:C++调用Python AgentOCR进行文本识别](https://blog.csdn.net/FriendshipTang/article/details/146336798) > * [通过计算实例简单地理解PatchCore异常检测](https://blog.csdn.net/FriendshipTang/article/details/148877810) > * [Python将YOLO格式实例分割数据集转换为COCO格式实例分割数据集](https://blog.csdn.net/FriendshipTang/article/details/149101072) > * [YOLOv8 Ultralytics:使用Ultralytics框架训练RT-DETR实时目标检测模型](https://blog.csdn.net/FriendshipTang/article/details/132498898) > * [基于DETR的人脸伪装检测](https://blog.csdn.net/FriendshipTang/article/details/131670277) > * [YOLOv7训练自己的数据集(口罩检测)](https://blog.csdn.net/FriendshipTang/article/details/126513426) > * [YOLOv8训练自己的数据集(足球检测)](https://blog.csdn.net/FriendshipTang/article/details/129035180) > * [YOLOv5:TensorRT加速YOLOv5模型推理](https://blog.csdn.net/FriendshipTang/article/details/131023963) > * [YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU](https://blog.csdn.net/FriendshipTang/article/details/129969044) > * [玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测](https://blog.csdn.net/FriendshipTang/article/details/126696542) > * [YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制](https://blog.csdn.net/FriendshipTang/article/details/130396540) > * [YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层](https://blog.csdn.net/FriendshipTang/article/details/130375883) > * [Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集](https://blog.csdn.net/FriendshipTang/article/details/131979248) > * [YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)](https://blog.csdn.net/FriendshipTang/article/details/131987249) > * [使用Kaggle GPU资源免费体验Stable Diffusion开源项目](https://blog.csdn.net/FriendshipTang/article/details/132238734) > * [Stable Diffusion:在服务器上部署使用Stable Diffusion WebUI进行AI绘图(v2.0)](https://blog.csdn.net/FriendshipTang/article/details/150287538) > * [Stable Diffusion:使用自己的数据集微调训练LoRA模型(v2.0)](https://blog.csdn.net/FriendshipTang/article/details/150283800)