Trimmed Mean#

class byzfl.TrMean(f=0)[source]#

Description#

Compute the trimmed mean (or truncated mean) along the first axis [1]:

\[\big[\mathrm{TrMean}_{f} \ (x_1, \dots, x_n)\big]_k = \frac{1}{n - 2f}\sum_{j = f+1}^{n-f} \big[x_{\pi(j)}\big]_k\]

where

  • \(x_1, \dots, x_n\) are the input vectors, which conceptually correspond to gradients submitted by honest and Byzantine participants during a training iteration.

  • \(f\) conceptually represents the expected number of Byzantine vectors.

  • \(\big[\cdot\big]_k\) refers to the \(k\)-th coordinate.

  • \(\pi\) denotes a permutation on \(\big[n\big]\) that sorts the \(k\)-th coordinate of the input vectors in non-decreasing order, i.e., \(\big[x_{\pi(1)}\big]_k \leq …\leq \big[x_{\pi(n)}\big]_k\).

In other words, TrMean removes the \(f\) largest and \(f\) smallest coordinates per dimension, and then applies the average over the remaining coordinates.

Initialization parameters:

f (int, optional) – Number of faulty vectors. Set to 0 by default.

Calling the instance

Input parameters:

vectors (numpy.ndarray, torch.Tensor, list of numpy.ndarray or list of torch.Tensor) – A set of vectors, matrix or tensors.

Returns:

numpy.ndarray or torch.Tensor – The data type of the output will be the same as the input.

Examples

>>> import byzfl
>>> agg = byzfl.TrMean(1)

Using numpy arrays

>>> import numpy as np
>>> x = np.array([[1., 2., 3.],       # np.ndarray
>>>               [4., 5., 6.],
>>>               [7., 8., 9.]])
>>> agg(x)
array([4. 5. 6.])

Using torch tensors

>>> import torch
>>> x = torch.tensor([[1., 2., 3.],   # torch.tensor
>>>                   [4., 5., 6.],
>>>                   [7., 8., 9.]])
>>> agg(x)
tensor([4., 5., 6.])

Using list of numpy arrays

>>> import numpy as np
>>> x = [np.array([1., 2., 3.]),      # list of np.ndarray
>>>      np.array([4., 5., 6.]),
>>>      np.array([7., 8., 9.])]
>>> agg(x)
array([4., 5., 6.])

Using list of torch tensors

>>> import torch
>>> x = [torch.tensor([1., 2., 3.]),  # list of torch.tensor
>>>      torch.tensor([4., 5., 6.]),
>>>      torch.tensor([7., 8., 9.])]
>>> agg(x)
tensor([4., 5., 6.])

References