FL Train Step#

When you trigger an FL train step, a server is set up, which communicates with the FL clients, hosted on edge devices. On the edge devices, this is handled by an OctaiPipe Pipeline step (OctaiPipe Steps). called the FL Train Step. The following guide goes through the FL Train Step, showing the user how to configure it and how it can be extended using custom pipeline steps.

Methods in the FL Step#

The FL step inherits from the base PipelineStep in OctaiPipe. It uses PipelineStep as well as own methods to set up and run FL training. The following methods are implemented in the FL Train Step:

  • __init__

  • _get_model

  • load_datasets

  • run

The __init__ method initializes the class by initializing the PipelinStep parent class as well as checking the input and evaluation data specs using the _check_data method. The model is also initialized using the _get_model method.

The _get_model method checks the model_specs to see if the model type is in the model mapping from the default OctaiPipe models. If not, it attempts to retrieve the model from a local custom mapping or download it from blob storage.

The load_datasets method uses the PipelineStep’s _load_data method to first load the training data, then test dataset. This method gets called in the setup_loaders method in the model. This is so that users can define their own generators using custom models. For more information on how to use custom models, check out the documentation on custom FL models, Federated PyTorch.

The run method is the method that actually runs federated learning. It does so by calling the setup_loaders method in the model class, setting up the relevant client for the framework, and running the client. The run method takes the server_ip as an argument to hand to the client.

Configuring the FL train step#

Below is an example of the config file used to set up federated learning. For the FL Train Step, the infrastructure field is not included. The run specs are popped and given to the run method and the rest are given to the method on initialization.

  1name: federated_learning
  2
  3infrastructure:
  4  server: kubernetes
  5  backup_server: [deviceId]
  6  device_ids: [FL-01, FL-02, FL-03, FL-04]
  7
  8input_data_specs:
  9  default:
 10    - datastore_type: influxdb
 11      settings:
 12        query_type: dataframe
 13        query_template_path: ./configs/data/influx_query_def.txt
 14        query_config:
 15          start: "2022-11-10T00:00:00.000Z"
 16          stop: "2022-11-11T00:00:00.000Z"
 17          bucket: cmapss-bucket
 18          measurement: sensors-raw
 19          tags: {}
 20  FL-01:
 21    - datastore_type: influxdb
 22      settings:
 23        query_type: dataframe
 24        query_template_path: ./configs/data/influx_query_1.txt
 25        query_config:
 26          start: "2022-11-10T00:00:00.000Z"
 27          stop: "2022-11-11T00:00:00.000Z"
 28          bucket: cmapss-bucket
 29          measurement: sensors-raw
 30          tags: {}
 31  FL-02:
 32    - datastore_type: influxdb
 33      settings:
 34        query_type: dataframe
 35        query_template_path: ./configs/data/influx_query_2.txt
 36        query_config:
 37          start: "2022-11-10T00:00:00.000Z"
 38          stop: "2022-11-11T00:00:00.000Z"
 39          bucket: cmapss-bucket
 40          measurement: sensors-raw
 41          tags: {}
 42  FL-03:
 43    - datastore_type: influxdb
 44      settings:
 45        query_type: dataframe
 46        query_template_path: ./configs/data/influx_query_3.txt
 47        query_config:
 48          start: "2022-11-10T00:00:00.000Z"
 49          stop: "2022-11-11T00:00:00.000Z"
 50          bucket: cmapss-bucket
 51          measurement: sensors-raw
 52          tags: {}
 53  FL-04:
 54    - datastore_type: influxdb
 55      settings:
 56        query_type: dataframe
 57        query_template_path: ./configs/data/influx_query_4.txt
 58        query_config:
 59          start: "2022-11-10T00:00:00.000Z"
 60          stop: "2022-11-11T00:00:00.000Z"
 61          bucket: cmapss-bucket
 62          measurement: sensors-raw
 63          tags: {}
 64
 65evaluation_data_specs:
 66  default:
 67    - datastore_type: influxdb
 68      settings:
 69        query_type: dataframe
 70        query_template_path: ./configs/data/influx_query_eval_def.txt
 71        query_config:
 72          start: "2022-11-10T00:00:00.000Z"
 73          stop: "2022-11-11T00:00:00.000Z"
 74          bucket: cmapss-bucket
 75          measurement: sensors-raw
 76          tags: {}
 77  FL-01:
 78    - datastore_type: influxdb
 79      settings:
 80        query_type: dataframe
 81        query_template_path: ./configs/data/influx_query_eval_1.txt
 82        query_config:
 83          start: "2022-11-10T00:00:00.000Z"
 84          stop: "2022-11-11T00:00:00.000Z"
 85          bucket: cmapss-bucket
 86          measurement: sensors-raw
 87          tags: {}
 88  FL-02:
 89    - datastore_type: influxdb
 90      settings:
 91        query_type: dataframe
 92        query_template_path: ./configs/data/influx_query_eval_2.txt
 93        query_config:
 94          start: "2022-11-10T00:00:00.000Z"
 95          stop: "2022-11-11T00:00:00.000Z"
 96          bucket: cmapss-bucket
 97          measurement: sensors-raw
 98          tags: {}
 99  FL-03:
100    - datastore_type: influxdb
101      settings:
102        query_type: dataframe
103        query_template_path: ./configs/data/influx_query_eval_3.txt
104        query_config:
105          start: "2022-11-10T00:00:00.000Z"
106          stop: "2022-11-11T00:00:00.000Z"
107          bucket: cmapss-bucket
108          measurement: sensors-raw
109          tags: {}
110  FL-04:
111    - datastore_type: influxdb
112      settings:
113        query_type: dataframe
114        query_template_path: ./configs/data/influx_query_eval_4.txt
115        query_config:
116          start: "2022-11-10T00:00:00.000Z"
117          stop: "2022-11-11T00:00:00.000Z"
118          bucket: cmapss-bucket
119          measurement: sensors-raw
120          tags: {}
121
122model_specs:
123  type: base_torch
124  load_existing: false
125  name: test_torch
126  model_load_specs:
127    version: '000'
128  model_params:
129    loss_fn: mse
130    scaling: standard
131    metric: rmse
132    epochs: 10
133    batch_size: 32
134
135run_specs:
136  target_label: RUL
137  cycle_id: "Machine number"
138  backend: pytorch

The input_data_specs and evaluation_data_specs define the configuration for how to get the training and evaluation data. The output_data_specs are not used in the current FL Train Step but can be used for saving any data if custom implementations are wanted.

The model_specs define which model to use, whether it be a native OctaiPipe model or a custom model. Important here are the model_params, which get handed to the model on initialization. For the default PyTorch model, this includes things such as number of epochs and, which loss function to use and the batch size.

The run_specs as mentioned are passed to the run method. This requires the a target_label (outcome variable column name) to be defined. The cycle_id here that the training and validation sets contain data for a certain proportion of cycles is the column which defines an operating cycle. The data can be grouped on this so rather than a proportion of rows of data. The backend is which FL client to use. For example, for PyTorch this would be “pytorch”.

Making a custom FL train step#

In order to make a completely customized FL Train Step, the user can define a custom pipeline step. This guide will not go through in detail how that is done, but it is worth noting that a custom pipeline step needs to implement a run method which initializes a client and starts it, linking it to the server_ip from the run_specs.

To implement a custom step, it is also important to understand any model class being used, whether it is a native OctaiPipe model or a custom model.

To get more information on custom OctaiPipe pipeline steps, see this guide: Custom Pipeline Steps

To further understand the base PyTorch model and to understand how to implement a custom PyTorch model, see this guide: Federated PyTorch