MambaIR是基于Mamba的改进算法,专注于通过深度学习技术实现从低分辨率图像(LR,Low-Resolution)到高分辨率图像(HR,High-Resolution)的精细重建。该方法引入了多个关键模块,其中包括创新性的 2DSSM模块(2D Spatial-Spectral Mapping),以增强模型对图像纹理、边缘和色彩信息的保真度。

原码链接:https://github.com/csguoh/mambair

项目开发团队给出的版本(与笔者略有不同)如下:

  • Ubuntu 20.04 版本
  • CUDA 11.7 的
  • Python 3.9 版本
  • PyTorch 2.0.1 + cu117 版本

1. 运行环境

类别 详细信息
CPU 12 vCPU Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
CPU 核心数 12核心
GPU RTX 3090
GPU 显存 24 GB
CUDA 版本 11.8
操作系统 ubuntu20.04
Python 版本 3.8
PyTorch 版本 2.0.0

2. 源码下载

如果要得到每tokens时延、显存利用率等参数请用第三条指令而不是git

git clone https://github.com/csguoh/MambaIR.git
mkdir MambaIR && cd MambaIR
wget https://mirrors.aheadai.cn/scripts/MambaIR.zip
unzip MambaIR.zip

3. conda环境创建与配置

conda init bash
source ~/.bashrc
conda create -n mambair python==3.9
conda activate mambair
pip install pynvml
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
conda env update --name mambair --file environment.yaml  -----------12.6

environment.yaml的执行过程比较漫长,请耐心等待

如果create -f environment.yaml报错找不到packaging:

ModuleNotFoundError: No module named 'packaging'

请下载对应的包:

conda install packaging
conda env update --name mambair --file environment.yaml

4. 数据集与权重文件下载

wget https://mirrors.aheadai.cn/data/SR.zip
wget https://mirrors.aheadai.cn/scripts/classicSRx4.pth
unzip SR.zip -d ./datasets/
mv classicSRx4.pth experiments/pretrained_models/

5. 开始推理

下载对应的yml配置文件(将配置文件中的路径进行修改)

wget https://mirrors.aheadai.cn/scripts/luo_test_SR_x4.yml
mv luo_test_SR_x4.yml ./options/test/

在项目根目录运行推理命令:

python basicsr/test.py -opt options/test/luo_test_SR_x4.yml

可以使用的参数如下:

-opt: 指定主配置文件。
--launcher: 选择分布式训练或作业调度器。
--auto_resume 和 --debug: 控制训练恢复或调试模式。
--local-rank: 管理 GPU 分配。
--force_yml: 动态覆盖配置文件中的部分参数,适合快速调整实验。

如果报错如下:

(mambair) root@autodl-container-18b640a1a0-4f72471d:~/autodl-tmp/MambaIR# python basicsr/test.py -opt options/test/test_MambaIR_SR_x4.yml
Traceback (most recent call last):
  File "basicsr/test.py", line 7, in <module>
    from basicsr.data import build_dataloader, build_dataset
ModuleNotFoundError: No module named 'basicsr'

是因为basicsr 作为本项目的一部分没有包含在 requirements.txtenvironment.yaml 中。但本地模块(包括 basicsr)并未被正确地注册到 Python 环境中,导致环境配置的过程中并没有导入这个模块导致找不到。在test.py文件的第五行加入下面的代码,将路径改成项目文件夹的绝对路径

sys.path.append('/root/autodl-tmp/MambaIR')

6. 输出结果

得到的log日志部分如下:

wget https://mirrors.aheadai.cn/log/test_test_MambaIR_SR_x4_20241204_101324.log
2024-12-04 10:13:24,275 INFO: 
  name: test_MambaIR_SR_x4
  model_type: MambaIRModel
  scale: 4
  num_gpu: 1
  manual_seed: 10
  datasets:[
    test_1:[
      name: Set5
      type: PairedImageDataset
      dataroot_gt: /root/autodl-tmp/MambaIR/datasets/SR/Set5/HR
      dataroot_lq: /root/autodl-tmp/MambaIR/datasets/SR/Set5/LR_bicubic/X4
      filename_tmpl: {}x4
      io_backend:[
        type: disk
      ]
      phase: test
      scale: 4
    ]
Tokens/s: 3575.397920304924
First Token Delay (s): 0.7794525623321533
Token Delay (s): 3.333930872625643e-06
End-to-End Latency (s 5492.53214263916
Memory Usage (GB): 0.08478641510009766
Power Consumption (W): 258.899 

7. 用自己的3840*2160分辨率图片进行推理

将本地图片导入文件夹datasets/cars/

wget https://mirrors.aheadai.cn/data/mambair_cars.zip
unzip mambair_cars.zip -d ./datasets/cars/

编写配置文件drone_cars_SR_x4.yml:其中有三个路径需要修改[TODO]

# general settings
name: drone_cars_SR_x4
model_type: MambaIRModel
scale: 4
num_gpu: 1
manual_seed: 10

datasets:
  test_1:  # the 1st test dataset
    name: cars
    type: PairedImageDataset
    dataroot_gt: /root/autodl-tmp/MambaIR/datasets/cars    [TODO]
    dataroot_lq: /root/autodl-tmp/MambaIR/datasets/cars    [TODO]
    filename_tmpl: '{}'
    io_backend:
      type: disk

# network structures
network_g:
  type: MambaIR
  upscale: 4
  in_chans: 3
  img_size: 64
  img_range: 1.
  d_state: 16
  depths: [6, 6, 6, 6, 6, 6]
  embed_dim: 180
  mlp_ratio: 2
  upsampler: 'pixelshuffle'
  resi_connection: '1conv'

# path
path:
  pretrain_network_g: /root/autodl-tmp/MambaIR/experiments/pretrained_models/classicSRx4.pth    [TODO]
  strict_load_g: true

# validation settings
val:
  save_img: true
  suffix: ~  # add suffix to saved images, if None, use exp name

#  metrics:
#    psnr: # metric name, can be arbitrary
#      type: calculate_psnr
#      crop_border: 4
#      test_y_channel: true
#    ssim:
#      type: calculate_ssim
#      crop_border: 4
#      test_y_channel: true

注意:batch_size的修改在MambaIR\basicsr\data__init__.py的78行:

dataloader_args = dict(dataset=dataset, batch_size=5, shuffle=False, num_workers=0)

执行命令:

python basicsr/test.py -opt options/test/drone_cars_SR_x4.yml
本文系作者 @ admin 原创发布在 文档中心 | AheadAI ,未经许可,禁止转载。