一.概述

本教程旨在帮助用户快速上手使用 UNet 模型进行图像分割任务,特别是使用 Carvana Image Masking Challenge 数据集进行训练与预测。

UNet 模型是一个经典的卷积神经网络(CNN),广泛应用于医学图像处理、卫星图像分析等领域,主要用于图像分割任务。

本教程使用:

数据集: Carvana Image Masking Challenge数据集(Unet_Train.zip)

训练模型: Unet

通过本教程,用户将了解如何从环境配置、数据准备、模型训练到测试和预测的全过程。无论是初学者还是有一定经验的研究人员,都能通过本教程快速掌握 UNet 模型的使用,并能够在自己的项目中进行定制化训练和预测。

二.快速开始

2.1 环境准备

本教程运行硬件环境:

类别 详细信息
CPU 16 vCPU Intel(R) Xeon(R) Platinum 8352V CPU @ 2.10GHz
GPU RTX 4090 * 1
GPU 显存 24GB
CUDA 版本 10.0
操作系统 Ubuntu 18.04.5 LTS (Bionic Beaver)
Python 版本 3.8.20
PyTorch 版本 2.4.1

创建一个 conda 环境 unet ,并激活环境

conda create -n unet python=3.8
conda activate unet 

2.2 从资源站下载实例源码

mkdir run-unet
cd run-unet
# 从aheadai资源站下载资源包
wget https://mirrors.aheadai.cn/scripts/Unet_example.zip

解压资源包

unzip Unet_example.zip
cd Unet_example

确保解压后文件结构如下:

./Unet_example/
    ├── checkpoints/       # 权重保存文件
    ├── data/              # 数据保存文件
    ├── scripts/           # Bash脚本文件
    ├── unet/              # Unet模块文件
    ├── utils/             # 数据集处理模块文件
    ├── evaluate.py            
    ├── hubconf.py
    ├── predict.py         #预测脚本
    ├── train.py           #训练脚本
    └── requirements.txt   # 其余依赖文件

2.3 安装pytorch

pip install torch==2.4.1 --index-url https://pypi.mirrors.ustc.edu.cn/simple
conda install cudatoolkit=12.1

2.4 安装其他依赖

pip install -r requirements.txt --index-url https://pypi.mirrors.ustc.edu.cn/simple

三.训练教程

3.1下载数据文件

本教程提供两种数据加载方法:

3.11 Bash脚本

通过 Bash 脚本,从 Kaggle 下载 Carvana Image Masking Challenge 的训练数据集。此方法需要你输入 Kaggle 用户名和 API 密钥。

bash scripts/download_data.sh

运行脚本会将数据放入 data/imgs/ data/masks/ 文件夹。

通过此方法加载数据集需调整代码 Train.py(同理,若使用自己提供的数据集训练,则需要对应修改到相应路径下。)

#dir_img = Path("./data/train_hq")
#dir_mask = Path("./data/train_masks")
dir_img = Path('./data/imgs/')
dir_mask = Path('./data/masks/')
3.12 AheadAI资源站(推荐)

确保当前目录在 ./run-unet/Unet_example/ ,下载Carvana Image Masking Challenge数据集。文件名为 Unet_Train.zip

cd data
wget https://mirrors.aheadai.cn/data/Unet_Train.zip
unzip Unet_Train.zip
unzip train_masks.zip
unzip train_hq.zip
unzip test_40.zip

此方法无需调整代码。

解压后的 Unet_Train.zip 文件结构为:

./Unet_Train.zip/
    ├── train_hq.zip         # 训练数据集
    ├── train_masks.zip      # 掩码数据集
    └── test_40.zip          # 测试数据集

回到 Unet_example/ 目录进行后续操作:

cd ..

3.2 模型训练

3.21 参数解释

运行 python train.py -h 命令会显示 train.py 脚本的帮助信息。

python train.py -h

具体来说,帮助信息展示了脚本支持的命令行参数及其功能:

--epochs E, -e E             #设置训练的轮数(epochs)
--batch-size B, -b B         #设置每个批次的大小(batch size)
--learning-rate LR, -l LR    #设置学习率(learning rate)
--load LOAD, -f LOAD         #从指定的 .pth 文件中加载预训练模型的权重
--scale SCALE, -s SCALE      #设置图像的缩放因子
--validation VAL, -v VAL     #设置用于验证的数据比例
--amp                        #启用混合精度训练(Automatic Mixed Precision)

比如示例命令:

python train.py --epochs 10 --batch-size 16 --learning-rate 0.001 --scale 0.5 --validation 20.0 --amp

这个命令会训练模型 10epoch,每个批次使用 16 张图像,学习率为 0.001,图像缩放比例为 0.5,并将 20% 的数据用于验证,同时启用混合精度训练。

3.22 开始训练模型

其余参数选择默认:

python train.py --amp

其中,AMP 会使用更低的浮动精度(如半精度浮点数 FP16),在不影响模型准确度的情况下提高训练效率,尤其是在 GPU 上,推荐使用

但在本例中使用AMP可能会出现nan,导致最后预测图片全黑。此情况设置dtype=torch.bfloat16 :

#with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp,dtype=torch.bfloat16):

如果数据集加载成功,会有如下进度条:

然后等待训练结束,训练模型并将其保存后,你可以通过 CLI 轻松测试图像上的输出掩码。

在目录 ./checkpoints/ 下会得到以:

checkpoint_epoch{epoch}.pth

命名格式保存的模型文件供预测使用。

同时在目录 ./wandb/ 下会得到一些日志信息:

./wandb/
    ├── latest-run/            # 存放当前正在进行的训练的实时数据和信息
    ├── run.../                # 存放单次训练运行的所有相关数据和文件
    ├── debug.log              # wandb 运行时的常规调试日志
    ├── debug-internal.log     # wandb 内部操作的调试日志
    └── debug-cli.root.log     # 与 wandb 命令行接口相关的调试日志

3.3 模型测试

3.31 参数解释

运行 python predict.py -h 命令会显示 predict.py 脚本的帮助信息。

python predict.py -h

具体来说,帮助信息展示了脚本支持的命令行参数及其功能:

--model FILE, -m FILE            #指定存储训练好的模型的文件路径(默认为 MODEL.pth)
--input INPUT, -i                #指定输入图像文件的路径
--output INPUT, -o INPUT         #指定输出图像的文件路径
--viz, -v                        #启用可视化模式,用于在处理每张图像时显示图像和预测的掩膜
--no-save, -n                    #如果启用此选项,将不会保存生成的掩膜图像
--scale SCALE, -s SCALE          #设置图像的缩放因子
--mask-threshold MASK_THRESHOLD, -t MASK_THRESHOLD    #指定掩膜的阈值

其中 --input 为必须参数。

比如示例命令:

python predict.py -i image1.jpg -v -n                #进行可视化但不保存结果
python predict.py -i image1.jpg -t 0.7 -o result.png #修改阈值并进行预测
python predict.py -i image1.jpg -o result1.png       #加载默认模型并预测图像
python predict.py -m my_model.pth -i image1.jpg -o output1.png #加载指定模型并进行预测
3.32 图像预测

要预测并保存单个图像,执行以下操作(注意指定训练好的模型文件 .pth 即参数 -m ):

python predict.py -i ./data/test_40/000aa097d423_01.jpg  -o output.jpg -m ./checkpoints/checkpoint_epoch1.pth

可以看到以下输出:

INFO: Model loaded!
INFO: Predicting image ./test_40/000aa097d423_01.jpg ...
INFO: Mask saved to output.jpg

那么你就已经成功了!

同时本教程提供了预训练模型权重:/checkpoints/unet_carvana_scale0.5_epoch2.pth,请读者根据需要替换使用。

python predict.py -i ./data/test_40/000aa097d423_01.jpg  -o output.jpg -m ./checkpoints/unet_carvana_scale0.5_epoch2.pth
本文系作者 @ admin 原创发布在 文档中心 | AheadAI ,未经许可,禁止转载。