python
复制代码
early_stop_callback = EarlyStopping(monitor="val/body_pix", mode="min", patience=args.early_stop)
ckpt_callback = ModelCheckpoint(
monitor='val/body_pix',
save_top_k=1,
mode='min',
dirpath=args.save_path,
# filename='body_pixel-epoch={epoch}-val_body_pix={val/body_pix:.4f}',
filename='best-{}'.format(args.save_name),
auto_insert_metric_name=False
)
last_callback = ModelCheckpoint(
every_n_epochs=args.max_epochs,
save_top_k=1,
dirpath=args.save_path,
# filename='body_pixel-epoch={epoch}-val_body_pix={val/body_pix:.4f}',
filename='last-{}'.format(args.save_name),
auto_insert_metric_name=False
)
trainer = pl.Trainer(
accelerator=accelerator, devices=args.devices,
strategy=strategy,
# sync_batchnorm=True,
default_root_dir=args.default_log_path,
# gradient_clip_val=args.gradient_clip_val,
precision="bf16-mixed",
max_epochs=args.max_epochs,
callbacks=[early_stop_callback, ckpt_callback, last_callback]
)