Source code for byzfl.fed_framework.server
import torch
from byzfl.fed_framework import ModelBaseInterface
from byzfl.fed_framework import RobustAggregator
[docs]
class Server(ModelBaseInterface):
def __init__(self, params):
# Check for correct types and values in params
if not isinstance(params, dict):
raise TypeError(f"'params' must be of type dict, but got {type(params).__name__}")
if not isinstance(params["test_loader"], torch.utils.data.DataLoader):
raise TypeError(f"'test_loader' must be a DataLoader, but got {type(params['test_loader']).__name__}")
# Initialize the Server instance
super().__init__({
"device": params["device"],
"model_name": params["model_name"],
"optimizer_name": params["optimizer_name"],
"optimizer_params": params.get("optimizer_params", {}),
"learning_rate": params["learning_rate"],
"weight_decay": params["weight_decay"],
"milestones": params["milestones"],
"learning_rate_decay": params["learning_rate_decay"],
})
self.robust_aggregator = RobustAggregator(params["aggregator_info"], params["pre_agg_list"])
self.test_loader = params["test_loader"]
self.validation_loader = params.get("validation_loader")
if self.validation_loader is not None:
if not isinstance(params["validation_loader"], torch.utils.data.DataLoader):
raise TypeError(f"'validation_loader' must be a DataLoader, but got {type(params['validation_loader']).__name__}")
self.model.eval()
[docs]
def aggregate(self, vectors):
"""
Description
-----------
Aggregates input vectors using the configured robust aggregator.
Parameters
----------
vectors : list or np.ndarray or torch.Tensor
A collection of input vectors.
Returns
-------
Aggregated output vector.
"""
return self.robust_aggregator.aggregate_vectors(vectors)
[docs]
def update_model(self, gradients):
"""
Description
-----------
Updates the global model by aggregating gradients and performing an optimization step.
Parameters
----------
gradients : list
List of gradients to aggregate and apply.
"""
aggregate_gradient = self.aggregate(gradients)
self.set_gradients(aggregate_gradient)
self.step()
[docs]
def step(self):
"""
Description
-----------
Performs a single optimization step for the global model.
"""
self.optimizer.step()
self.scheduler.step()
[docs]
def get_model(self):
"""
Description
-----------
Retrieves the current global model.
Returns
-------
torch.nn.Module
The current global model.
"""
return self.model
def _compute_accuracy(self, data_loader):
total = 0
correct = 0
for inputs, targets in data_loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = self.model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
return correct / total
[docs]
def compute_validation_accuracy(self):
"""
Description
-----------
Computes the accuracy of the global model on the validation dataset.
Returns
-------
float
Validation accuracy.
"""
if self.validation_loader is None:
print("Validation Data Loader is not set.")
return
return self._compute_accuracy(self.validation_loader)
[docs]
def compute_test_accuracy(self):
"""
Description
-----------
Computes the accuracy of the global model on the test dataset.
Returns
-------
float
Test accuracy.
"""
return self._compute_accuracy(self.test_loader)