Source code for byzfl.fed_framework.client
import torch
from byzfl.fed_framework import ModelBaseInterface
from byzfl.utils.conversion import flatten_dict
[docs]
class Client(ModelBaseInterface):
def __init__(self, params):
# Check for correct types and values in params
if not isinstance(params, dict):
raise TypeError(f"'params' must be of type dict, but got {type(params).__name__}")
if not isinstance(params["loss_name"], str):
raise TypeError(f"'loss_name' must be of type str, but got {type(params['loss_name']).__name__}")
if not isinstance(params["LabelFlipping"], bool):
raise TypeError(f"'LabelFlipping' must be of type bool, but got {type(params['LabelFlipping']).__name__}")
if not isinstance(params["nb_labels"], int) or not params["nb_labels"] > 1:
raise ValueError(f"'nb_labels' must be an integer greater than 1")
if not isinstance(params["momentum"], float) or not 0 <= params["momentum"] < 1:
raise ValueError(f"'momentum' must be a float in the range [0, 1)")
if not isinstance(params["training_dataloader"], torch.utils.data.DataLoader):
raise TypeError(f"'training_dataloader' must be a DataLoader, but got {type(params['training_dataloader']).__name__}")
# Initialize Client instance
super().__init__({
"model_name": params["model_name"],
"device": params["device"],
"learning_rate": params["learning_rate"],
"weight_decay": params["weight_decay"],
"milestones": params["milestones"],
"learning_rate_decay": params["learning_rate_decay"],
"optimizer_name": params["optimizer_name"],
"optimizer_params": params.get("optimizer_params", {}),
})
self.criterion = getattr(torch.nn, params["loss_name"])()
self.gradient_LF = 0
self.labelflipping = params["LabelFlipping"]
self.nb_labels = params["nb_labels"]
self.momentum = params["momentum"]
self.momentum_gradient = torch.zeros_like(
torch.cat(tuple(
tensor.view(-1)
for tensor in self.model.parameters()
)),
device=params["device"]
)
self.training_dataloader = params["training_dataloader"]
self.train_iterator = iter(self.training_dataloader)
self.loss_list = list()
self.train_acc_list = list()
def _sample_train_batch(self):
"""
Private function to get the next data from the dataloader.
"""
try:
return next(self.train_iterator)
except StopIteration:
self.train_iterator = iter(self.training_dataloader)
return next(self.train_iterator)
[docs]
def compute_gradients(self):
"""
Computes the gradients of the local model loss function.
"""
inputs, targets = self._sample_train_batch()
inputs, targets = inputs.to(self.device), targets.to(self.device)
if self.labelflipping:
self.model.eval()
self.model.zero_grad()
targets_flipped = targets.sub(self.nb_labels - 1).mul(-1)
outputs = self.model(inputs)
loss = self.criterion(outputs, targets_flipped)
loss.backward()
self.gradient_LF = self.get_dict_gradients()
self.model.train()
self.model.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
self.loss_list.append(loss.item())
loss.backward()
# Compute train accuracy
_, predicted = torch.max(outputs.data, 1)
total = targets.size(0)
correct = (predicted == targets).sum().item()
acc = correct / total
self.train_acc_list.append(acc)
[docs]
def get_flat_flipped_gradients(self):
"""
Returns the gradients of the model with flipped targets in a flat array.
"""
return flatten_dict(self.gradient_LF)
[docs]
def get_flat_gradients_with_momentum(self):
"""
Returns the gradients with momentum applied in a flat array.
"""
self.momentum_gradient.mul_(self.momentum)
self.momentum_gradient.add_(
self.get_flat_gradients(),
alpha=1 - self.momentum
)
return self.momentum_gradient
[docs]
def get_loss_list(self):
"""
Returns the list of computed losses over training.
"""
return self.loss_list
[docs]
def get_train_accuracy(self):
"""
Returns the training accuracy per batch.
"""
return self.train_acc_list
[docs]
def set_model_state(self, state_dict):
"""
Updates the model state with the provided state dictionary.
Parameters
----------
state_dict : dict
The state dictionary of a model.
"""
if not isinstance(state_dict, dict):
raise TypeError(f"'state_dict' must be of type dict, but got {type(state_dict).__name__}")
self.model.load_state_dict(state_dict)