Types of Federated Learning
Organizations need to understand user behavior using data to improve their market position. Therefore, businesses solicit user feedback in a variety of ways. For example, Garmin, a well-known technology business, has a dedicated page for customers to submit ideas and suggestions. Similarly, Hotjar utilizes usability testing, whereas Zapier focuses on user feedback surveys.
Personalization is another important aspect linked to user experience. It is essential in today’s digital presence to engage consumers and provide a positive experience. A good example is Netflix, which scores high on personalizing the viewing experience. The user’s viewing habits and choices determine their preferences for genre preferences to make subsequent suggestions appropriately.
Netflix employs personalization to assist its customers in navigating the vast library of movies and television series. It grants the company access to user information, creating privacy concerns for users. Federated learning is beneficial. It can improve customer experience through personalization while preserving user privacy.
Source: Business Wire
What is Federated Learning?
Federated learning is an on-device machine learning where a model trains on end-user devices. No user data is sent to the server for training the model. Instead, the updated parameters are sent back to the central server for aggregation. The central server utilizes these summaries to update the global model and send this global model back to the user devices as a version update.
Federated learning is the solution in today’s world that offers personalized online systems while ensuring end-user data privacy. The cost-efficient system also provides faster information processing while making the model available worldwide.
Types of Federated Learning
Over time, federated learning has taken on various shapes and forms in the computing world. The variations depend on:
- Data Partitions
Types Based on the Schemes
Cross-Silo Federated Learning Model
A silo in information technology is a segregated data storage place for an organization that is not a part of the rest of the network. It contains unstructured, raw data with restricted access. As a result, the information is not readily available for usage or further processing to the outside network.
Numerous silos connected to a central server make up the cross-silo federated learning architecture. For example, multiple organizations can connect through a single network while keeping their raw data separate in their silos. It enables large-scale data processing across businesses while ensuring privacy.
Explaining the Architecture and Data Flow Through an Example
The architecture comprises end-users from various companies, a silo from each company, and a central server.
For example, above is a diagram of a cross-silo system connecting a community hospital, a local laboratory, and a medical research facility. Each entity deals with the community’s general population, giving them a similar sample space to work with.
It allows communication and creates a common database for access. Each setup’s data is subject to a privacy algorithm before being stored in its silo. Then, based on pre-defined parameters, a local machine-learning model is constructed and trained before sharing it with the central server.
For example, hospitals’ local machine learning model trains on the information about the distribution of diseases among patients. The laboratory also shares a regional model based on the types of tests run in the lab and their frequency. In each of these models, the data used is based on specific parameters to ensure the patients’ privacy and data security.
The server receives the 3 local ML models from all the three silos to generate a global model that is readily available for all entities in the network. For example, suppose the medical research center has to study a rising viral disease in the community, part of the cross-silo system. In that case, it can rely on the global model to check the spread of that viral disease and what tests can detect its presence.
The global model is then continually reviewed and updated as per changes in the local models. It is an iterative process that ensures that each new aggregation improves the information sharing within the network. As a result, interactions between the hospital, laboratory, and medical research center can be enhanced with a more versatile machine-learning model.
Cross-Device Federated Learning Model
This model of federated learning deals with data from the individual devices operating within a network. The multiple end-user devices, like laptops and mobile phones, act as the source of data to train the model locally. The central server then aggregates them to create a global model.
An excellent example would be Google using this setup to train the next-word predictions model. First, it creates a local language model based on the input from the user on that device to predict the next word. It is achieved by using sequential models trained in a federated way in which it ‘memorizes’ the past data input by the user to make predictions in the future.
Each user device generates a locally trained model on their data. As there are many end-user devices, training the model using such extensive data points allows it to improve its prediction accuracy. For example, an RNN model tested on a population of Engish speakers from the US used around 7.5 billion sentences for training. As a result, the prediction accuracy increased by almost 24%, resulting in approximately 10% more clicks on the prediction strip generated.
Every device then computes a stochastic gradient of the sequential model using one or more steps. Finally, a global model aggregates each device’s calculated gradients(parameters) with the central server.
Source: CMU ML Blog
As a result, the system becomes more efficient, improving the user experience. After all, who does not appreciate a correctly predicted next word when they are in the middle of drafting a text?
Types Based on the Data Partitions
Before we dive deeper into this categorization model, we must understand the following terms about federated learning:
- Feature Space — It defines the key characteristics used to categorize the available data set in the system
- Sample Space — It consists of the end-user devices/silos that provide the data set to the central server
Data is designed into three structural forms based on these two architectural specifications, leading to different kinds of federated systems. Let us take a look at them.
Horizontal Federated Learning Model
This structural form deals with a shared feature space between various clients within the network while the sample for each one remains different. It is also known as sample-based federated learning. It primarily deals with clients that have a homogenous set of data.
The horizontal federated learning model allows for establishing a multi-task federated learning system.
MOCHA — Multi-Task Federated Learning System
The system consists of multiple clients within the network. First, a local model consisting of standard features is trained locally and shared with the central server, creating a global model.
MOCHA is the federated multi-task learning framework to train the model through an iterative process. It can handle the divergence in data that occurs in a heterogeneous sample space. This multi-task federated learning model ensures low communication costs and addresses fault tolerance in the network.
Vertical Federated Learning Model
A vertical network deals with different features but shares a common sample space. It is also called feature-based federated learning. Primarily, this concept is used in B2B interaction and data sharing, where multiple organizations dealing with the same set of consumers share a common network.
Source: arXiv Vanity
A common framework used for vertical federated learning is called PyVertical. Let us take a closer look at the steps involved.
PyVertical — Framework for Vertical Federated Learning Model
The framework uses a two-step process to ensure the implementation of a vertical FL model. It provides data privacy and efficient model training. The steps consist of:
- PSI — Private Set Intersection: It is a cryptographic technique that allows multiple parties with different data sets to compute intersections. It ensures that the raw data of each party is private to ensure data privacy. In addition, it allows the parties to find a shared set of data common to all of them.
- SplitNN — Split Neural Networks: It is a training model split into different segments, each handled by various parties within the network. It treats each data set as an individual neural network fed into the next segment during the training process. Each party within the network trains its data set and sends it to the next one. It involves more parties in the training process, leading to enhanced computational efficiency.
The PyVertical then trains the data in three steps: vertical partitioning, PSI, and a SplitNN.
- The data is partitioned vertically into images and labels, collectively called data points. Each data point is then assigned a unique ID and randomly shuffled.
- PSI now helps to link different data sets using unique IDs. Data points that form the intersection are filtered out.
- The network is split to train the images and labels.
It allows vertical federated learning models to deal with heterogeneous data efficiently.
Federated Transfer Learning Model
This model of federated learning integrates both the horizontal and vertical federated learning systems. It allows information exchange across domains without many standard features or sample spaces. Furthermore, it enables different entities to share a global model without sharing a common feature space, offering privacy to their data.
The idea is to train a model on a large dataset for a specific problem and then apply it to a different issue within a similar domain. It makes the model depend on the domain of various problems being similar.
Federated Transfer Learning (FTL) in Autonomous Driving
Autonomous driving uses reinforcement learning (RL), an FTL model based on a collection of observations made and actions taken. Without any intervention from a human driver, the success of this application is entirely dependent on the practical training of each model.
Each autonomous vehicle is an independent entity in the system without many common features or spaces with other cars of the same type. Each car undergoes its reinforcement learning process, which can be shared via the online transfer of information.
Source: FTL — Concept, and Applications
The architecture consists of three parts, each handling a different communication link in the network. That is:
- Sensing and Data Acquisition: The first step in the reinforcement learning process is to gather data based on the observations and actions of an autonomous vehicle. The automobile senses and collects information based on its interaction with the people, other cars, and several other variables in its surrounding. The data is used to train the model locally for that vehicle.
- Vehicular Cloud: The local model, created through RL, is saved in the cloud on that specific automobile. The cloud is essential for the storage of the local model as it needs to be optimized repeatedly based on the observations and actions of the vehicle over time.
- Internet Cloud: The vehicular cloud ensures that the local model is available to the central server, the internet cloud in this case. The central server regularly obtains the locally created RL models and averages them using the federated averaging method.
This allows the creation of a global model, which can, in turn, help optimize the RL models of all autonomous vehicles connected via the internet cloud.
Some Commonly Used FL Strategies
A federated learning system involves multiple clients which have a huge amount of data. Any function dealing with this network has to create local models which can be converged to a common global model for use by the entire system.
Below is the basic structure that can define a federated learning model.
where x is a model parameter, E is the expectation over the model parameters from client data, and fi is the loss function w.r.t. parameters and client data.
As you can see, a loss function is defined for each client making use of the client data and specific parameters. Once a loss function is created, expectations over the client data are implemented on the function to extract only the needed information. An aggregation of expectations from all clients is then created to define the model parameters at the global level.
Different approaches are used to carry out the above processes. Some notable ones are as follows:
SGD — Stochastic Gradient Descent
As the name indicates, this is an algorithm that travels down the gradient of a function on each iteration. The main goal is to converge the gradient to its minimum value.
In implementing SGD, the database from the client is used to compute a single stochastic gradient for the individual loss function. The central server receives all these gradients from multiple clients and averages them.
It creates a synchronous model as each client brings the entire data to a single gradient point before averaging it. One disadvantage of SGD is the slow processing of the data.
FedAvg — Federated Averaging
FedAvg is a faster and better approach to implementing a cross-silo model. Unlike the SGD, this allows the clients to take multiple gradient steps, leading to a better-estimated value.
The following equation defines each step taken to calculate SGD at the client end. In FedAvg, this step is repeated multiple times on each client in S.
The central server then averages these SGD values.
These updated gradient values from each client are shared with the central server which aggregates them to create a global model. This makes the process faster as some part of the optimization of the model is done locally at the client’s end.
The problem is that centralized updates are not possible in such a model. Also, if the data from clients are heterogeneous, meaning it varies greatly, the model might not converge.
SCAFFOLD — Stochastic Controlled Averaging for Federated Learning
The issues in FedAvg lead to yet another solution in the form of SCAFFOLD. As the name indicates, the calculations are more controlled in this method. It deals with the convergence issue of FedAvg for heterogeneous data by introducing a correction value at each gradient calculation in the iterations done locally.
Where c — ci is the correction term.
The results from these local iterations are then sent to the central server for averaging, similar to the other methods. While the process is faster due to local optimizations in gradient calculations, the correction term also ensures that the data is centralized and converges easily.
Federated Learning is the modern solution to personalized digital spaces for end-users. With its fast data processing and efficient privacy to end-user data, it is the need of the hour for the digital spaces today. It allows the creation of a digital network between multiple clients without compromising privacy, ensuring safe and unbiased data models for collaboration. It also offers numerous schemes and data partitions in the models to cater to the diverse range of requirements of digital spaces.
While communication between multiple clients is favored by cross-silo federated learning, the cross-device model ensures secure data exchange between end-user devices. In addition to this, there are three different structural models, depending on the features defining a set of data and the sample space it covers. These types include: horizontal, vertical, and transfer federated learning models.
The choice of a learning model for any client is highly dependent on the type of data exchange required by the organization and the common factor between each set of data in the network. It is essential to make the right choice to reap all the benefits of federated learning, but the first step remains to study and understand the types of federated learning models.
Author: Huda Mahmood