本文最后更新于 2024-11-19 16:10

通用图像分割模型:Segment Anything Model(SAM)训练教程

一 概述

SAM是由Meta的FAIR实验室发布的图像分割模型,具有十分强大的零样本泛化能力,模型官网:https://segment-anything.com/。读者可以自行前往官网查看模型效果

二 快速开始

2.1 准备环境

创建conda环境:

本教程选择pytorch作为训练框架,要注意Python与PyTorch的版本对应。

conda create -n sam python=3.9 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/

激活环境:

conda activate sam

下载项目源码:

wget https://mirrors.aheadai.cn/pkgs/segment-anything-main.zip
# 或者也可以从官网去下载,官网链接:https://github.com/facebookresearch/segment-anything

解压:

unzip segment-anything-main.zip 

安装依赖:

cd segment-anything-main
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple

安装CUDA(以CUDA11.8为例)
如何安装合适版本的Pytorch参考https://docs.aheadai.cn/196.html

wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
sh cuda_11.8.0_520.61.05_linux.run
export PATH=/usr/local/cuda-11.8/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH

安装pytorch

pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu113/torch_stable.html -i https://pypi.mirrors.ustc.edu.cn/simple

安装其他依赖:

pip install opencv-python pycocotools matplotlib onnxruntime onnx tqdm monai scikit-image opencv-python-headless
 --default-time 99999 -i https://pypi.mirrors.ustc.edu.cn/simple

【可选】将刚刚新建的sam环境添加到jupyter中

pip install ipykernel
python -m ipykernel install --user --name=sam

2.2 训练教程

下载数据集:

cd .. #先回到上一级地址或任何你想存放数据压缩包的地方,目前还在segment-anything-main里面
wget https://mirrors.aheadai.cn/data/SAM-dataset-551images.zip

下载初始权重文件:(这里选择的是sam_vit_h_4b8939.pth)

wget https://mirrors.aheadai.cn/pkgs/sam_vit_h_4b8939.pth
或前往官网自行下载:https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

下载代码:

wget https://mirrors.aheadai.cn/scripts/sam-gen-npz.py
wget https://mirrors.aheadai.cn/scripts/sam-train.py
# 根据代码中的提示信息,修改相应存储路径

解压数据集:

unzip SAM-dataset-551images.zip

生成npz

python sam-gen-npz.py
# 注意修改代码中的路径

开始训练:

python sam-train.py

如果报错:

Traceback (most recent call last):
File "/online1/ssd/qiql/run-sam/model/train.py", line 141, in 
 mask_predictions, _ = sam_model.module.mask_decoder(
File "/home/qiql/.conda/envs/sam/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
 return forward_call(*input, **kwargs)
File "/home/qiql/segment-anything-main/segment_anything/modeling/mask_decoder.py", line 94, in forward
 masks, iou_pred = self.predict_masks(
File "/home/qiql/segment-anything-main/segment_anything/modeling/mask_decoder.py", line 127, in predict_masks
 src = src + dense_prompt_embeddings
RuntimeError: The size of tensor a (9216) must match the size of tensor b (96) at non-singleton dimension 0

或:

Traceback (most recent call last):
File "/root/autodl-tmp/run-sam/sam-train.py", line 152, in <module>
mask_predictions, _ = sam_model.module.mask_decoder(
File "/root/miniconda3/envs/sam/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/root/autodl-tmp/run-sam/segment-anything-main/segment_anything/modeling/mask_decoder.py", line 94, in forward
masks, iou_pred = self.predict_masks(
File "/root/autodl-tmp/run-sam/segment-anything-main/segment_anything/modeling/mask_decoder.py", line 126, in predict_masks
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
RuntimeError: CUDA out of memory. Tried to allocate 36.00 GiB (GPU 0; 23.68 GiB total capacity; 2.83 GiB already allocated; 18.71 GiB free; 2.89 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

解决办法:

找到SAM源码中的 mask_decoder.py 文件,找到以下代码块:

src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)  # 直接扩展,无条件判断

将其修改为:

if image_embeddings.shape[0] != tokens.shape[0]:  # 动态判断批次大小是否相等
 src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
 src = image_embeddings  # 如果大小相等,直接使用image_embeddings

这里做一下简单的解释:
原代码中,无条件扩展 image_embeddings,即使在 image_embeddings 和 tokens 大小已经匹配时也进行扩展,导致张量大小不匹配错误。此外,当 batch_size 较大时,这种不必要的扩展会显著增加显存占用,可能引发显存不足的问题。修改后的代码通过动态判断批次大小,仅在必要时进行扩展,避免了张量大小不匹配和显存不足的错误,从而支持更大的 batch_size,同时优化了显存使用效率。*

三 数据处理

3.1 问题背景

SAM的初始模型,具有一些基本的图像分割能力,但是当图像灰度接近时效果就会变差,比如以下这张图:

image-20241108103456951

如果需要将里面的区域提取出来,如下图:

image-20241108103330646

使用传统的算法很难做到,同样的,使用初始权重的SAM模型也很难做到,所以需要使用自己的数据集进行二次训练,或者称之为微调

3.2 数据标注

首先,需要准备一些数据,笔者准备了551张原图,使用Labelme进行手动标注。事实上,在SAM发布之后,有一些教程可以搭建基于SAM的标注流程,可以加快标注的速度,Labelme的安装过程可以参考教程:https://blog.csdn.net/m0_62473142/article/details/134316259

但由于SAM模型本身对这些工业图片的分割效果就不好,所以笔者没有采用SAM用来辅助标注。

标注过程如下:

  1. 点击菜单栏的创建多边形
  2. 选点落在图片中的黑色轮廓上,所有点首位相连即该图片标注完毕
  3. 自定义标签名,后续所有图片都应该是这个标签名

3.3 json转mask

做完数据标注后,得到的都是json文件,但训练SAM的过程中,需要的是图片,所以需要把json转为抠出来的图片

转换代码:

import os
import json
from PIL import Image, ImageDraw

source_directory = r'./data/json'
target_directory = r'./data/mask'

# 确保目标目录存在
os.makedirs(target_directory, exist_ok=True)

# 遍历源目录中的所有文件
for filename in os.listdir(source_directory):
    if filename.endswith('.json'):
        json_path = os.path.join(source_directory, filename)

        # 加载JSON文件
        with open(json_path, 'r') as file:
            data = json.load(file)

        # 创建一个与原图像相同大小的空白图像
        width, height = data['imageWidth'], data['imageHeight']
        mask = Image.new('L', (width, height), 0)

        # 对于JSON中的每个标注
        for shape in data['shapes']:
            polygon = [tuple(point) for point in shape['points']]
            ImageDraw.Draw(mask).polygon(polygon, outline=1, fill=255)

        # 构建目标文件路径并保存掩码图像
        mask_filename = filename.replace('.json', '.bmp')
        mask_path = os.path.join(target_directory, mask_filename)
        mask.save(mask_path)

做完上面的操作之后,按照本文第二节中介绍的操作,就可以开始训练了,训练得到的bestSAM.pth

3.4 比较结果

示例代码:

#比较微调后模型和原始模型分割效果,单个图像
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
model_type = 'vit_h'
checkpoint = './model/sam_vit_h_4b8939.pth'
checkpoint1 = './model/bestSAM.pth'
device = 'cuda:2'  
device1 = 'cuda:3'
sam_model_orig = sam_model_registry[model_type](checkpoint=checkpoint)
sam_model_orig.to(device)
sam_model_orig = torch.nn.DataParallel(sam_model_orig, device_ids=[0, 1, 2, 3])

sam_model = sam_model_registry[model_type](checkpoint=checkpoint1)
sam_model.to(device1)
sam_model = torch.nn.DataParallel(sam_model, device_ids=[0, 1, 2, 3])

from segment_anything import sam_model_registry, SamPredictor
predictor_tuned = SamPredictor(sam_model.module)
predictor_original = SamPredictor(sam_model_orig.module)

image = cv2.imread('./data/ALD-dataset-all/after_FFT/5mchs.bmp')

predictor_tuned.set_image(image)
predictor_original.set_image(image)

input_bbox = np.array(bbox_coords[keys[0]])

input_point = np.array([[180, 180]])
input_label = np.array([1])

masks_tuned, _, _ = predictor_tuned.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_bbox,
    multimask_output=True,
)

masks_orig, _, _ = predictor_original.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=input_bbox,
    multimask_output=True,
)

%matplotlib inline 
_, axs = plt.subplots(1, 2, figsize=(25, 25))

axs[0].imshow(image)
show_mask(masks_tuned[0], axs[0])
show_box(input_bbox, axs[0])
axs[0].set_title('Mask with Tuned Model', fontsize=26)
axs[0].axis('off')

axs[1].imshow(image)
show_mask(masks_orig[0], axs[1])
show_box(input_bbox, axs[1])
axs[1].set_title('Mask with Untuned Model', fontsize=26)
axs[1].axis('off')

plt.show()

如果效果不达预期,可以尝试增加训练的epoch,或者增加数据集数量,或者更换初始权重文件,SAM官方一共提供了三个初始权重文件

本文系作者 @ admin 原创发布在 文档中心 | AheadAI ,未经许可,禁止转载。