Federated Learning Simulation#
The Federated Learning Simulation module demonstrates how to use the key components of the library — Client
, Server
, ByzantineClient
, and DataDistributor
— to simulate a federated learning environment. This example showcases how to perform distributed learning with Byzantine-resilient aggregation strategies.
Key Features#
Client Class: Simulates honest participants that compute gradients based on their local data.
Server Class: Simulates the central server that aggregates gradients and updates the global model.
ByzantineClient Class: Simulates malicious participants injecting adversarial gradients into the aggregation process.
DataDistributor Class: Handles the distribution of data among clients in various configurations, including IID and non-IID distributions (e.g., Dirichlet, Gamma, Extreme), to simulate realistic federated learning setups.
Robust Aggregation: Demonstrates the usage of robust aggregation techniques such as Trimmed Mean, combined with pre-aggregation methods like Static Clipping and Nearest Neighbor Mixing (NNM).
Example: Federated Learning Workflow#
This example uses the MNIST dataset to simulate a federated learning setup with five honest clients and two Byzantine clients. Follow the steps below to run the simulation:
# Import necessary libraries
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from byzfl import Client, Server, ByzantineClient, DataDistributor
from byzfl.utils.misc import set_random_seed
# Set random seed for reproducibility
SEED = 42
# Configurations
nb_honest_clients = 3
nb_byz_clients = 1
nb_training_steps = 1000
batch_size = 25
# Data Preparation
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True)
# Distribute data among clients using non-IID Dirichlet distribution
data_distributor = DataDistributor({
"data_distribution_name": "dirichlet_niid",
"distribution_parameter": 0.5,
"nb_honest": nb_honest_clients,
"data_loader": train_loader,
"batch_size": batch_size,
client_dataloaders = data_distributor.split_data()
# Initialize Honest Clients
honest_clients = [
"model_name": "cnn_mnist",
"device": "cpu",
"optimizer_name": "SGD",
"learning_rate": 0.1,
"loss_name": "NLLLoss",
"weight_decay": 0.0001,
"milestones": [1000],
"learning_rate_decay": 0.25,
"LabelFlipping": False,
"training_dataloader": client_dataloaders[i],
"momentum": 0.9,
"nb_labels": 10,
}) for i in range(nb_honest_clients)
# Prepare Test Dataset
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Server Setup, Use SGD Optimizer
server = Server({
"device": "cpu",
"model_name": "cnn_mnist",
"test_loader": test_loader,
"optimizer_name": "SGD",
"learning_rate": 0.1,
"weight_decay": 0.0001,
"milestones": [1000],
"learning_rate_decay": 0.25,
"aggregator_info": {"name": "TrMean", "parameters": {"f": nb_byz_clients}},
"pre_agg_list": [
{"name": "Clipping", "parameters": {"c": 2.0}},
{"name": "NNM", "parameters": {"f": nb_byz_clients}},
# Byzantine Client Setup
attack = {
"name": "InnerProductManipulation",
"f": nb_byz_clients,
"parameters": {"tau": 3.0},
byz_client = ByzantineClient(attack)
# Training Loop
for training_step in range(nb_training_steps):
# Evaluate Global Model Every 100 Training Steps
if training_step % 100 == 0:
test_acc = server.compute_test_accuracy()
print(f"--- Training Step {training_step}/{nb_training_steps} ---")
print(f"Test Accuracy: {test_acc:.4f}")
# Honest Clients Compute Gradients
for client in honest_clients:
# Aggregate Honest Gradients
honest_gradients = [client.get_flat_gradients_with_momentum() for client in honest_clients]
# Apply Byzantine Attack
byz_vector = byz_client.apply_attack(honest_gradients)
# Combine Honest and Byzantine Gradients
gradients = honest_gradients + byz_vector
# Update Global Model
# Send Updated Model to Clients
new_model = server.get_dict_parameters()
for client in honest_clients:
print("Training Complete!")
Example Output#
Running the above code will produce the following output:
--- Training Step 0/1000 ---
Test Accuracy: 0.0600
--- Training Step 100/1000 ---
Test Accuracy: 0.6454
--- Training Step 200/1000 ---
Test Accuracy: 0.8156
--- Training Step 300/1000 ---
Test Accuracy: 0.8878
--- Training Step 400/1000 ---
Test Accuracy: 0.8667
--- Training Step 500/1000 ---
Test Accuracy: 0.9014
--- Training Step 600/1000 ---
Test Accuracy: 0.9103
--- Training Step 700/1000 ---
Test Accuracy: 0.9647
--- Training Step 800/1000 ---
Test Accuracy: 0.9531
--- Training Step 900/1000 ---
Test Accuracy: 0.9732
Training Complete!
Documentation References#
For more information about individual components, refer to the following: - Client Class: Client - Server Class: Server - ByzantineClient Class: Byzantine Client - RobustAggregator Class: Robust Aggregator - DataDistributor Class: Data Distributor - Models Module: Models
This example can be extended to other datasets and models by modifying the parameters accordingly.
The robustness of the system depends on the aggregation methods and the number of Byzantine participants.
The module is designed to be flexible and adaptable for experimentation with different setups.