U-Net:使用 PyTorch 进行语义分割
一.概述
本教程旨在帮助用户快速上手使用 UNet 模型进行图像分割任务,特别是使用 Carvana Image Masking Challenge 数据集进行训练与预测。
UNet 模型是一个经典的卷积神经网络(CNN),广泛应用于医学图像处理、卫星图像分析等领域,主要用于图像分割任务。
本教程使用:
数据集: Carvana Image Masking Challenge数据集(Unet_Train.zip)
训练模型: Unet
通过本教程,用户将了解如何从环境配置、数据准备、模型训练到测试和预测的全过程。无论是初学者还是有一定经验的研究人员,都能通过本教程快速掌握 UNet 模型的使用,并能够在自己的项目中进行定制化训练和预测。
二.快速开始
2.1 环境准备
本教程运行硬件环境:
类别 | 详细信息 |
---|---|
CPU | 16 vCPU Intel(R) Xeon(R) Platinum 8352V CPU @ 2.10GHz |
GPU | RTX 4090 * 1 |
GPU 显存 | 24GB |
CUDA 版本 | 10.0 |
操作系统 | Ubuntu 18.04.5 LTS (Bionic Beaver) |
Python 版本 | 3.8.20 |
PyTorch 版本 | 2.4.1 |
创建一个 conda 环境 unet ,并激活环境
conda create -n unet python=3.8
conda activate unet
2.2 从资源站下载实例源码
mkdir run-unet
cd run-unet
# 从aheadai资源站下载资源包
wget https://mirrors.aheadai.cn/scripts/Unet_example.zip
解压资源包
unzip Unet_example.zip
cd Unet_example
确保解压后文件结构如下:
./Unet_example/
├── checkpoints/ # 权重保存文件
├── data/ # 数据保存文件
├── scripts/ # Bash脚本文件
├── unet/ # Unet模块文件
├── utils/ # 数据集处理模块文件
├── evaluate.py
├── hubconf.py
├── predict.py #预测脚本
├── train.py #训练脚本
└── requirements.txt # 其余依赖文件
2.3 安装pytorch
pip install torch==2.4.1 --index-url https://pypi.mirrors.ustc.edu.cn/simple
conda install cudatoolkit=12.1
2.4 安装其他依赖
pip install -r requirements.txt --index-url https://pypi.mirrors.ustc.edu.cn/simple
三.训练教程
3.1下载数据文件
本教程提供两种数据加载方法:
3.11 Bash脚本
通过 Bash 脚本,从 Kaggle 下载 Carvana Image Masking Challenge 的训练数据集。此方法需要你输入 Kaggle 用户名和 API 密钥。
bash scripts/download_data.sh
运行脚本会将数据放入 data/imgs/
和 data/masks/
文件夹。
通过此方法加载数据集需调整代码 Train.py
。(同理,若使用自己提供的数据集训练,则需要对应修改到相应路径下。)
#dir_img = Path("./data/train_hq")
#dir_mask = Path("./data/train_masks")
dir_img = Path('./data/imgs/')
dir_mask = Path('./data/masks/')
3.12 AheadAI资源站(推荐)
确保当前目录在 ./run-unet/Unet_example/
,下载Carvana Image Masking Challenge数据集。文件名为 Unet_Train.zip
。
cd data
wget https://mirrors.aheadai.cn/data/Unet_Train.zip
unzip Unet_Train.zip
unzip train_masks.zip
unzip train_hq.zip
unzip test_40.zip
此方法无需调整代码。
解压后的 Unet_Train.zip
文件结构为:
./Unet_Train.zip/
├── train_hq.zip # 训练数据集
├── train_masks.zip # 掩码数据集
└── test_40.zip # 测试数据集
回到 Unet_example/
目录进行后续操作:
cd ..
3.2 模型训练
3.21 参数解释
运行 python train.py -h
命令会显示 train.py
脚本的帮助信息。
python train.py -h
具体来说,帮助信息展示了脚本支持的命令行参数及其功能:
--epochs E, -e E #设置训练的轮数(epochs)
--batch-size B, -b B #设置每个批次的大小(batch size)
--learning-rate LR, -l LR #设置学习率(learning rate)
--load LOAD, -f LOAD #从指定的 .pth 文件中加载预训练模型的权重
--scale SCALE, -s SCALE #设置图像的缩放因子
--validation VAL, -v VAL #设置用于验证的数据比例
--amp #启用混合精度训练(Automatic Mixed Precision)
比如示例命令:
python train.py --epochs 10 --batch-size 16 --learning-rate 0.001 --scale 0.5 --validation 20.0 --amp
这个命令会训练模型 10 个 epoch,每个批次使用 16 张图像,学习率为 0.001,图像缩放比例为 0.5,并将 20% 的数据用于验证,同时启用混合精度训练。
3.22 开始训练模型
其余参数选择默认:
python train.py --amp
其中,AMP 会使用更低的浮动精度(如半精度浮点数 FP16),在不影响模型准确度的情况下提高训练效率,尤其是在 GPU 上,推荐使用。
但在本例中使用AMP可能会出现nan
,导致最后预测图片全黑。此情况设置dtype=torch.bfloat16 :
#with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp,dtype=torch.bfloat16):
如果数据集加载成功,会有如下进度条:
然后等待训练结束,训练模型并将其保存后,你可以通过 CLI 轻松测试图像上的输出掩码。
在目录 ./checkpoints/
下会得到以:
checkpoint_epoch{epoch}.pth
命名格式保存的模型文件供预测使用。
同时在目录 ./wandb/
下会得到一些日志信息:
./wandb/
├── latest-run/ # 存放当前正在进行的训练的实时数据和信息
├── run.../ # 存放单次训练运行的所有相关数据和文件
├── debug.log # wandb 运行时的常规调试日志
├── debug-internal.log # wandb 内部操作的调试日志
└── debug-cli.root.log # 与 wandb 命令行接口相关的调试日志
3.3 模型测试
3.31 参数解释
运行 python predict.py -h
命令会显示 predict.py
脚本的帮助信息。
python predict.py -h
具体来说,帮助信息展示了脚本支持的命令行参数及其功能:
--model FILE, -m FILE #指定存储训练好的模型的文件路径(默认为 MODEL.pth)
--input INPUT, -i #指定输入图像文件的路径
--output INPUT, -o INPUT #指定输出图像的文件路径
--viz, -v #启用可视化模式,用于在处理每张图像时显示图像和预测的掩膜
--no-save, -n #如果启用此选项,将不会保存生成的掩膜图像
--scale SCALE, -s SCALE #设置图像的缩放因子
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD #指定掩膜的阈值
其中 --input 为必须参数。
比如示例命令:
python predict.py -i image1.jpg -v -n #进行可视化但不保存结果
python predict.py -i image1.jpg -t 0.7 -o result.png #修改阈值并进行预测
python predict.py -i image1.jpg -o result1.png #加载默认模型并预测图像
python predict.py -m my_model.pth -i image1.jpg -o output1.png #加载指定模型并进行预测
3.32 图像预测
要预测并保存单个图像,执行以下操作(注意指定训练好的模型文件 .pth 即参数 -m ):
python predict.py -i ./data/test_40/000aa097d423_01.jpg -o output.jpg -m ./checkpoints/checkpoint_epoch1.pth
可以看到以下输出:
INFO: Model loaded!
INFO: Predicting image ./test_40/000aa097d423_01.jpg ...
INFO: Mask saved to output.jpg
那么你就已经成功了!
同时本教程提供了预训练模型权重:/checkpoints/unet_carvana_scale0.5_epoch2.pth
,请读者根据需要替换使用。
python predict.py -i ./data/test_40/000aa097d423_01.jpg -o output.jpg -m ./checkpoints/unet_carvana_scale0.5_epoch2.pth
评论