BERT 微调实验指南
BERT(Bidirectional Encoder Representations from Transformers)是 Google 于 2018 年提出的一种预训练语言模型,其架构基于 Transformer。BERT 以其双向上下文理解能力,彻底改变了自然语言处理(NLP)领域,成为许多 NLP 任务(如问答系统、文本分类、情感分析等)的基石。
BERT 提供两种主要模型:
- BERT-Base:12 层 Transformer,110M 参数。
- BERT-Large:24 层 Transformer,340M 参数。
以下实验教程使用的是 wikitext-2 数据集,基于 Masked Language Model (MLM) 任务,对 BERT-Base 模型 进行微调。
环境配置
1. 运行环境
文档末附有完整的pip管理的包及其版本。
类别 | 详细信息 |
---|---|
CPU | AMD EPYC 9654 96-Core Processor |
CPU 核心数 | 96 核心 / 192 线程 |
GPU | NVIDIA GeForce RTX 4090 |
GPU 显存 | 24 GB |
CUDA 版本 | 12.2 |
操作系统 | Ubuntu 22.04.3 LTS (Jammy Jellyfish) |
Python 版本 | 3.9.20 |
PyTorch 版本 | 2.5.1+cu121 |
2. 安装 Conda
参考这篇文章: https://docs.aheadai.cn/60.html
3. 创建 Conda 环境
创建新的 Conda 环境:
conda create -n bert python=3.9 -y
conda activate bert
4. 安装必要的 Python 包
在深度学习环境中,确保 Python、PyTorch 和 CUDA 版本的兼容性至关重要。
参考这篇文章,选择合适的PyTorch版本:https://docs.aheadai.cn/196.html
# 安装 PyTorch,根据 CUDA 版本匹配
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# 安装 Hugging Face 及其他工具
pip install transformers datasets tqdm numpy pynvml tensorboard tf-keras accelerate ipywidgets
准备分词器和数据集
因为 Hugging Face 的 transformers
库已经提供了预训练模型的高效加载和使用接口,可以完全脱离 BERT 源码,直接加载模型和分词器。
直接使用以下命令下载 BERT_wikitext-2-raw-v1.zip
,并解压:
wget https://mirrors.aheadai.cn/data/BERT_wikitext-2-raw-v1.zip
unzip BERT_wikitext-2-raw-v1.zip -d ./BERT_wikitext-2-raw-v1
解压后,确保文件结构如下:
./BERT_wikitext-2-raw-v1/
├── bert_model_base/ # base模型权重
├── bert_model_large/ # large模型权重
├── bert_tokenizer/ # 分词器
└── data/ # 数据集
训练脚本
下载训练脚本 run_bert_training.py
,注意按需求修改【TODO】
处的参数。
wget https://mirrors.aheadai.cn/scripts/run-bert-finetune.py
主要注意以下参数的修改:
# 1. 确保文件读取路径正确
TOKENIZER_PATH = "./BERT_wikitext-2-raw-v1/BERT_wikitext-2-raw-v1/bert_tokenizer"
BERT_MODEL_PATH = "./BERT_wikitext-2-raw-v1/BERT_wikitext-2-raw-v1/bert_model_large" # 两个模型baseh和arge根据需要选择
DATASET_PATH = "./BERT_wikitext-2-raw-v1/BERT_wikitext-2-raw-v1/data"
# 2. 训练迭代次数和epoch次数
MAX_STEPS = 200
NUM_EPOCHS = 2
# 3. 当前使用的GPU单价
GPU_COST_PER_HOUR = 1.98 # GPU 价格
# 4. log名称和保存位置
logging.basicConfig(filename="training_log_large_fp16.log", level=logging.INFO, format="%(asctime)s - %(message)s")
LOG_DIR = "./logs"
# 5. 训练精度选择
# 5.1默认精度是单精度FP32。可以选择开启混合精度AMP如下
USE_FP16 = True
# 5.2也可以选择强制执行FP16精度。
USE_FP16 = False
# 然后在调用模型后,添加model.half()
model = BertForMaskedLM.from_pretrained(BERT_MODEL_PATH)
model.half()
# 6. GPU选择与设置
PREFERRED_GPU = 0 # 首选GPU编号
USE_MULTIPLE_GPUS = False # 是否使用多卡训练
运行实验
1. 运行训练脚本
python run_bert_training.py
2. 查看日志
日志文件 training_log.log
包含:
- 平均吞吐量
- GPU 资源使用(显存、功耗)
- 成本计算(10,000 步训练总时间与成本)
制作自己的微调数据集
微调 BERT 需要准备符合格式的数据集,具体步骤如下:
数据格式要求
BERT 微调的典型任务分为以下几类,每种任务有不同的数据格式:
- 语言模型预训练(Masked Language Model, MLM)
- 文本分类(Text Classification)
- 序列标注(Sequence Labeling,如命名实体识别)
- 问答(Question Answering)
这里以 MLM 任务(如使用自己的语料库进行语言模型微调)为例。
数据格式示例
-
BERT 的 MLM 任务需要纯文本文件,每一行是一段文本。
示例:Machine learning is great. Deep learning is a subset of machine learning. BERT is a powerful language model.
步骤一:准备数据
-
收集数据
- 将你的语料整理成一个或多个
.txt
文件。 - 每一行是一个独立的句子或段落,句子之间不要有空行。
- 将你的语料整理成一个或多个
-
确保文本质量
- 移除不必要的符号或乱码。
- 如果语料很大,可以随机抽样一部分数据进行微调。
步骤二:加载自定义数据
在训练脚本中加载数据时,使用 Dataset
或 datasets
库。以下是如何加载自定义文本数据的代码:
1. 将文本数据转换为 Dataset
格式
from datasets import Dataset
import os
def load_local_dataset(file_path):
"""从本地加载数据集"""
with open(file_path, 'r', encoding='utf-8') as f:
return [{"text": line.strip()} for line in f]
# 加载自定义数据
file_path = "./my_custom_data.txt" # 替换为你的数据路径
data = load_local_dataset(file_path)
# 转换为 Dataset 格式
dataset = Dataset.from_list(data)
print(dataset)
2. 对数据进行分词和处理
from transformers import BertTokenizer
# 加载分词器
tokenizer = BertTokenizer.from_pretrained("./BERT_wikitext-2-raw-v1/BERT_wikitext-2-raw-v1/bert_tokenizer")
# 分词与编码
def tokenize_function(examples):
tokenized = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
tokenized["labels"] = tokenized["input_ids"].copy() # MLM 任务需要 input_ids 作为标签
return tokenized
# 对数据进行分词
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
步骤三:划分训练集和验证集
为了训练稳定,建议将数据集分为 训练集 和 验证集。
# 划分数据集(例如 80% 训练,20% 验证)
train_size = int(0.8 * len(tokenized_dataset))
train_dataset = tokenized_dataset.select(range(train_size))
eval_dataset = tokenized_dataset.select(range(train_size, len(tokenized_dataset)))
print(f"训练集大小: {len(train_dataset)}")
print(f"验证集大小: {len(eval_dataset)}")
步骤四:微调模型
使用准备好的数据集进行微调,以下是训练部分代码:
from transformers import BertForMaskedLM, Trainer, TrainingArguments
# 加载预训练模型
model = BertForMaskedLM.from_pretrained("./BERT_wikitext-2-raw-v1/BERT_wikitext-2-raw-v1/bert_model")
# 设置训练参数
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
save_strategy="epoch",
per_device_train_batch_size=32,
logging_dir="./logs",
logging_steps=10,
num_train_epochs=3,
save_total_limit=1,
load_best_model_at_end=True
)
# 创建 Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
# 训练模型
trainer.train()
步骤五:验证和保存模型
微调完成后,可以验证模型效果并保存模型:
# 保存微调后的模型
model.save_pretrained("./finetuned_bert_model")
tokenizer.save_pretrained("./finetuned_bert_model")
附:Python 环境中的已安装包(pip list
)
包名 | 版本 |
---|---|
absl-py | 2.1.0 |
accelerate | 1.1.1 |
aiohappyeyeballs | 2.4.3 |
aiohttp | 3.11.2 |
aiosignal | 1.3.1 |
asttokens | 2.4.1 |
astunparse | 1.6.3 |
async-timeout | 5.0.1 |
attrs | 24.2.0 |
certifi | 2024.8.30 |
charset-normalizer | 3.4.0 |
comm | 0.2.2 |
datasets | 3.1.0 |
debugpy | 1.8.8 |
decorator | 5.1.1 |
dill | 0.3.8 |
exceptiongroup | 1.2.2 |
executing | 2.1.0 |
filelock | 3.13.1 |
flatbuffers | 24.3.25 |
frozenlist | 1.5.0 |
fsspec | 2024.2.0 |
gast | 0.6.0 |
google-pasta | 0.2.0 |
grpcio | 1.68.0 |
h5py | 3.12.1 |
huggingface-hub | 0.26.2 |
idna | 3.10 |
importlib_metadata | 8.5.0 |
ipykernel | 6.29.5 |
ipywidgets | 8.1.5 |
ipython | 8.18.1 |
jedi | 0.19.2 |
Jinja2 | 3.1.3 |
jupyter_client | 8.6.3 |
jupyter_core | 5.7.2 |
jupyterlab-widgets | 3.0.13 |
keras | 3.6.0 |
libclang | 18.1.1 |
Markdown | 3.7 |
markdown-it-py | 3.0.0 |
MarkupSafe | 2.1.5 |
matplotlib-inline | 0.1.7 |
mdurl | 0.1.2 |
ml-dtypes | 0.4.1 |
mpmath | 1.3.0 |
multidict | 6.1.0 |
multiprocess | 0.70.16 |
namex | 0.0.8 |
nest-asyncio | 1.6.0 |
networkx | 3.2.1 |
numpy | 1.26.3 |
nvidia-cublas-cu12 | 12.1.3.1 |
nvidia-cuda-cupti-cu12 | 12.1.105 |
nvidia-cuda-nvrtc-cu12 | 12.1.105 |
nvidia-cuda-runtime-cu12 | 12.1.105 |
nvidia-cudnn-cu12 | 9.1.0.70 |
nvidia-cufft-cu12 | 11.0.2.54 |
nvidia-curand-cu12 | 10.3.2.106 |
nvidia-cusolver-cu12 | 11.4.5.107 |
nvidia-cusparse-cu12 | 12.1.0.106 |
nvidia-nccl-cu12 | 2.21.5 |
nvidia-nvjitlink-cu12 | 12.1.105 |
nvidia-nvtx-cu12 | 12.1.105 |
opt_einsum | 3.4.0 |
optree | 0.13.1 |
packaging | 24.2 |
pandas | 2.2.3 |
parso | 0.8.4 |
pexpect | 4.9.0 |
pillow | 10.2.0 |
pip | 24.2 |
platformdirs | 4.3.6 |
prompt_toolkit | 3.0.48 |
propcache | 0.2.0 |
protobuf | 5.28.3 |
psutil | 6.1.0 |
ptyprocess | 0.7.0 |
pure_eval | 0.2.3 |
pyarrow | 18.0.0 |
Pygments | 2.18.0 |
pynvml | 11.5.3 |
PySocks | 1.7.1 |
python-dateutil | 2.9.0.post0 |
pytz | 2024.2 |
PyYAML | 6.0.2 |
pyzmq | 26.2.0 |
regex | 2024.11.6 |
requests | 2.32.3 |
rich | 13.9.4 |
safetensors | 0.4.5 |
setuptools | 75.1.0 |
six | 1.16.0 |
stack-data | 0.6.3 |
sympy | 1.13.1 |
tensorboard | 2.18.0 |
tensorboard-data-server | 0.7.2 |
tensorflow | 2.18.0 |
tensorflow-io-gcs-filesystem | 0.37.1 |
termcolor | 2.5.0 |
tf_keras | 2.18.0 |
tokenizers | 0.20.3 |
torch | 2.5.1+cu121 |
torchaudio | 2.5.1+cu121 |
torchvision | 0.20.1+cu121 |
tornado | 6.4.1 |
tqdm | 4.67.0 |
traitlets | 5.14.3 |
transformers | 4.46.2 |
triton | 3.1.0 |
typing_extensions | 4.9.0 |
tzdata | 2024.2 |
urllib3 | 2.2.3 |
wcwidth | 0.2.13 |
Werkzeug | 3.1.3 |
wheel | 0.44.0 |
widgetsnbextension | 4.0.13 |
wrapt | 1.16.0 |
xxhash | 3.5.0 |
yarl | 1.17.1 |
zipp | 3.21.0 |
评论