Client#
The Client
class simulates an honest participant in a federated learning environment. Each client trains its local model on its subset of the data and shares updates (e.g., gradients) with the central server.
Key Features#
Local Training: Allows training on client-specific datasets while maintaining data ownership.
Gradient Computation: Computes gradients of the model’s loss function with respect to its parameters.
Support for Momentum: Incorporates momentum into gradient updates to improve convergence.
Integration with Robust Aggregators: Shares updates with the server, enabling robust aggregation techniques to handle adversarial or heterogeneous data environments.
Initialization Parameters#
- param dict params:
A dictionary containing the configuration for the Client. Must include:
"model_name"
strName of the model to be used. For a complete list of available models, refer to Models.
"device"
strDevice for computation (e.g.,
'cpu'
or'cuda'
).
"optimizer_name"
strName of the optimizer to be used (e.g.,
"SGD"
,"Adam"
).
"optimizer_params"
dict, optionalAdditional parameters for the optimizer (e.g., beta values for Adam).
"learning_rate"
floatLearning rate for the optimizer.
"loss_name"
strName of the loss function to be used (e.g.,
"CrossEntropyLoss"
).
"weight_decay"
floatWeight decay for regularization.
"milestones"
listList of training steps at which the learning rate decay is applied.
"learning_rate_decay"
floatFactor by which the learning rate is reduced at each milestone.
"LabelFlipping"
boolA flag that enables the label-flipping attack. When set to
True
, the class labels in the local dataset are flipped to their opposing classes.
"momentum"
floatMomentum parameter for the optimizer.
"training_dataloader"
DataLoaderPyTorch DataLoader object for the local training dataset.
"nb_labels"
intNumber of labels in the dataset, required for the label-flipping attack.
Methods#
compute_gradients()
Computes the gradients of the local model based on the client’s subset of training data.get_flat_gradients_with_momentum()
Returns the flattened gradients with momentum applied, combining current and past updates.get_flat_flipped_gradients()
Retrieves the gradients of the model after applying the label-flipping attack in a flattened array.get_loss_list()
Returns the list of training losses recorded during training.get_train_accuracy()
Provides the training accuracy for each processed batch.set_model_state(state_dict)
Updates the model’s state with the provided state dictionary.
Examples#
Initialize the Client
class with an MNIST data loader:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from byzfl import Client
# Fix the random seed
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Define the training data loader using MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# Define client parameters
client_params = {
"model_name": "cnn_mnist",
"device": "cpu",
"optimizer_name": "Adam",
"optimizer_params": {"betas": (0.9, 0.999)},
"learning_rate": 0.01,
"loss_name": "CrossEntropyLoss",
"weight_decay": 0.0005,
"milestones": [10, 20],
"learning_rate_decay": 0.5,
"LabelFlipping": True,
"momentum": 0.9,
"training_dataloader": train_loader,
"nb_labels": 10,
}
# Initialize the Client
client = Client(client_params)
Compute gradients for the local dataset:
# Compute gradients
client.compute_gradients()
# Retrieve the training accuracy for the first batch
print(client.get_train_accuracy()[0])
# Retrieve the gradients after applying the label-flipping attack
print(client.get_flat_flipped_gradients())
- class byzfl.Client(params)[source]#
Bases:
ModelBaseInterface
- compute_gradients()[source]#
Description#
Computes the gradients of the local model’s loss function for the current training batch. If the LabelFlipping attack is enabled, gradients for flipped targets are computed and stored separately. Additionally, the training loss and accuracy for the batch are computed and recorded.
- get_flat_flipped_gradients()[source]#
Description#
Retrieves the gradients computed using flipped targets as a flat array.
- Returns:
numpy.ndarray or torch.Tensor – A flat array containing the gradients for the model parameters when trained with flipped targets.
- get_flat_gradients_with_momentum()[source]#
Description#
Computes the gradients with momentum applied and returns them as a flat array.
- Returns:
torch.Tensor – A flat array containing the gradients with momentum applied.
- get_loss_list()[source]#
Description#
Retrieves the list of training losses recorded over the course of training.
- Returns:
list – A list of float values representing the training losses for each batch.
- get_train_accuracy()[source]#
Description#
Retrieves the training accuracy for each batch processed during training.
- Returns:
list – A list of float values representing the training accuracy for each batch.
- set_model_state(state_dict)[source]#
Description#
Updates the state of the model with the provided state dictionary. This method is used to load a saved model state or update the global model in a federated learning context. Typically, this method can be used to synchronize clients with the global model.
- param state_dict:
The state dictionary containing model parameters and buffers.
- type state_dict:
dict
- raises TypeError:
If state_dict is not a dictionary.