AI智能
改变未来

分布式机器学习中的模型架构

在上一篇博文《分布式机器学习中的模型聚合》(链接:https://www.cnblogs.com/orion-orion/p/15635803.html</font>)中,我们关注了在分布式机器学习中模型聚合(参数通信)的问题,但是对每一个client具体的模型架构设计和参数优化方法还没有讨论。本篇文章我们关注具体模型结构设计和参数优化。首先,在我follow的这篇篇论文[1]中(代码参见[2])不同的client有一个集成模型,而每一个集成模型由多个模型分量组成,可以表示为如下图:接下来我们就自顶向下地分层次展示Client、Learners_ensemble和每个Learner的设计原理。

1. Client

Client是每一个Client任务节点的类设计,它提供与模型类相似的

get_next_batch()

方法和

step

方法供我们前面在博客《分布式机器学习中的模型聚合》中讲过的

aggregator

类调用。但是我们需要认识到,我们在操纵Client时,实际上就在操纵其

Learners_ensemble

,也就是在操纵所有的

Learner

模型分量。

它包含的方法核心如下:

其具体代码实现如下:

class Client(object):r"""一个Client任务节点"""def __init__(self,learners_ensemble,train_iterator,val_iterator,test_iterator,logger,local_steps,tune_locally=False):# 本地的learners_ensemble模型self.learners_ensemble = learners_ensembleself.n_learners = len(self.learners_ensemble)self.tune_locally = tune_locally# 表示是否进行本地调整,我们先化繁为简,略过这一功能if self.tune_locally:self.tuned_learners_ensemble = deepcopy(self.learners_ensemble)else:self.tuned_learners_ensemble = None# 表示是否为二分类问题self.binary_classification_flag = self.learners_ensemble.is_binary_classification# 需要保存train,val,test的DataLoader(因为每个Client对应一个不同的数据集)# 保存DataLoader的好处是只需要对象初始化时设置好DataLoader,后续step时便不用传入数据# 这里"iterator"其实都是torch.utils.data.DataLoader对象# 使用前需要使用iter(train_iterator)来转换为迭代器(用for迭代的话默认转型)self.train_iterator = train_iteratorself.val_iterator = val_iteratorself.test_iterator = test_iterator# 由train_iterator创造迭代器self.train_loader = iter(self.train_iterator)self.n_train_samples = len(self.train_iterator.dataset)self.n_test_samples = len(self.test_iterator.dataset)# 记录每一个分量模型中每一个样本的权重(0~1之间)self.samples_weights = torch.ones(self.n_learners, self.n_train_samples) / self.n_learnersself.local_steps = local_stepsself.counter = 0 # 记录进行优化步骤step的次数self.logger = loggerdef get_next_batch(self):"""带异常判断(安全)地从train_loader(由train_iterator)构建的迭代器中读一个batch如果数据集已经读至末尾,则循环读取"""try:batch = next(self.train_loader)except StopIteration:self.train_loader = iter(self.train_iterator)batch = next(self.train_loader)return batchdef step(self, single_batch_flag=False, *args, **kwargs):"""进行client的一个训练step:参数 single_batch_flag: 若为true, client只使用一个batch进行更新:返回 clients_updates: ()"""self.counter += 1 # 迭代步数+1self.update_sample_weights()self.update_learners_weights()# 最终的优化步落实到learners_ensemble上if single_batch_flag:batch = self.get_next_batch()# 若已设定了一次只使用一个batch,则从train_loader中读一个batchclient_updates = \\self.learners_ensemble.fit_batch(batch=batch,weights=self.samples_weights)else:# 否则,将迭代器train_iterator传入client_updates = \\self.learners_ensemble.fit_epochs(iterator=self.train_iterator,n_epochs=self.local_steps,weights=self.samples_weights)return client_updatesdef write_logs(self):r"""记录train和test数据的loss和acc,后面控制台会打印输出。注意,此处评估调用tuned_learners_ensemble中的evaluate_iterator()方法进行模型评估并记录,evaluate_iterator()方法具体实现我们后面会介绍"""def update_sample_weights(self):# 此方法用于更新每个样本的权重,# 在MixtureClient任务类中重写passdef update_learners_weights(self):# 此方法用于更新每个分量模型的权重,# 在MixtureClient任务类中重写pass

注意,以上Client类还未对

update_learners_weights

update_sample_weights

这两个方法进行定义。定义在如下的MixtureClient中:

class MixtureClient(Client):def update_sample_weights(self):all_losses = self.learners_ensemble.gather_losses(self.val_iterator)self.samples_weights = F.softmax((torch.log(self.learners_ensemble.learners_weights) - all_losses.T), dim=1).Tdef update_learners_weights(self):self.learners_ensemble.learners_weights = self.samples_weights.mean(dim=1)

2. Learners_ensemble

Learners_ensemble是多个分量模型的集成。在优化模型时需要分别对多个模型分量进行优化。在模型输出时,采用多个分量模型加权平均的输出方式。

\\bm{y}_{t} = \\sum_{m=1}^M w_{tm}h(\\mathbf{X}_t; \\bm{\\theta}_{tm})

此外,Learners_ensemble还提供

evaluate_iterator()

方法来完成对模型的评估(该方法得到的评估数值是所有模型分量的平均),供上层

Client

类调用。

它包含的方法核心如下:

其具体代码实现如下:

class LearnersEnsemble(object):"""由多个分量Learners集成的LearnersEnsemble.(是一个可迭代对象,重写了_iter_,_getitem_,_len_方法)"""def __init__(self, learners, learners_weights):self.learners = learnersself.learners_weights = learners_weights# 假设所有learners的特征维度一样self.model_dim = self.learners[0].model_dim# 布尔标识是分类还是回归任务self.is_binary_classification = self.learners[0].is_binary_classification# 默认所有learners的device和metric一样self.device = self.learners[0].deviceself.metric = self.learners[0].metricdef fit_batch(self, batch, weights):"""使用一个batch更新各learner分量.:参数 batch: 元组 (x, y, indices):参数 weights: tensor类型,每个样本对应的权重(可为None):返回 client_updates: np.array类型,大小为(n_learners, model_dim): 用于衡量ensemble中每个learner的新旧参数之间的差异"""#记录每一个learners的参数的每一个维度的更新量client_updates = torch.zeros(len(self.learners), self.model_dim)for learner_id, learner in enumerate(self.learners):old_params = learner.get_param_tensor()if weights is not None:learner.fit_batch(batch=batch, weights=weights[learner_id])else:learner.fit_batch(batch=batch, weights=None)params = learner.get_param_tensor()client_updates[learner_id] = (params - old_params)return client_updates.cpu().numpy()def fit_epochs(self, iterator, n_epochs, weights=None):"""多次遍历训练集(即多个epochs)更新各learner分量.:参数 n_epochs: 使用训练集的epochs轮数:参数 weights: tensor类型,每个样本对应的权重(可为None):返回 client_updates: np.array类型,大小为(n_learners, model_dim): 用于衡量ensemble中每个learner的新旧参数之间的差异"""client_updates = torch.zeros(len(self.learners), self.model_dim)for learner_id, learner in enumerate(self.learners):old_params = learner.get_param_tensor()if weights is not None:learner.fit_epochs(iterator, n_epochs, weights=weights[learner_id])else:learner.fit_epochs(iterator, n_epochs, weights=None)params = learner.get_param_tensor()client_updates[learner_id] = (params - old_params)return client_updates.cpu().numpy()def evaluate_iterator(self, iterator):"""用迭代器指向的数据评估learners.:参数 iterator: yields x, y, indices:返回: global_loss, global_acc(测试数据的)"""if self.is_binary_classification:criterion = nn.BCELoss(reduction="none")else:criterion = nn.NLLLoss(reduction="none")for learner in self.learners:# 将各learner模型设置为evaluation模式learner.model.eval()global_loss = 0.global_metric = 0.n_samples = 0with torch.no_grad():for (x, y, _) in iterator:x = x.to(self.device).type(torch.float32)y = y.to(self.device)n_samples += y.size(0)y_pred = 0.for learner_id, learner in enumerate(self.learners):# 注意一,这里sigmoid和softmax写在model类外,更具灵活性,# 但一般我们仍然将其看做分类器h(x)的一部分# 注意二,此处实质上采用各分类器输出进行加权平均集成if self.is_binary_classification:y_pred += self.learners_weights[learner_id] * torch.sigmoid(learner.model(x))else:y_pred += self.learners_weights[learner_id] * F.softmax(learner.model(x), dim=1)y_pred = torch.clamp(y_pred, min=0., max=1.)if self.is_binary_classification:y = y.type(torch.float32).unsqueeze(1)global_loss += criterion(y_pred, y).sum().item()y_pred = torch.logit(y_pred, eps=1e-10)else:global_loss += criterion(torch.log(y_pred), y).sum().item()global_metric += self.metric(y_pred, y).item()return global_loss / n_samples, global_metric / n_samplesdef gather_losses(self, iterator):"""汇集各learner模型关于迭代的所有样本的losses:参数 iterator::返回: tensor (n_learners, n_samples) ,各learner关于所迭代的数据集所有样本的loss"""n_samples = len(iterator.dataset)all_losses = torch.zeros(len(self.learners), n_samples)for learner_id, learner in enumerate(self.learners):all_losses[learner_id] = learner.gather_losses(iterator)return all_lossesdef free_memory(self):"""释放模型权重"""for learner in self.learners:learner.free_memory()def free_gradients(self):"""释放模型梯度"""for learner in self.learners:learner.free_gradients()# 以下三个方法说明LearnersEnsemble是个可迭代对象def __iter__(self):return LearnersEnsembleIterator(self)def __len__(self):return len(self.learners)def __getitem__(self, idx):return self.learners[idx]

3. Learner

Learner相当于在具体的诸如CNN、RNN等模型之上进行的一层包装,实现了模型训练的接口,其属性

Learner.model

即具体的模型对象,来自类似与下列的模型类:

class CIFAR10CNN(nn.Module):def __init__(self, num_classes):super(CIFAR10CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(32, 64, 5)self.fc1 = nn.Linear(64 * 5 * 5, 2048)self.output = nn.Linear(2048, num_classes)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 64 * 5 * 5)x = F.relu(self.fc1(x))x = self.output(x)return x

它包含的方法核心如下:

其具体代码实现如下:

class Learner:"""负责训练并评估一个(深度)学习器属性----------model (nn.Module): learner训练的模型criterion (torch.nn.modules.loss): 训练`model`所用的损失函数,这里我们设置reduction="none",也就是默认一个batch的loss返回一个向量而不求和/平均metric (fn): 模型评价指标对应的函数, 输入两个向量输出一标量device (str or torch.device):optimizer (torch.optim.Optimizer):lr_scheduler (torch.optim.lr_scheduler):is_binary_classification (bool): 是否将labels转换为float, 如果使用 `BCELoss`(用于分类的交叉熵损失函数)这里必须要设置为True方法------optimizer_step: 进行一轮优化迭代, 需要梯度已经被计算完毕fit_batch: 对一个批量进行一轮优化迭代fit_epoch: 单次遍历iterator中的得到的所有样本,进行一系列批量的迭代fit_epochs: 多次遍历将从iterator指向的训练集gather_losses:收集iterator迭代器所有样本的loss并拼接输出get_param_tensor: 获取获取一个flattened后的`model`的参数free_memory: 释放模型权重free_gradients: 释放模型梯度"""def __init__(self, model,criterion,metric,device,optimizer,lr_scheduler=None,is_binary_classification=False):self.model = model.to(device)self.criterion = criterion.to(device)self.metric = metricself.device = deviceself.optimizer = optimizerself.lr_scheduler = lr_schedulerself.is_binary_classification = is_binary_classificationself.model_dim = int(self.get_param_tensor().shape[0])def optimizer_step(self):"""执行一轮优化迭代,调用之前需要反向传播先算好梯度(即已调用loss.backward())"""self.optimizer.step()if self.lr_scheduler:self.lr_scheduler.step()def fit_batch(self, batch, weights=None):"""基于来自`iterator`的一个batch的样本执行一轮优化迭代:参数 batch (元组(x, y, indices))::参数 weights(tensor): 每个样本的权重,可为none:返回: loss.detach(), metric.detach()(训练数据)"""self.model.train()x, y, indices = batchx = x.to(self.device).type(torch.float32)y = y.to(self.device)if self.is_binary_classification:y = y.type(torch.float32).unsqueeze(1)self.optimizer.zero_grad()y_pred = self.model(x)loss_vec = self.criterion(y_pred, y)metric = self.metric(y_pred, y) / len(y)if weights is not None:weights = weights.to(self.device)loss = (loss_vec.T @ weights[indices]) / loss_vec.size(0)else:loss = loss_vec.mean()loss.backward()self.optimizer.step()if self.lr_scheduler:self.lr_scheduler.step()return loss.detach(), metric.detach()def fit_epoch(self, iterator, weights=None):"""将来自`iterator`的所有batches遍历一次,进行优化迭代:参数 iterator(torch.utils.data.DataLoader)::参数 weights(torch.tensor): 存储每个样本权重的向量,可为None:return: loss.detach(), metric.detach() (训练数据)"""self.model.train()global_loss = 0.global_metric = 0.n_samples = 0for x, y, indices in iterator:x = x.to(self.device).type(torch.float32)y = y.to(self.device)n_samples += y.size(0)if self.is_binary_classification:y = y.type(torch.float32).unsqueeze(1)self.optimizer.zero_grad()y_pred = self.model(x)loss_vec = self.criterion(y_pred, y)if weights is not None:weights = weights.to(self.device)loss = (loss_vec.T @ weights[indices]) / loss_vec.size(0)else:loss = loss_vec.mean()loss.backward()self.optimizer.step()global_loss += loss.detach() * loss_vec.size(0)global_metric += self.metric(y_pred, y).detach()return global_loss / n_samples, global_metric / n_samplesdef gather_losses(self, iterator):"""计算来自iterator的样本中的所有losses并拼接为all_losses:参数 iterator(torch.utils.data.DataLoader)::return: 所有来自iterator.dataset样本的losses拼成的tensor"""self.model.eval()n_samples = len(iterator.dataset)all_losses = torch.zeros(n_samples, device=self.device)with torch.no_grad():for (x, y, indices) in iterator:x = x.to(self.device).type(torch.float32)y = y.to(self.device)if self.is_binary_classification:y = y.type(torch.float32).unsqueeze(1)y_pred = self.model(x)all_losses[indices] = self.criterion(y_pred, y).squeeze()return all_lossesdef fit_epochs(self, iterator, n_epochs, weights=None):"""执行多个n_epochs的训练:参数 iterator(torch.utils.data.DataLoader)::参数 n_epochs(int)::参数 weights: 每个样本权重的向量,可为None:返回: None"""for step in range(n_epochs):self.fit_epoch(iterator, weights)if self.lr_scheduler is not None:self.lr_scheduler.step()def get_param_tensor(self):"""将所有模型参数做为一个flattened的一维张量输出:返回: torch.tensor"""param_list = []for param in self.model.parameters():param_list.append(param.data.view(-1, ))return torch.cat(param_list)def get_grad_tensor(self):"""将 `model` 所有参数的梯度做为flattened的一维张量输出:返回: torch.tensor"""grad_list = []for param in self.model.parameters():if param.grad is not None:grad_list.append(param.grad.data.view(-1, ))return torch.cat(grad_list)def free_memory(self):"""释放模型权重"""del self.optimizerdel self.modeldef free_gradients(self):"""释放模型梯度"""self.optimizer.zero_grad(set_to_none=True)

3. Client、Learners_ensemble、Learner的对比

三者的对比架构图如下:

其中,$A$方法指向$B$方法的箭头代表在$A$方法中调用$B$方法。

我们可以看到,我们在上一篇博文《分布式机器学习中的模型聚合》(链接:https://www.cnblogs.com/orion-orion/p/15635803.html</font>)中所调用的函数

client.step()

以及

client.write_logs()

下层其实还封装着这么多的实现。

需要指出的是,模型的梯度计算和参数更新最终是要落实到

Learner

类去完成,不过模型的评估我们直接在

LearnersEnsemble

类即可完成,而不需要在

Learner

类去单独设计一个方法。

4. 模型测试

我们采用CIFAR10 数据集对论文提出的模型进行测试,可以看到测试效果不错。预设迭代200各epoch,迭代了9各epoch我们就已经达到 Train Acc: 73.587%,Test Acc: 70.577% ,虽然和论文最终宣称的78.1%尚差距,不过最终应该能达到该精度,可见论文声称的结果很大程度上还是靠谱的。

==> Clients initialization..===> Building data iterators..0%|          | 0/80 [00:00<?, ?it/s]4%|▍         | 3/80 [00:00<00:03, 24.48it/s]9%|▉         | 7/80 [00:00<00:02, 30.10it/s]14%|█▍        | 11/80 [00:00<00:02, 25.71it/s]18%|█▊        | 14/80 [00:00<00:02, 23.93it/s]21%|██▏       | 17/80 [00:00<00:02, 24.12it/s]26%|██▋       | 21/80 [00:00<00:02, 25.36it/s]30%|███       | 24/80 [00:00<00:02, 25.34it/s]34%|███▍      | 27/80 [00:01<00:02, 24.98it/s]39%|███▉      | 31/80 [00:01<00:01, 28.59it/s]46%|████▋     | 37/80 [00:01<00:01, 33.12it/s]51%|█████▏    | 41/80 [00:01<00:01, 32.48it/s]56%|█████▋    | 45/80 [00:01<00:01, 24.59it/s]66%|██████▋   | 53/80 [00:01<00:00, 31.57it/s]72%|███████▎  | 58/80 [00:01<00:00, 33.56it/s]78%|███████▊  | 62/80 [00:02<00:00, 33.12it/s]86%|████████▋ | 69/80 [00:02<00:00, 37.24it/s]91%|█████████▏| 73/80 [00:02<00:00, 36.86it/s]98%|█████████▊| 78/80 [00:02<00:00, 35.90it/s]100%|██████████| 80/80 [00:02<00:00, 31.41it/s]===> Initializing clients..0%|          | 0/80 [00:00<?, ?it/s]1%|▏         | 1/80 [00:13<18:05, 13.75s/it]2%|▎         | 2/80 [00:13<07:26,  5.73s/it]4%|▍         | 3/80 [00:13<04:03,  3.16s/it]5%|▌         | 4/80 [00:14<02:28,  1.95s/it]6%|▋         | 5/80 [00:14<01:36,  1.28s/it]8%|▊         | 6/80 [00:14<01:05,  1.13it/s]9%|▉         | 7/80 [00:14<00:46,  1.56it/s]10%|█         | 8/80 [00:14<00:33,  2.12it/s]11%|█▏        | 9/80 [00:14<00:25,  2.81it/s]12%|█▎        | 10/80 [00:14<00:19,  3.61it/s]14%|█▍        | 11/80 [00:14<00:15,  4.43it/s]15%|█▌        | 12/80 [00:14<00:12,  5.25it/s]16%|█▋        | 13/80 [00:15<00:10,  6.13it/s]18%|█▊        | 14/80 [00:15<00:09,  6.86it/s]19%|█▉        | 15/80 [00:15<00:08,  7.42it/s]21%|██▏       | 17/80 [00:15<00:07,  8.42it/s]22%|██▎       | 18/80 [00:15<00:07,  8.73it/s]24%|██▍       | 19/80 [00:15<00:06,  9.03it/s]25%|██▌       | 20/80 [00:15<00:06,  9.22it/s]26%|██▋       | 21/80 [00:15<00:06,  9.40it/s]28%|██▊       | 22/80 [00:15<00:06,  9.23it/s]29%|██▉       | 23/80 [00:16<00:06,  9.33it/s]30%|███       | 24/80 [00:16<00:05,  9.37it/s]32%|███▎      | 26/80 [00:16<00:05,  9.61it/s]34%|███▍      | 27/80 [00:16<00:05,  9.62it/s]36%|███▋      | 29/80 [00:16<00:05,  8.51it/s]39%|███▉      | 31/80 [00:16<00:05,  9.02it/s]41%|████▏     | 33/80 [00:17<00:05,  9.25it/s]42%|████▎     | 34/80 [00:17<00:04,  9.34it/s]44%|████▍     | 35/80 [00:17<00:04,  9.45it/s]45%|████▌     | 36/80 [00:17<00:04,  9.55it/s]46%|████▋     | 37/80 [00:17<00:04,  8.99it/s]48%|████▊     | 38/80 [00:17<00:04,  8.48it/s]49%|████▉     | 39/80 [00:17<00:04,  8.29it/s]50%|█████     | 40/80 [00:17<00:04,  8.08it/s]51%|█████▏    | 41/80 [00:18<00:04,  7.99it/s]52%|█████▎    | 42/80 [00:18<00:04,  7.79it/s]54%|█████▍    | 43/80 [00:18<00:04,  8.01it/s]55%|█████▌    | 44/80 [00:18<00:04,  8.47it/s]56%|█████▋    | 45/80 [00:18<00:03,  8.84it/s]57%|█████▊    | 46/80 [00:18<00:03,  9.03it/s]59%|█████▉    | 47/80 [00:18<00:04,  7.95it/s]60%|██████    | 48/80 [00:18<00:04,  7.89it/s]61%|██████▏   | 49/80 [00:19<00:03,  7.77it/s]62%|██████▎   | 50/80 [00:19<00:04,  6.43it/s]64%|██████▍   | 51/80 [00:19<00:04,  6.74it/s]65%|██████▌   | 52/80 [00:19<00:03,  7.45it/s]66%|██████▋   | 53/80 [00:19<00:03,  8.04it/s]68%|██████▊   | 54/80 [00:19<00:03,  8.41it/s]70%|███████   | 56/80 [00:19<00:02,  9.00it/s]71%|███████▏  | 57/80 [00:20<00:02,  9.20it/s]72%|███████▎  | 58/80 [00:20<00:02,  9.38it/s]74%|███████▍  | 59/80 [00:20<00:02,  9.52it/s]75%|███████▌  | 60/80 [00:20<00:02,  9.64it/s]76%|███████▋  | 61/80 [00:20<00:01,  9.65it/s]78%|███████▊  | 62/80 [00:20<00:01,  9.57it/s]79%|███████▉  | 63/80 [00:20<00:01,  9.69it/s]80%|████████  | 64/80 [00:20<00:01,  9.75it/s]81%|████████▏ | 65/80 [00:20<00:01,  9.69it/s]82%|████████▎ | 66/80 [00:21<00:01,  8.49it/s]84%|████████▍ | 67/80 [00:21<00:01,  8.29it/s]85%|████████▌ | 68/80 [00:21<00:01,  8.08it/s]86%|████████▋ | 69/80 [00:21<00:01,  8.03it/s]88%|████████▊ | 70/80 [00:21<00:01,  7.96it/s]89%|████████▉ | 71/80 [00:21<00:01,  8.10it/s]91%|█████████▏| 73/80 [00:21<00:00,  7.17it/s]92%|█████████▎| 74/80 [00:22<00:00,  7.71it/s]94%|█████████▍| 75/80 [00:22<00:00,  8.13it/s]95%|█████████▌| 76/80 [00:22<00:00,  8.56it/s]98%|█████████▊| 78/80 [00:22<00:00,  9.11it/s]100%|██████████| 80/80 [00:22<00:00,  9.43it/s]100%|██████████| 80/80 [00:22<00:00,  3.52it/s]==> Test Clients initialization..===> Building data iterators..0it [00:00, ?it/s]0it [00:00, ?it/s]===> Initializing clients..0it [00:00, ?it/s]0it [00:00, ?it/s]++++++++++++++++++++++++++++++Global..Train Loss: 2.299 | Train Acc: 10.643% |Test Loss: 2.298 | Test Acc: 10.503% |++++++++++++++++++++++++++++++++++++++++++++++++++################################################################################Training..0%|          | 0/200 [00:00<?, ?it/s]0%|          | 1/200 [01:08<3:48:37, 68.93s/it]1%|          | 2/200 [02:16<3:45:00, 68.18s/it]2%|▏         | 3/200 [03:23<3:41:16, 67.40s/it]2%|▏         | 4/200 [04:29<3:39:15, 67.12s/it]++++++++++++++++++++++++++++++Global..Train Loss: 1.003 | Train Acc: 65.321% |Test Loss: 1.036 | Test Acc: 63.872% |++++++++++++++++++++++++++++++++++++++++++++++++++################################################################################2%|▎         | 5/200 [05:56<4:00:55, 74.13s/it]3%|▎         | 6/200 [07:02<3:50:47, 71.38s/it]4%|▎         | 7/200 [08:08<3:43:55, 69.61s/it]4%|▍         | 8/200 [09:14<3:38:57, 68.43s/it]4%|▍         | 9/200 [10:20<3:35:23, 67.66s/it]++++++++++++++++++++++++++++++Global..Train Loss: 0.754 | Train Acc: 73.587% |Test Loss: 0.835 | Test Acc: 70.577% |++++++++++++++++++++++++++++++++++++++++++++++++++

这里附上论文中数据集和其采用模型的对应关系和论文中所声称的在以上各数据集中能达到的精度。

参考文献

  • [1] Marfoq O, Neglia G, Bellet A, et al. Federated multi-task learning under a mixture of distributions[J]. Advances in Neural Information Processing Systems, 2021, 34.
  • [2] https://github.com/omarfoq/FedEM
赞(0) 打赏
未经允许不得转载:爱站程序员基地 » 分布式机器学习中的模型架构