Source code for byzfl.fed_framework.model_base_interface


import byzfl.fed_framework.models as models
from byzfl.utils.conversion import flatten_dict, unflatten_dict, unflatten_generator
import torch
import collections

[docs] class ModelBaseInterface(object): def __init__(self, params): # Input validation self._validate_params(params) # Initialize model model_name = params["model_name"] self.device = params["device"] self.model = torch.nn.DataParallel(getattr(models, model_name)()).to(self.device) # Initialize optimizer, set default to SGD if optimizer_name is not provided optimizer_name = params.get("optimizer_name", "SGD") optimizer_params = params.get("optimizer_params", {}) optimizer_class = getattr(torch.optim, optimizer_name, None) if optimizer_class is None: raise ValueError(f"Optimizer '{optimizer_name}' is not supported by PyTorch.") self.optimizer = optimizer_class( self.model.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"], **optimizer_params ) # Initialize scheduler self.scheduler = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=params["milestones"], gamma=params["learning_rate_decay"] ) def _validate_params(self, params): """ Validates the input parameters for correct types and values. Parameters ---------- params : dict Dictionary of input parameters. """ required_keys = ["model_name", "device", "learning_rate", "weight_decay", "milestones", "learning_rate_decay"] for key in required_keys: if key not in params: raise KeyError(f"Missing required parameter: {key}") # Validate types and ranges if not isinstance(params["model_name"], str): raise TypeError("Parameter 'model_name' must be a string.") if not isinstance(params["device"], str): raise TypeError("Parameter 'device' must be a string.") if not isinstance(params["learning_rate"], float) or params["learning_rate"] <= 0: raise ValueError("Parameter 'learning_rate' must be a positive float.") if not isinstance(params["weight_decay"], float) or params["weight_decay"] < 0: raise ValueError("Parameter 'weight_decay' must be a non-negative float.") if not isinstance(params["milestones"], list) or not all(isinstance(x, int) for x in params["milestones"]): raise TypeError("Parameter 'milestones' must be a list of integers.") if not isinstance(params["learning_rate_decay"], float) or params["learning_rate_decay"] <= 0 or params["learning_rate_decay"] > 1.0: raise ValueError("Parameter 'learning_rate_decay' must be a positive float smaller than 1.0.")
[docs] def get_flat_parameters(self): """ Returns model parameters in a flat array. Returns ------- list Flat list of model parameters. """ return flatten_dict(self.model.state_dict())
[docs] def get_flat_gradients(self): """ Returns model gradients in a flat array. Returns ------- list Flat list of model gradients. """ return flatten_dict(self.get_dict_gradients())
[docs] def get_dict_parameters(self): """ Returns model parameters in a dictionary format. Returns ------- collections.OrderedDict Dictionary of model parameters. """ return self.model.state_dict()
[docs] def get_dict_gradients(self): """ Returns model gradients in a dictionary format. Returns ------- collections.OrderedDict Dictionary of model gradients. """ new_dict = collections.OrderedDict() for key, value in self.model.named_parameters(): new_dict[key] = value.grad return new_dict
[docs] def set_parameters(self, flat_vector): """ Sets model parameters using a flat array. Parameters ---------- flat_vector : list Flat list of parameters to set. """ new_dict = unflatten_dict(self.model.state_dict(), flat_vector) self.model.load_state_dict(new_dict)
[docs] def set_gradients(self, flat_vector): """ Sets model gradients using a flat array. Parameters ---------- flat_vector : list Flat list of gradients to set. """ new_dict = unflatten_generator(self.model.named_parameters(), flat_vector) for key, value in self.model.named_parameters(): value.grad = new_dict[key].clone().detach()
[docs] def set_model_state(self, state_dict): """ Sets the state_dict of the model. Parameters ---------- state_dict : dict Dictionary containing model state. """ self.model.load_state_dict(state_dict)