Source code for byzfl.fed_framework.models

import torch
import torch.nn as nn
import torch.nn.functional as F

"""
Models Module
=============
This module contains a collection of models designed for various datasets such as MNIST and CIFAR. 
These models include fully connected networks, convolutional neural networks, logistic regression, 
and ResNet architectures.

Available Models
----------------

### MNIST Models

1. **`fc_mnist`**
   - **Type**: Fully connected neural network.
   - **Description**: A simple two-layer fully connected network designed for the MNIST dataset.
   - **Input Shape**: `28 x 28` (flattened into a vector).
   - **Output Classes**: 10.

2. **`cnn_mnist`**
   - **Type**: Convolutional neural network.
   - **Description**: A small convolutional neural network tailored for MNIST.
   - **Input Shape**: `1 x 28 x 28`.
   - **Output Classes**: 10.

3. **`logreg_mnist`**
   - **Type**: Logistic regression model.
   - **Description**: A simple logistic regression model for MNIST.
   - **Input Shape**: `28 x 28` (flattened into a vector).
   - **Output Classes**: 10.

### CIFAR Models

1. **`cnn_cifar_old`**
   - **Type**: Convolutional neural network.
   - **Description**: A small convolutional network for CIFAR datasets.
   - **Input Shape**: `3 x 32 x 32`.
   - **Output Classes**: 10.

2. **`cnn_cifar`**
   - **Type**: Convolutional neural network.
   - **Description**: An updated and efficient convolutional network for CIFAR datasets.
   - **Input Shape**: `3 x 32 x 32`.
   - **Output Classes**: 10.

3. **`cifar_Net`**
   - **Type**: Convolutional neural network.
   - **Description**: Another small convolutional network for CIFAR datasets.
   - **Input Shape**: `3 x 32 x 32`.
   - **Output Classes**: 10.

### ResNet Models

ResNet models are general-purpose convolutional neural networks capable of handling datasets like CIFAR-10 and CIFAR-100.

1. **`ResNet18`**
   - **Description**: ResNet with 18 layers.

2. **`ResNet34`**
   - **Description**: ResNet with 34 layers.

3. **`ResNet50`**
   - **Description**: ResNet with 50 layers.

4. **`ResNet101`**
   - **Description**: ResNet with 101 layers.

5. **`ResNet152`**
   - **Description**: ResNet with 152 layers.

Notes
-----
- All models are subclasses of `torch.nn.Module` and are compatible with PyTorch training pipelines.
- The ResNet implementations support custom class numbers via the `num_classes` parameter.
"""

[docs] class fc_mnist(nn.Module): """ Fully Connected Network for MNIST. Description: ------------ A simple fully connected neural network for the MNIST dataset, consisting of two fully connected layers with ReLU activation and softmax output. Examples: --------- >>> model = fc_mnist() >>> x = torch.randn(16, 28*28) # Batch of 16 MNIST images >>> output = model(x) >>> print(output.shape) torch.Size([16, 10]) """ def __init__(self): """Initialize the model parameters.""" super().__init__() self._f1 = nn.Linear(28 * 28, 100) self._f2 = nn.Linear(100, 10)
[docs] def forward(self, x): """Perform a forward pass through the model.""" x = F.relu(self._f1(x.view(-1, 28 * 28))) x = F.log_softmax(self._f2(x), dim=1) return x
[docs] class cnn_mnist(nn.Module): """ Convolutional Neural Network for MNIST. Description: ------------ A simple convolutional neural network designed for the MNIST dataset. It consists of two convolutional layers, ReLU activation, max pooling, and fully connected layers. Examples: --------- >>> model = cnn_mnist() >>> x = torch.randn(16, 1, 28, 28) # Batch of 16 grayscale MNIST images >>> output = model(x) >>> print(output.shape) torch.Size([16, 10]) """ def __init__(self): """Initialize the model parameters.""" super().__init__() self._c1 = nn.Conv2d(1, 20, 5, 1) self._c2 = nn.Conv2d(20, 50, 5, 1) self._f1 = nn.Linear(800, 500) self._f2 = nn.Linear(500, 10)
[docs] def forward(self, x): """Perform a forward pass through the model.""" x = F.relu(self._c1(x)) x = F.max_pool2d(x, 2, 2) x = F.relu(self._c2(x)) x = F.max_pool2d(x, 2, 2) x = F.relu(self._f1(x.view(-1, 800))) x = F.log_softmax(self._f2(x), dim=1) return x
[docs] class logreg_mnist(nn.Module): """ Logistic Regression Model for MNIST. Description: ------------ A simple logistic regression model for the MNIST dataset. It consists of a single linear layer. Examples: --------- >>> model = logreg_mnist() >>> x = torch.randn(16, 28*28) # Batch of 16 MNIST images >>> output = model(x) >>> print(output.shape) torch.Size([16, 10]) """ def __init__(self): """Initialize the model parameters.""" super().__init__() self._linear = nn.Linear(784, 10)
[docs] def forward(self, x): """Perform a forward pass through the model.""" return torch.sigmoid(self._linear(x.view(-1, 784)))
# ---------------------------------------------------------------------------- #
[docs] class cnn_cifar(nn.Module): """ Convolutional Neural Network for CIFAR. Description: ------------ A convolutional neural network designed for the CIFAR-10 and CIFAR-100 datasets. It consists of three convolutional layers, max pooling, and fully connected layers. Examples: --------- >>> model = cnn_cifar() >>> x = torch.randn(16, 3, 32, 32) # Batch of 16 CIFAR images >>> output = model(x) >>> print(output.shape) torch.Size([16, 10]) """ def __init__(self): """Initialize the model parameters.""" super().__init__() self.conv1 = nn.Conv2d(3, 20, 5, padding=2) self.conv2 = nn.Conv2d(self.conv1.out_channels, 100, 5, padding=2) self.conv3 = nn.Conv2d(self.conv2.out_channels, 200, 5, padding=2) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(self.conv3.out_channels * 4 * 4, 512) self.fc2 = nn.Linear(self.fc1.out_features, 256) self.fc3 = nn.Linear(self.fc2.out_features, 10)
[docs] def forward(self, x): """Perform a forward pass through the model.""" x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = self.pool(F.relu(self.conv3(x))) x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) return self.fc3(x)
class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) return F.relu(out) class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_planes, planes, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(self.expansion * planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += self.shortcut(x) return F.relu(out) class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super().__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x, out_feature=False): x = F.relu(self.bn1(self.conv1(x))) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = F.avg_pool2d(x, 4) feature = x.view(x.size(0), -1) out = self.linear(feature) return (out, feature) if out_feature else out
[docs] class ResNet18(nn.Module): """ Description: ------------ ResNet18 is a convolutional neural network architecture with 18 layers. It is designed for image classification tasks and includes skip connections for efficient gradient flow. Parameters: ------------ num_classes : int The number of output classes for classification (default is 10). Examples: --------- >>> model = ResNet18(num_classes=10) >>> x = torch.randn(16, 3, 32, 32) # Batch of 16 CIFAR images >>> output = model(x) >>> print(output.shape) torch.Size([16, 10]) """ def __init__(self, num_classes=10): super().__init__() self.model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
[docs] def forward(self, x): """Perform a forward pass through the model.""" return self.model(x)
[docs] class ResNet34(nn.Module): """ Description: ------------ ResNet34 is a deep convolutional neural network with 34 layers, designed for image classification tasks. It uses residual connections to improve gradient flow and enable training of very deep networks. Parameters: ------------ num_classes : int The number of output classes for classification (default is 10). Examples: --------- >>> model = ResNet34(num_classes=100) >>> x = torch.randn(8, 3, 32, 32) # Batch of 8 images with CIFAR-like dimensions >>> output = model(x) >>> print(output.shape) torch.Size([8, 100]) """ def __init__(self, num_classes=10): super().__init__() self.model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
[docs] def forward(self, x): """Perform a forward pass through the model.""" return self.model(x)
[docs] class ResNet50(nn.Module): """ Description: ------------ ResNet50 is a deeper ResNet variant with 50 layers. It employs the Bottleneck block to reduce computational complexity while maintaining accuracy, making it suitable for larger-scale datasets and more complex tasks. Parameters: ------------ num_classes : int The number of output classes for classification (default is 10). Examples: --------- >>> model = ResNet50(num_classes=1000) >>> x = torch.randn(16, 3, 224, 224) # Batch of 16 images with ImageNet-like dimensions >>> output = model(x) >>> print(output.shape) torch.Size([16, 1000]) """ def __init__(self, num_classes=10): super().__init__() self.model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
[docs] def forward(self, x): """Perform a forward pass through the model.""" return self.model(x)
[docs] class ResNet101(nn.Module): """ Description: ------------ ResNet101 is a deeper ResNet variant with 101 layers, designed for highly complex tasks. It leverages Bottleneck blocks to maintain performance while keeping computational costs manageable. Parameters: ------------ num_classes : int The number of output classes for classification (default is 10). Examples: --------- >>> model = ResNet101(num_classes=100) >>> x = torch.randn(4, 3, 64, 64) # Batch of 4 images >>> output = model(x) >>> print(output.shape) torch.Size([4, 100]) """ def __init__(self, num_classes=10): super().__init__() self.model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes)
[docs] def forward(self, x): """Perform a forward pass through the model.""" return self.model(x)
[docs] class ResNet152(nn.Module): """ Description: ------------ ResNet152 is the deepest ResNet variant among the standard configurations. With 152 layers, it is highly effective for complex tasks, including image classification, segmentation, and detection. The model achieves a balance between depth and computational feasibility using Bottleneck blocks. Parameters: ------------ num_classes : int The number of output classes for classification (default is 10). Examples: --------- >>> model = ResNet152(num_classes=10) >>> x = torch.randn(2, 3, 128, 128) # Batch of 2 high-resolution images >>> output = model(x) >>> print(output.shape) torch.Size([2, 10]) """ def __init__(self, num_classes=10): super().__init__() self.model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes)
[docs] def forward(self, x): """Perform a forward pass through the model.""" return self.model(x)