Custom PyTorch Model#

Base Class#

class octaipipe.model_classes.fl_aquarium.base_pytorch.BasePytorch(loss_fn: Optional[str] = None, scaling: Optional[str] = None, metric: Optional[str] = None, epochs: int = 5, batch_size: int = 32, input_shape: Optional[int] = None, output_shape: Optional[int] = None, **kwargs)#


Federated learning can be run with a wide array of different models and model setups. As there are far too many models and deep learning architectures out there to include in OctaiPipe itself, we have instead added support for custom FL models. This guide explains the Base PyTorch model in OctaiPipe, and details how to build on it to make a custom model.

This guide together with the guide on configuring the FL Train Step can be used to set up and configure or customize federated learning in OctaiPipe.

Following is a step-by-step guide for how to build a custom model that is compatible with the FL pipeline in OctaiPipe. The model uses the Base PyTorch model as a template from which to build a model.

1. Using a Base Model#

Each model implemented for FL in OctaiPipe should inherit from a base model, that is implemented in OctaiPipe already. In this example, we inherit from the Base Pytorch Model.

The base PyTorch model implements the following methods:

  • __init__: the initialization of the model

  • _build_model: setup of architecture for model

  • setup_loaders: a method for setting up dataloaders

  • forward: the process for forward propagation through network

  • train_model: method that runs training on data

  • test_model: method that runs testing on trained model

Any of these methods can be overwritten by simply implmenting it in the child class. Such, a custom architecture can be implemented by specifying a custom _build_model method.

As a word of caution, it is useful to go through dependencies in the base class when implementing methods as some methods rely on other methods in the class. For example, the __init__ method calls the _build_model method, so changing either could have effects on the other.

It is also worth mentioning here that the base PyTorch model implements either z-scoring or min-max scaling depending on arguments given to __init__. In order to get appropriate results during inference, input data needs to be scaled accordingly. If no scaling method is selected, features are not scaled.

Next we will go through the methods in the base class in more detail in order to see how it works and how we might implement a custom version.

2. Building our own custom methods#

The __init__ method in the current implementation sets up the model class with things like input shape and number and type of evaluation metric. The key thing to note about the __init__ method is that any arguments that the user wants to hanf to the model class should go here. In the FL config file, the model_params field in the model_specs is handed to the __init__ method as a dictionary. If arguments need to be handed to other methods, the arguments from __init__ can be saved as class attributes. It is important that any custom __init__ method initializes the parent classes with super().__init__().

Example init function with custom argument for evaluation metric#
def __init__(self, metric_function):
    self._metric_function = metric_function

The _build_model method specifies the components of the model. It is here that specific layers of the model are defined. Here, the PyTorch Documentation is more useful for defining the method. Likewise, the forward method specifies the pattern the data takes through the components defined in the _build_model method. Implementing these classes together builds the architecture of the network.

Example _build_model and forward methods using torch nn and nn.functional modules#
def _build_model(self):
    self.in_layer = nn.Linear(20, 40)
    self.dropout = nn.Dropout(p=0.2)
    self.out_layer = nn.Linear(40, 1)

def forward(self, x):
    x = F.Relu(self.in_layer(x))
    x = self.dropout(x)
    x = self.out_layer(x)
    return x

The setup_loaders method is named for its functionality and sets up the PyTorch dataloaders for the model. It loads the training and testing data and runs the scaling on the features before making dataloaders for train, validation and test. Customizing this method gives the user the ability to define their own data generators and to implement their own preprocessing steps. PyTorch has more information on their dataloaders on this page.

The key to the setup_loaders method is that it takes the FL step as an input. This means that all the methods in OctaiPipe’s PipelineStep and FL Train Step are available within the setup_loaders method. The FL step has the _input_data_specs and _evaluation_data_specs attributes, which can be used with the _load_data method to get data from a datasource. Thus, the user can define their own data ingestion functions within a custom setup_loaders. It is worth noting that the _load_data method uses self._input_data_specs to load data. To get evaluation data, assign self._evaluation_data_specs as self._input_data_specs and use _load_data.

Example setup_loaders method that implements custom train/val split#
def setup_loaders(self, train_step, target_label):
    X_train, X_test, y_train, y_test = train_step.load_datasets(target_label)

    def make_loader(X, y):
        X = X.to_numpy(dtype=np.float64)
        y = y.to_numpy(dtype=np.float64)
        tensor_x = torch.Tensor(X)
        tensor_y = torch.Tensor(y)

        data = TensorDataset(tensor_x, tensor_y)
        loader = DataLoader(data, batch_size=16, shuffle=False)
        return loader

    split_idx = X_train.sample(frac=0.9).index

    X_val = X_train.copy().loc[~split_idx]
    y_val = y_train.copy().loc[~split_idx]

    X_train = X_train.copy().loc[split_idx]
    y_train = y_train.copy().loc[split_idx]

    trainloader = make_loader(X_train, y_train)
    valloader = make_loader(X_val, y_val)
    testloader = make_loader(X_test, y_test)

    return trainloader, valloader, testloader

The train_model and test_model methods implement the training and testing of the model by running forward and backward propagation. This is where the loss function and optimizers are implemented. For FL, the train_model method is called in the FL client’s fit method and the test_model method is called in the evaluate method. This is the process runs federated learning.

Example train_model and test_model#
def train_model(self, trainloader, valloader):
        opt = torch.optim.SGD(self.parameters(), momentum=.9, lr=.00001)
        criterion = nn.L1Loss()

        for epoch in range(5):
            train_loss = 0.0
            for X_, y_ in trainloader:
                X_, y_ ='cpu'),'cpu')
                y_ = y_[:, None]
                loss = criterion(self(X_), y_)
                train_loss += loss
            val_loss, val_score = self.test_model(valloader)
            print(f'Epoch {epoch}. Train loss: {train_loss}.'
                f'Validation loss: {val_loss}. Validation score: {val_score}')

    def test_model(self, testloader):
        loss = 0.0
        criterion = nn.L1Loss()
        predictions = []
        y_true = []
        with torch.no_grad():
            for X_, y_ in testloader:
                X_, y_ ='cpu'),'cpu')
                y_pred = self(X_)
                y_ = y_[:, None]
                loss += criterion(y_pred, y_).item()

        score = self._metric_function(y_true, predictions)

        return loss, score


If you want to use the FedProx aggregation strategy with your PyTorch model, you need to add a clause to the train_model method to include a proximal term in the loss function during training. This adjustment is necessary to account for the proximal term required by FedProx, as described in the paper Federated Optimization in Heterogeneous Networks. The loss function modification would look like this:

Example train_model with FedProx#
def train_model(self, trainloader, valloader, proximal_mu: float | None):
    global_params = copy.deepcopy(self).parameters()
    opt = torch.optim.SGD(self.parameters(), momentum=.9, lr=.00001)
    criterion = nn.L1Loss()

    for epoch in range(5):
        train_loss = 0.0
        for X_, y_ in trainloader:
            X_, y_ ='cpu'),'cpu')
            y_ = y_[:, None]
            proximal_term = 0.0
            loss = criterion(self(X_), y_)
            if proximal_mu is not None:
                for local_weights, global_weights in zip(self.parameters(), global_params):
                    proximal_term += (local_weights - global_weights).norm(2)
                loss += (proximal_mu / 2) * proximal_term
            train_loss += loss
        val_loss, val_score = self.test_model(valloader)
        print(f'Epoch {epoch}. Train loss: {train_loss}.'
              f'Validation loss: {val_loss}. Validation score: {val_score}')

The proximal_mu value will be passed to the train_model function automatically if using OctaiPipe’s PyTorch client and the FedProx strategy.

3. Using a custom model for FL in OctaiPipe#

To use a custom model for FL in OctaiPipe, the user needs to assign the model a unique model type and point to the Python file with the model class in the model_specs in the FL configs when running OctaiFL. An example is provided below:

 1 model_specs:
 2   type: improved_pytorch_model
 3   load_existing: false
 4   name: degradation_regressor_for_windmill
 5   model_load_specs:
 6     version: '000'
 7     from_blob: false
 8   model_params:
 9     metric_function: custom_metric_function
10   custom_model:
11     file_path: path/to/

Here, we have defined a type that does not already exist in OctaiPipe. If this is done, and there is not already a model with that type in the database, OctaiFL expects the user to provide custom_model specs. Here, the file_path to the model class is provided. If the model already exists, a new record will be created with an incremented version number. If no model exists, version number will be assigned as 1.

In model_load_specs, from_blob is set to False. This variable defines whether a new version of a model that already exists on a device should be pulled from blob. In other words, if this is false and we have already pulled a model with the same name, the model will be taken from local storage on the device. To force pulling a new model from blob, set this to True.