联邦学习FedAvg实例教程(AutoDL)
一.概述
移动通信设备中有许多有用的数据,训练模型后可以提高用户体验。但是,这些数据通常敏感或很庞大,不能直接上传到数据中心,使用传统的方法训练模型。据此提出联邦学习,将训练数据分布在移动设备上,通过聚合本地计算的更新来学习共享模型。
二.快速开始
2.1 环境准备
运行硬件环境:
类别 | 详细信息 |
---|---|
CPU | 16 vCPU Intel(R) Xeon(R) Platinum 8481C |
GPU | RTX 4090D * 1 |
GPU 显存 | 24GB |
CUDA 版本 | 11.7 |
操作系统 | Ubuntu 22.04.3 LTS |
Python 版本 | 3.8 |
PyTorch 版本 | 1.13.1 |
创建一个 conda 环境 fedavg ,python 版本选择 3.8,并激活环境
conda create -n fedavg python=3.8
conda activate fedavg
2.2 安装pytorch及其他库
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install nvidia-ml-py
2.3 下载实例源码
cd autodl-tmp
wget https://mirrors.aheadai.cn/pkgs/FedAvg_Python_image_classification.zip
unzip FedAvg_Python_image_classification
三.教程
在本目录下,在命令行中执行下面的命令:
cd FedAvg_Python_image_classification
python main.py -c ./utils/conf.json
我们还提供了文件 main_log.py
,可以评估实例中的一些参数,运行前请先安装必要库
pip install nvidia-ml-py
python main_log.py -c ./utils/conf.jsonpip install nvidia-ml-py
若运行文件 main_log.py
,结束后会生成fedavg_training_log.txt
,内容参考如下
2024-12-17 15:55:16,701 - INFO - Epoch: 19, Accuracy: 43.4400, Loss: 2.0756, Samples/sec: 6960.48, Step time (s): 14.4267, Epoch time (s): 22.6099, Memory Usage (%): 42.29, Power Consumption (W): 51.83
3.1 服务端
横向联邦学习的服务端的主要功能是将被选择的客户端上传的本地模型进行模型聚合。但这里需要特别注意的是,事实上,对于一个功能完善的联邦学习框架,比如我们将在后面介绍的FATE平台,服务端的功能要复杂得多,比如服务端需要对各个客户端节点进行网络监控、对失败节点发出重连信号等。本章由于是在本地模拟的,不涉及网络通信细节和失败故障等处理,因此不讨论这些功能细节,仅涉及模型聚合功能。
下面我们首先定义一个服务端类 Server,类中的主要函数包括以下几个。
- 定义构造函数。在构造函数中,服务端的工作包括:第一,将配置信息拷贝到服务端中;第二,按照配置中的模型信息获取模型,这里我们使用 torchvision 的 models 模块内置的 ResNet-18 模型。
class Server(object):
def __init__(self, conf, eval_dataset):
self.conf = conf
self.global_model = models.get_model(self.conf["model_name"])
self.eval_loader = torch.utils.data.DataLoader(eval_dataset,
batch_size=self.conf["batch_size"], shuffle=True)
- 定义模型聚合函数。前面我们提到服务端的主要功能是进行模型的聚合,因此定义构造函数后,我们需要在类中定义模型聚合函数,通过接收客户端上传的模型,使用聚合函数更新全局模型。聚合方案有很多种,本节我们采用经典的 FedAvg 算法。
def model_aggregate(self, weight_accumulator):
for name, data in self.global_model.state_dict().items():
update_per_layer = weight_accumulator[name] * self.conf["lambda"]
if data.type() != update_per_layer.type():
data.add_(update_per_layer.to(torch.int64))
else:
data.add_(update_per_layer)
- 定义模型评估函数。对当前的全局模型,利用评估数据评估当前的全局模型性能。通常情况下,服务端的评估函数主要对当前聚合后的全局模型进行分析,用于判断当前的模型训练是需要进行下一轮迭代、还是提前终止,或者模型是否出现发散退化的现象。根据不同的结果,服务端可以采取不同的措施策略。
def model_eval(self):
self.global_model.eval()
total_loss = 0.0
correct = 0
dataset_size = 0
for batch_id, batch in enumerate(self.eval_loader):
data, target = batch
dataset_size += data.size()[0]
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
output = self.global_model(data)
total_loss += torch.nn.functional.cross_entropy(output, target,
reduction='sum').item() # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
acc = 100.0 * (float(correct) / float(dataset_size))
total_l = total_loss / dataset_size
return acc, total_l
3.2 客户端
横向联邦学习的客户端主要功能是接收服务端的下发指令和全局模型,利用本地数据进行局部模型训练。与前一节一样,对于一个功能完善的联邦学习框架,客户端的功能同样相当复杂,比如需要考虑本地的资源(CPU、内存等)是否满足训练需要、当前的网络中断、当前的训练由于受到外界因素影响而中断等。读者如果对这些设计细节感兴趣,可以查看当前流行的联邦学习框架源代码和文档,比如FATE,获取更多的实现细节。本节我们仅考虑客户端本地的模型训练细节。我们首先定义客户端类 Client ,类中的主要函数包括以下两种。
- 定义构造函数。在客户端构造函数中,客户端的主要工作包括:首先,将配置信息拷贝到客户端中;然后,按照配置中的模型信息获取模型,通常由服务端将模型参数传递给客户端,客户端将该全局模型覆盖掉本地模型;最后,配置本地训练数据,在本案例中,我们通过 torchvision 的 datasets 模块获取 cifar10 数据集后按客户端 ID 切分,不同的客户端拥有不同的子数据集,相互之间没有交集。
class Client(object):
def __init__(self, conf, model, train_dataset, id = -1):
self.conf = conf
self.local_model = models.get_model(self.conf["model_name"])
self.client_id = id
self.train_dataset = train_dataset
all_range = list(range(len(self.train_dataset)))
data_len = int(len(self.train_dataset) / self.conf['no_models'])
train_indices = all_range[id * data_len: (id + 1) * data_len]
self.train_loader = torch.utils.data.DataLoader(self.train_dataset,
batch_size=conf["batch_size"], sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))
- 定义模型本地训练函数。本例是一个图像分类的例子,因此,我们使用交叉熵作为本地模型的损失函数,利用梯度下降来求解并更新参数值,实现细节如下面代码块所示。
def local_train(self, model):
for name, param in model.state_dict().items():
self.local_model.state_dict()[name].copy_(param.clone())
optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],
momentum=self.conf['momentum'])
self.local_model.train()
for e in range(self.conf["local_epochs"]):
for batch_id, batch in enumerate(self.train_loader):
data, target = batch
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
optimizer.zero_grad()
output = self.local_model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
print("Epoch %d done." % e)
diff = dict()
for name, data in self.local_model.state_dict().items():
diff[name] = (data - model.state_dict()[name])
return diff
3.3 整合
当配置文件、服务端类和客户端类都定义完毕,我们将这些信息组合起来。首先,读取配置文件信息。
with open(args.conf, 'r') as f:
conf = json.load(f)
接下来,我们将分别定义一个服务端对象和多个客户端对象,用来模拟横向联邦训练场景。
train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
server = Server(conf, eval_datasets)
clients = []
for c in range(conf["no_models"]):
clients.append(Client(conf, server.global_model, train_datasets, c))
每一轮的迭代,服务端会从当前的客户端集合中随机挑选一部分参与本轮迭代训练,被选中的客户端调用本地训练接口local_train进行本地训练,最后服务端调用模型聚合函数 model_aggregate 来更新全局模型,代码如下所示。
for e in range(conf["global_epochs"]):
candidates = random.sample(clients, conf["k"])
weight_accumulator = {}
for name, params in server.global_model.state_dict().items():
weight_accumulator[name] = torch.zeros_like(params)
for c in candidates:
diff = c.local_train(server.global_model)
for name, params in server.global_model.state_dict().items():
weight_accumulator[name].add_(diff[name])
server.model_aggregate(weight_accumulator)
acc, loss = server.model_eval()
print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))
3.5 配置信息
本案例的配置信息在:conf.json
,读者可以根据实际需要修改。
- model_name:模型名称
- no_models:客户端数量
- type:数据集信息
- global_epochs:全局迭代次数,即服务端与客户端的通信迭代次数
- local_epochs:本地模型训练迭代次数
- k:每一轮迭代时,服务端会从所有客户端中挑选k个客户端参与训练。
- batch_size:本地训练每一轮的样本数
- lr,momentum,lambda:本地训练的超参数设置
附:Python 环境中的已安装包(pip list
)
Package | Version |
---|---|
Brotli | 1.0.9 |
certifi | 2024.8.30 |
charset-normalizer | 3.3.2 |
idna | 3.7 |
mkl-fft | 1.3.8 |
mkl-random | 1.2.4 |
mkl-service | 2.4.0 |
numpy | 1.24.3 |
nvidia-ml-py | 12.560.30 |
pillow | 10.4.0 |
pip | 24.2 |
psutil | 6.1.0 |
PySocks | 1.7.1 |
requests | 2.32.3 |
setuptools | 75.1.0 |
torch | 1.13.1 |
torchaudio | 0.13.1 |
torchvision | 0.14.1 |
typing_extensions | 4.11.0 |
urllib3 | 2.2.3 |
wheel | 0.44.0 |
评论