MambaIR图像修复推理过程
MambaIR是基于Mamba的改进算法,专注于通过深度学习技术实现从低分辨率图像(LR,Low-Resolution)到高分辨率图像(HR,High-Resolution)的精细重建。该方法引入了多个关键模块,其中包括创新性的 2DSSM模块(2D Spatial-Spectral Mapping),以增强模型对图像纹理、边缘和色彩信息的保真度。
原码链接:https://github.com/csguoh/mambair
项目开发团队给出的版本(与笔者略有不同)如下:
- Ubuntu 20.04 版本
- CUDA 11.7 的
- Python 3.9 版本
- PyTorch 2.0.1 + cu117 版本
1. 运行环境
类别 | 详细信息 |
---|---|
CPU | 12 vCPU Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz |
CPU 核心数 | 12核心 |
GPU | RTX 3090 |
GPU 显存 | 24 GB |
CUDA 版本 | 11.8 |
操作系统 | ubuntu20.04 |
Python 版本 | 3.8 |
PyTorch 版本 | 2.0.0 |
2. 源码下载
如果要得到每tokens时延、显存利用率等参数请用第三条指令而不是git
git clone https://github.com/csguoh/MambaIR.git
mkdir MambaIR && cd MambaIR
wget https://mirrors.aheadai.cn/scripts/MambaIR.zip
unzip MambaIR.zip
3. conda环境创建与配置
conda init bash
source ~/.bashrc
conda create -n mambair python==3.9
conda activate mambair
pip install pynvml
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
conda env update --name mambair --file environment.yaml -----------12.6
environment.yaml的执行过程比较漫长,请耐心等待
如果create -f environment.yaml
报错找不到packaging:
ModuleNotFoundError: No module named 'packaging'
请下载对应的包:
conda install packaging
conda env update --name mambair --file environment.yaml
4. 数据集与权重文件下载
wget https://mirrors.aheadai.cn/data/SR.zip
wget https://mirrors.aheadai.cn/scripts/classicSRx4.pth
unzip SR.zip -d ./datasets/
mv classicSRx4.pth experiments/pretrained_models/
5. 开始推理
下载对应的yml配置文件(将配置文件中的路径进行修改)
wget https://mirrors.aheadai.cn/scripts/luo_test_SR_x4.yml
mv luo_test_SR_x4.yml ./options/test/
在项目根目录运行推理命令:
python basicsr/test.py -opt options/test/luo_test_SR_x4.yml
可以使用的参数如下:
-opt: 指定主配置文件。
--launcher: 选择分布式训练或作业调度器。
--auto_resume 和 --debug: 控制训练恢复或调试模式。
--local-rank: 管理 GPU 分配。
--force_yml: 动态覆盖配置文件中的部分参数,适合快速调整实验。
如果报错如下:
(mambair) root@autodl-container-18b640a1a0-4f72471d:~/autodl-tmp/MambaIR# python basicsr/test.py -opt options/test/test_MambaIR_SR_x4.yml
Traceback (most recent call last):
File "basicsr/test.py", line 7, in <module>
from basicsr.data import build_dataloader, build_dataset
ModuleNotFoundError: No module named 'basicsr'
是因为basicsr
作为本项目的一部分没有包含在 requirements.txt
或 environment.yaml
中。但本地模块(包括 basicsr
)并未被正确地注册到 Python 环境中,导致环境配置的过程中并没有导入这个模块导致找不到。在test.py文件的第五行加入下面的代码,将路径改成项目文件夹的绝对路径:
sys.path.append('/root/autodl-tmp/MambaIR')
6. 输出结果
得到的log日志部分如下:
wget https://mirrors.aheadai.cn/log/test_test_MambaIR_SR_x4_20241204_101324.log
2024-12-04 10:13:24,275 INFO:
name: test_MambaIR_SR_x4
model_type: MambaIRModel
scale: 4
num_gpu: 1
manual_seed: 10
datasets:[
test_1:[
name: Set5
type: PairedImageDataset
dataroot_gt: /root/autodl-tmp/MambaIR/datasets/SR/Set5/HR
dataroot_lq: /root/autodl-tmp/MambaIR/datasets/SR/Set5/LR_bicubic/X4
filename_tmpl: {}x4
io_backend:[
type: disk
]
phase: test
scale: 4
]
Tokens/s: 3575.397920304924
First Token Delay (s): 0.7794525623321533
Token Delay (s): 3.333930872625643e-06
End-to-End Latency (s 5492.53214263916
Memory Usage (GB): 0.08478641510009766
Power Consumption (W): 258.899
7. 用自己的3840*2160分辨率图片进行推理
将本地图片导入文件夹datasets/cars/
wget https://mirrors.aheadai.cn/data/mambair_cars.zip
unzip mambair_cars.zip -d ./datasets/cars/
编写配置文件drone_cars_SR_x4.yml:其中有三个路径需要修改[TODO]
# general settings
name: drone_cars_SR_x4
model_type: MambaIRModel
scale: 4
num_gpu: 1
manual_seed: 10
datasets:
test_1: # the 1st test dataset
name: cars
type: PairedImageDataset
dataroot_gt: /root/autodl-tmp/MambaIR/datasets/cars [TODO]
dataroot_lq: /root/autodl-tmp/MambaIR/datasets/cars [TODO]
filename_tmpl: '{}'
io_backend:
type: disk
# network structures
network_g:
type: MambaIR
upscale: 4
in_chans: 3
img_size: 64
img_range: 1.
d_state: 16
depths: [6, 6, 6, 6, 6, 6]
embed_dim: 180
mlp_ratio: 2
upsampler: 'pixelshuffle'
resi_connection: '1conv'
# path
path:
pretrain_network_g: /root/autodl-tmp/MambaIR/experiments/pretrained_models/classicSRx4.pth [TODO]
strict_load_g: true
# validation settings
val:
save_img: true
suffix: ~ # add suffix to saved images, if None, use exp name
# metrics:
# psnr: # metric name, can be arbitrary
# type: calculate_psnr
# crop_border: 4
# test_y_channel: true
# ssim:
# type: calculate_ssim
# crop_border: 4
# test_y_channel: true
注意:batch_size的修改在MambaIR\basicsr\data__init__.py的78行:
dataloader_args = dict(dataset=dataset, batch_size=5, shuffle=False, num_workers=0)
执行命令:
python basicsr/test.py -opt options/test/drone_cars_SR_x4.yml
本文系作者 @
admin
原创发布在 文档中心 | AheadAI ,未经许可,禁止转载。
有帮助?
评论