Server#
The Server
class simulates the central server in a federated learning environment. It aggregates client updates, applies robust aggregation techniques, and updates the global model to ensure robustness against Byzantine attacks.
Key Features#
Global Model Management: Updates and maintains the global model using aggregated gradients received from clients.
Robust Aggregation: Integrates defensive techniques, such as Trimmed Mean and Static Clipping, to ensure resilience against malicious updates.
Performance Evaluation: Computes accuracy metrics on validation and test datasets to track the global model’s progress.
Integration: Works in conjunction with
Client
andByzantineClient
classes to simulate realistic federated learning scenarios.
Initialization Parameters#
- param params:
dict A dictionary containing the configuration for the Server. Must include:
"model_name"
strName of the model to be used. Refer to Models for available models.
"device"
strName of the device to be used for computations (e.g.,
"cpu"
,"cuda"
).
"optimizer_name"
strName of the optimizer to use (e.g.,
"SGD"
,"Adam"
).
"optimizer_params"
dict, optionalParameters for the optimizer (e.g.,
betas
for Adam,momentum
for SGD).
"learning_rate"
floatLearning rate for the global model optimizer.
"weight_decay"
floatWeight decay (L2 regularization) for the optimizer.
"milestones"
listList of training steps at which the learning rate decay is applied.
"learning_rate_decay"
floatFactor by which the learning rate is reduced at each milestone.
"aggregator_info"
dict- Dictionary specifying the aggregation method and its parameters:
"name"
strName of the aggregator (e.g.,
"TrMean"
).
"parameters"
dictParameters for the aggregator.
"pre_agg_list"
list- List of dictionaries specifying pre-aggregation methods and their parameters:
"name"
strName of the pre-aggregator (e.g.,
"Clipping"
).
"parameters"
dictParameters for the pre-aggregator.
"test_loader"
DataLoaderDataLoader for the test dataset to evaluate the global model.
"validation_loader"
DataLoader, optionalDataLoader for the validation dataset to monitor training performance.
Methods#
aggregate(vectors)
Aggregates input vectors using the configured robust aggregator.update_model(gradients)
Updates the global model by aggregating gradients and performing an optimization step.step()
Executes a single optimization step for the global model.get_model()
Retrieves the current global model.compute_validation_accuracy()
Computes accuracy on the validation dataset.compute_test_accuracy()
Computes accuracy on the test dataset.
Examples#
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from byzfl import Client, Server, ByzantineClient
# Define data loader using MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# Define pre-aggregators and aggregator
pre_aggregators = [{"name": "Clipping", "parameters": {"c": 2.0}}, {"name": "NNM", "parameters": {"f": 1}}]
aggregator_info = {"name": "TrMean", "parameters": {"f": 1}}
# Define server parameters
server_params = {
"device": "cpu",
"model_name": "cnn_mnist",
"test_loader": test_loader,
"optimizer_name": "Adam",
"optimizer_params": {"betas": (0.9, 0.999)},
"learning_rate": 0.01,
"weight_decay": 0.0005,
"milestones": [10, 20],
"learning_rate_decay": 0.5,
"aggregator_info": aggregator_info,
"pre_agg_list": pre_aggregators,
}
# Initialize the Server
server = Server(server_params)
# Example: Aggregation and model update
gradients = [...] # Collect gradients from clients
server.update_model(gradients)
print("Test Accuracy:", server.compute_test_accuracy())
Notes#
The server is designed to be resilient against Byzantine behaviors by integrating pre-aggregation and aggregation techniques.
Accuracy evaluation is built-in to monitor the model’s progress throughout the simulation.
- class byzfl.Server(params)[source]#
Bases:
ModelBaseInterface
- aggregate(vectors)[source]#
Description#
Aggregates input vectors using the configured robust aggregator.
- param vectors:
A collection of input vectors.
- type vectors:
list or np.ndarray or torch.Tensor
- Returns:
Aggregated output vector.
- compute_test_accuracy()[source]#
Description#
Computes the accuracy of the global model on the test dataset.
- Returns:
float – Test accuracy.
- compute_validation_accuracy()[source]#
Description#
Computes the accuracy of the global model on the validation dataset.
- Returns:
float – Validation accuracy.