Federated Learning Meets Multi-objective Optimization

Zeou Hu, Kiarash Shaloudegi, Guojun Zhang, and Yaoliang Yu Zeou Hu and Yaoliang Yu are with the Cheriton School of Computer Science, University of Waterloo, Waterloo, ON, N2L 3G1. E-mail: {zeou.hu,yaoliang.yu}@uwaterloo.ca; Kiarash Shaloudegi is with Amazon Advertising. Work was done while at Huawei Noah Ark’s Lab, Montreal, QC, H3N 1X9. E-mail: kiarashs@amazon.com; Guojun Zhang is with Huawei Noah Ark’s Lab, Montreal, QC, H3N 1X9. E-mail: guojun.zhang@huawei.com
Abstract

Federated learning has emerged as a promising, massively distributed way to train a joint deep model over large amounts of edge devices while keeping private user data strictly on device. In this work, motivated from ensuring fairness among users and robustness against malicious adversaries, we formulate federated learning as multi-objective optimization and propose a new algorithm FedMGDA+ that is guaranteed to converge to Pareto stationary solutions. FedMGDA+ is simple to implement, has fewer hyperparameters to tune, and refrains from sacrificing the performance of any participating user. We establish the convergence properties of FedMGDA+ and point out its connections to existing approaches. Extensive experiments on a variety of datasets confirm that FedMGDA+ compares favorably against state-of-the-art.

Keywords Pareto optimization, Distributed algorithms, Federated learning, Edge computing, Machine learning, Neural networks.

1 Introduction

Deep learning has achieved impressive successes on a number of domain applications, thanks largely to innovations on algorithmic and architectural design, and equally importantly to the tremendous amount of computational power one can harness through GPUs, computer clusters and dedicated software and hardware. Edge devices, such as smart phones, tablets, routers, car devices, home sensors, etc., due to their ubiquity and moderate computational power, impose new opportunities and challenges for deep learning. On the one hand, edge devices have direct access to privacy sensitive data that users may be reluctant to share (with say data centers), and they are much more powerful than their predecessors, capable of conducting a significant amount of on-device computations. On the other hand, edge devices are largely heterogeneous in terms of capacity, power, data, availability, communication, memory, etc., posing new challenges beyond conventional in-house training of machine learning models. Thus, a new paradigm, known as federated learning (FL) (McMahan et al.,, 2017) that aims at harvesting the prospects of edge devices, has recently emerged. Developing new FL algorithms and systems on edge devices has since become a hot research topic in machine learning.

From the beginning of its birth, FL has close ties to conventional distributed optimization. However, FL emerged from the pressing need to address news challenges in the mobile era that existing distributed optimization algorithms were not designed for per se. We mention the following characteristics of FL that are most relevant to our work, and refer to the excellent surveys (Li et al.,, 2019; Yang et al.,, 2019; Kairouz et al.,, 2019) and the references therein for more challenges and applications in FL.

  • Non-IID: Each user’s data can be distinctively different from every other user’s, violating the standard iid assumption in statistical learning and posing significant difficulty in formulating the goal in precise mathematical terms (Mohri et al.,, 2019). The distribution of user data is often severely unbalanced.

  • Limited communication: Communication between each user and a central server is constrained by network bandwidth, device status, user participation incentive, etc., demanding a thoughtful balance between computation (on each user device) and communication.

  • Privacy: Protecting user (data) privacy is of uttermost importance in FL. It is thus not possible to share user data (even to a cloud arbitrator), which adds another layer of difficulty in addressing the previous two challenges.

  • Fairness: As argued forcibly in recent work (e.g., Mohri et al., 2019; Li et al., 2020b ), ensuring fairness among users has become another serious goal in FL, as it largely determines users’ willingness to participate and ensures some degree of robustness against malicious user manipulations.

  • Robustness: FL algorithms are eventually deployed in the wild hence subject to malicious attacks. Indeed, adversarial attacks (e.g., Bagdasaryan et al., 2020; Sun et al., 2019; Bhagoji et al., 2019) have been constructed recently to reveal vulnerabilities of FL systems against malicious manipulations at the user side.

In this work, motivated from the last two challenges above, i.e.fairness and robustness, we propose a new algorithm FedMGDA+ that complements and improves existing FL systems. FedMGDA+ is based on multi-objective optimization and is guaranteed to converge to Pareto stationary solutions. FedMGDA+ is simple to implement, has fewer hyperparameters to tune, and most importantly refrains from sacrificing the performance of any participating user. We demonstrate the superior performance of FedMGDA+ under a variety of metrics including accuracy, fairness, and robustness.

We summarize our contributions as follows:

  • In §3, based on the proximal average we provide a novel, unifying and revealing interpretation of existing FL practices.

  • In §4, we summarize some background on multi-objective optimization and point out its connections to existing FL algorithms. We believe this new perspective will yield more fruitful exchanges between the two fields in the future.

  • In §5, we propose FedMGDA+ that complements existing FL systems while taking robustness and fairness explicitly into its algorithmic design. We prove that FedMGDA+ converges to a Pareto stationary solution under mild assumptions.

  • In §6, we perform extensive experiments to validate the competitiveness of FedMGDA+ under a variety of desirable metrics, and to illustrate the respective pros and cons of our and alternative algorithms.

We discuss more related work in §2 and we conclude in §7 with some future directions.

To facilitate reproducibility, we have released our code at: https://github.com/watml/Fed-MGDA.

2 Related Work

In this section we give a brief review of some recent work that is directly related to ours and put our contributions in context. To start with, McMahan et al., (2017) proposed the first FL algorithm known as “Federated Averaging” (a.k.a., FedAvg), which is a synchronous update scheme that proceeds in several rounds. At each round, the central server sends the current global model to a subset of users, each of which then uses its respective local data to update the received model. Upon receiving the updated local models from users, the server performs aggregation, such as simple averaging, to update the global model. For more discussion on different averaging schemes, see Li et al., (2020). Li et al., 2020a extended FedAvg to better deal with non-i.i.d. distribution of data, by adding a “proximal regularizer” to the local loss functions and minimizing the Moreau envelope function for each user. The resulting algorithm FedProx, as pointed out in §3, is a randomized version of the proximal average algorithm in Yu, (2013) and reduces to FedAvg when regularization diminishes.

Analysing FedAvg has been a challenging task due to its flexible updating scheme, partial user participation, and non-iid distribution of client data Li et al., 2020a . The first theoretical analysis of FedAvg for strongly convex and smooth problems with iid and non-iid data appeared in Stich, (2019) and Li et al., (2020), respectively, where the effect of different sampling and averaging schemes on the convergence rate of FedAvg was also investigated, leading to the conclusion that such effect becomes particularly important when the dataset is unbalanced and non-iid distributed. In Huo et al., (2020), FedAvg was analyzed for non-convex problems, where FedAvg was formulated as a stochastic gradient-based algorithm with biased gradients, and the convergence of FedAvg with decaying step sizes to stationary points was proved. Moreover, Huo et al., (2020) proposed FedMom, a server-side acceleration based on Nesterov’s momentum, and proved again its convergence to stationary points. Lately, Reddi et al., (2020) proposed and analyzed federated versions of several popular adaptive optimizers (e.g. ADAM). They generalize the framework of FedAvg by decoupling the FL update scheme into server optimizer and client optimizer. Interestingly, same as us, Reddi et al., (2020) also observed the importance of learning rate decays on both clients and server.

Recently, an interesting work by Pathak and Wainwright, (2020) demonstrated theoretically that fixed points reached by FedAvg and FedProx (if exist) need not be stationary points of the original optimization problem, even in convex settings and with deterministic updates. To address this issue, they proposed FedSplit to restore the correct fixed points. It still remains open, though, if FedSplit can still converge to the correct fixed points under asynchronous and stochastic user updates, both of which are widely adopted in practice and studied here.

Ensuring fairness among users has become a serious goal in FL since it largely determines users’ willingness to participate in the training process. Mohri et al., (2019) argued that existing FL algorithms can lead to federated models that are biased toward different users. To solve this issue, Mohri et al., (2019) proposed agnostic federated learning (AFL) to improve fairness among users. AFL considers the target distribution as a weighted combination of the user distributions and optimizes the centralized model for the worse-case realization, leading to a saddle-point optimization problem which was solved by a fast stochastic optimization algorithm. On the other hand, based on fair resource allocation in wireless networks, Li et al., 2020b proposed q-fair federated learning (q-FFL) to achieve more uniform test accuracy across users. Li et al., 2020b further proposed q-FedAvg as a communication efficient algorithm to solve q-FFL. However, both AFL and q-FedAvg do not explicitly encourage user participation and they suffer from adversarial attacks while our algorithm FedMGDA+ is designed to be fair among participants and robust against both additive and multiplicative attacks.

FedAvg relies on a coordinate-wise averaging of local models to update the global model. According to Wang et al., 2020b , in neural network (NN) based models, such coordinate-wise averaging might lead to sub-optimal results due to the permutation invariance of NN parameters. To address this issue, Yurochkin et al., (2019) proposed probabilistic federated neural matching (PFNM), which is only applicable to fully connected feed-forward networks. The recent work (Wang et al., 2020b, ) proposed federated matched averaging (FedMA) as a layer-wise extension of PFNM to accommodate CNNs and LSTMs. However, the Bayesian non-parametric mechanism in PFNM and FedMA may be vulnerable to model poisoning attack (Bagdasaryan et al.,, 2020; Bhagoji et al.,, 2019; Wang et al., 2020a, ), while some simple defences, such as norm thresholding and differential privacy, were discussed in Sun et al., (2019). We note that these ideas are complementary to FedMGDA+ and we plan to investigate possible integrations of them in future work.

Lastly, we note that there is significant interest in standardizing the benchmarks, protocols and evaluations in FL, see for instance (Caldas et al.,, 2018; He et al.,, 2020). We have spent significant efforts in adhering to the suggested rules there, by reporting on common datasets, open sourcing our code and including all experimental details.

3 Problem Setup

We recall the federated learning (FL) framework of McMahan et al., (2017) and point out a simple interpretation that seemingly unifies different implementations. We consider FL with m users (edge devices), where the i-th user is interested in minimizing a function fi:d,i=1,,m, defined on a shared model parameter 𝐰d. Typically, each user function fi also depends on the respective user’s local (private) data 𝒟i. The main goal in FL is to collectively and efficiently optimize individual objectives {fi} while meeting challenges such as those mentioned in the Introduction (§1): non-iid distribution of user data, limited communication, user privacy, fairness, robustness, etc..

McMahan et al., (2017) proposed FedAvg to optimize the arithmetic average of individual user functions:

min𝐰d𝖠𝐟,𝝀0(𝐰),where𝖠𝐟,𝝀0(𝐰):=i=1mλifi(𝐰). (1)

The weights λi need to be specified beforehand. Typical choices include the dataset size at each user, the “importance” of each user, or simply uniform, i.e.λi1/m. FedAvg works as follows: At each round, a (random) subset of users is selected, each of which performs k epochs of local (full or minibatch) gradient descent:

for all i in parallel,𝐰i𝐰iηfi(𝐰i), (2)

and then the weights are averaged at the server side:

𝐰iλi𝐰i, (3)

which is finally broadcast to the users in the next round. The number of local epochs k turns out to be a key factor. Setting k=1 amounts to solving (1) by the usual gradient descent algorithm, while setting k= (and assuming convergence for each local function fi) amounts to (repeatedly) averaging the respective minimizers of fi’s. We now give a new interpretation of FedAvg that yields insights on what it optimizes with an intermediate k.

Our interpretation is based on the proximal average (Bauschke et al.,, 2008). Recall that the Moreau envelope and proximal map of a convex111For nonconvex functions, similar results hold once we address multi-valuedness of the proximal map, see Yu et al., (2015). function f is defined respectively as:

𝖬fη(𝐰) =min𝐱12η𝐱𝐰22+f(𝐱), (4)
𝖯fη(𝐰) =argmin𝐱12η𝐱𝐰22+f(𝐱). (5)

Given a set of convex functions 𝐟=(f1,,fm) and positive weights 𝝀=(λ1,,λm) that sum to 1, we define the proximal average as the unique function 𝖠𝐟,𝝀η such that 𝖯𝖠𝐟,𝝀ηη=iλi𝖯fiη. In other words, the proximal map of the proximal average is the average of proximal maps. More concretely, Bauschke et al., (2008) gave the following explicit, albeit complicated, formula for the proximal average:

𝖠𝐟,𝝀η(𝐰) =min𝐰1,,𝐰mi=1mλi[fi(𝐰i)+12η𝐰i22]12η𝐰22, (6)
s.t.i=1mλi𝐰i=𝐰. (7)

From the above formula we can easily derive that

𝖠𝐟,𝝀0(𝐰) :=limη0+𝖠𝐟,𝝀η(𝐰)=iλifi(𝐰),
𝖠𝐟,𝝀(𝐰) :=limη𝖠𝐟,𝝀η(𝐰)=miniλi𝐰i=𝐰iλifi(𝐰i).

Interestingly, we can now interpret FedAvg in two extreme settings as minimizing the proximal average:

  • FedAvg with k=1 local step is exactly the same as minimizing the proximal average 𝖠𝐟,𝝀0(𝐰) with η=0. This is clear from the objective (1) of FedAvg (as our notation already suggests).

  • FedAvg with k= local steps is exactly the same as minimizing the proximal average 𝖠𝐟,𝝀(𝐰) with η=. Indeed,

    {min𝐰𝖠𝐟,𝝀(𝐰)}=min𝐰1,,𝐰miλifi(𝐰i), (8)

    where the right-hand side decouples and hence 𝐰i at optimality is a minimizer of fi (recall that 𝝀0).

Therefore, we may interpret FedAvg with an intermediate k as minimizing 𝖠𝐟,𝝀η with an intermediate η. More interestingly, if we apply the PA-PG algorithm in Yu, (2013, Algo. 222) to minimize 𝖠𝐟,𝝀η, we obtain the simple update rule

𝐰iλi𝖯fiη(𝐰), (9)

where the proximal maps are computed in parallel at the user’s side. We note that the recent FedProx algorithm (Li et al., 2020a, ) is essentially a randomized version of (9). Crucially, we do not need to evaluate the complicated formula (6) as the update (9) only requires its proximal map, which by definition is the average of the individual proximal maps (computed by each user separately). Moreover, the difference between the proximal average 𝖠𝐟,𝝀η and the arithmetic average 𝖠𝐟,𝝀0 can be uniformly bounded using the Lipschitz constant of each function fi (Yu,, 2013). Thus, for small step size η, FedAvg (with any finite k) and FedProx all minimize some approximate form of the arithmetic average in (1).

How to set the weights 𝝀 in FedAvg has been a major challenge. In FL, data is distributed in a highly non-iid and unbalanced fashion, so it is not clear if some chosen arithmetic average in (1) would really satisfy one’s actual intention. A second issue with the arithmetic average in (1) is its well-known non-robustness against malicious manipulations, which has been exploited in recent adversarial attacks (Bhagoji et al.,, 2019). Instead, Agnostic FL (AFL (Mohri et al.,, 2019)) aims to optimize the worst-case loss:

min𝐰max𝝀Λ𝖠𝐟,𝝀0(𝐰), (10)

where the set Λ might cover reality better than any specific 𝝀 and provide some minimum guarantee for all users (hence achieving mild fairness). On the other hand, the worst-case loss in (10) is perhaps even more non-robust against adversarial attacks. For instance, adding a positive constant to some loss fi can make it dominate the entire optimization process. The recent work q-FedAvg (Li et al., 2020b, ) proposes an q norm interpolation between FedAvg (essentially 1 norm) and AFL (essentially norm). By tuning q, q-FedAvg can achieve better compromise than FedAvg or AFL.

4 Multi-objective Minimization (MoM)

Multi-objective minimization (MoM) refers to the setting where multiple scalar objective functions, possibly incompatible with each other, need to be minimized simultaneously. It is also called vector optimization (Jahn,, 2009) because the objective functions can be combined into a single vector-valued function. In mathematical terms, MoM can be written as

min𝐰d𝐟(𝐰):=(f1(𝐰),f2(𝐰),,fm(𝐰)), (11)

where the minimum is defined wrt the partial ordering:

𝐟(𝐰)𝐟(𝐳)i=1,,m,fi(𝐰)fi(𝐳). (12)

(We remind that algebraic operations such as and +, when applied to a vector with another vector or scalar, are always performed component-wise.) Unlike single objective optimization, with multiple objectives it is possible that

𝐟(𝐰)𝐟(𝐳) and 𝐟(𝐳)𝐟(𝐰), (13)

in which case we say 𝐰 and 𝐳 are not comparable.

We call 𝐰 a Pareto optimal solution of (11) if its objective value 𝐟(𝐰) is a minimum element (wrt the partial ordering in (12)), or equivalently for any 𝐰, 𝐟(𝐰)𝐟(𝐰) implies 𝐟(𝐰)=𝐟(𝐰). In other words, it is not possible to improve any component objective in 𝐟(𝐰) without compromising some other objective. Similarly, we call 𝐰 a weakly Pareto optimal solution if there does not exist any 𝐰 such that 𝐟(𝐰)<𝐟(𝐰), i.e., it is not possible to improve all component objectives in 𝐟(𝐰). Clearly, any Pareto optimal solution is also weakly Pareto optimal but the converse may not hold.

We point out that the optimal solutions in MoM are usually a set (in general of infinite cardinality) (Mukai,, 1980), and without additional subjective preference information, all Pareto optimal solutions are considered equally good (as they are not comparable against each other). This is fundamentally different from the single objective case.

From now on, for simplicity we assume all objective functions are continuously differentiable but not necessarily convex (to accommodate deep models). Finding a (weakly) Pareto optimal solution in this setting is quite challenging (already so in the single objective case). Instead, we will contend with Pareto stationary solutions, namely those that satisfy an intuitive first order necessary condition:

Definition 1 (Pareto-stationarity, Mukai, 1980).

We call 𝐰 Pareto-stationary iff some convex combination of the gradients {fi(𝐰)} vanishes, i.e.there exists some 𝛌0 such that iλi=1 and iλifi(𝐰)=𝟎.

Lemma 1 (Mukai, 1980).

Any Pareto optimal solution is Pareto stationary. Conversely, if all functions are convex, then any Pareto stationary solution is weakly Pareto optimal.

Needless to say, the above results reduce to the familiar ones for the single objective case (m=1).

There exist many algorithms for finding Pareto stationary solutions. We briefly review three popular ones that are relevant for us, and refer the reader to the excellent monograph (Miettinen,, 1998) for more details.

Weighted approach. Let 𝝀Δ (the simplex) and consider the following single, weighted objective:

min𝐰i=1mλifi(𝐰). (14)

This is essentially the approach taken by FedAvg, with any (global) minimizer of (14) being weakly Pareto optimal (in fact, Pareto optimal if all weights λi are positive). From Definition 1 it is clear that any stationary solution of the weighted scalar problem (14) is a Pareto stationary solution of the original MoM (11). Note that the scalarization weights 𝝀, once chosen, are fixed throughout. Different 𝝀 leads to different Pareto stationary solutions.

ϵ-constraint. Let ϵm1, ι{1,,m} and consider the following constrained scalar problem:

min𝐰 fι(𝐰) (15)
s.t. fi(𝐰)ϵi,iι. (16)

Assuming the constraints are satisfiable, then any (global) minimizer of (15) is again weakly Pareto optimal. The ϵ-constraint approach is closely related to the weighted approach above, through the usual Lagrangian reformulation. Both require fixing an m1 dimensional parameter in advance (𝝀 vs. ϵ), though.

Chebyshev approach. Let 𝐬m and consider the minimax problem (where recall that Δ is the simplex constraint):

min𝐰max𝝀Δ𝝀(𝐟(𝐰)𝐬). (17)

Again, any (global) minimizer is weakly Pareto optimal. Here 𝐬 is a fixed vector that ideally lower bounds 𝐟. This is essentially the approach taken by AFL (Mohri et al.,, 2019) with 𝐬=𝟎.

5 FL as Multi-objective Minimization

Having introduced both FL and MoM, and observed some connections between the two, it is very natural to treat each user function fi in FL as a separate objective in MoM and aim to optimize them simultaneously as in (11). This will be the main approach we follow below, which, to the best of our knowledge, has not been formally explored before (despite of the apparent connections that we saw in the previous section, perhaps retrospectively). In particular, we will extend the multiple gradient descent algorithm (Mukai,, 1980) in MoM to FL, draw connections to existing FL algorithms, and prove convergence properties of our extended algorithm FedMGDA+. Very importantly, the notion of Pareto optimality and stationarity immediately enforces fairness among users, as we are discouraged from improving certain users by sacrificing others.

To further motivate our development, let us compare to the objective in AFL (Mohri et al.,, 2019):

min𝐰max𝝀Δ𝝀𝐟(𝐰)min𝐰maxi=1,,mfi(𝐰), (18)

where Δ denotes the simplex222To be precise, AFL restricted 𝝀 to a subset ΛΔ. We simply set Λ=Δ to ease the discussion.. By optimizing the worst loss than the average loss in FedAvg, AFL provides some guarantee to all users hence achieving some form of fairness. However, note that AFL’s objective (18) is not robust against adversarial attacks. In fact, if a malicious user artificially “inflates” its loss fi (e.g., even by adding/multiplying a constant), it can completely dominate and mislead AFL to solely focus on optimizing its performance. The same issue applies to q-FedAvg (Li et al., 2020b, ), albeit with a less dramatic effect if q is small.

AFL’s objective (18) is very similar to the Chebyshev approach in MoM (see Section 4), which inspires us to propose the following iterative algorithm for solving (11):

𝐰~t+1=argmin𝐰max𝝀Δ𝝀(𝐟(𝐰)𝐟(𝐰~t)), (19)

where we adaptively “center” the user functions using function values from the previous iteration. When the functions fi are smooth, we apply the quadratic bound to obtain:

𝐰t+1=argmin𝐰max𝝀Δ𝝀J𝐟(𝐰t)(𝐰𝐰t)+12η𝐰𝐰t2, (20)

where J𝐟=[f1,,fm]d×m is the Jacobian and η>0 is the step size. Crucially, note that 𝐟(𝐰t) does not appear in the above bound (20) since we subtracted it off in (19). Since (20) is convex in 𝐰 and concave in 𝝀 we can swap min with max and obtain the dual:

max𝝀Δ min𝐰𝝀J𝐟(𝐰t)(𝐰𝐰t)+12η𝐰𝐰t2. (21)

Solving 𝐰 by setting its derivative to 𝟎 we arrive at:

𝐰t+1=𝐰tη𝐝t,𝐝t=J𝐟(𝐰t)𝝀t, (22)
where𝝀t=argmin𝝀ΔJ𝐟(𝐰t)𝝀2. (23)

Note that 𝐝t is precisely the minimum-norm element in the convex hull of the columns (i.e., gradients) in the Jacobian J𝐟, and finding 𝝀t amounts to solving a simple quadratic program. The resulting iterative algorithm in (22) is known as multiple gradient descent algorithm (MGDA), which has been (re)discovered in Mukai, (1980); Fliege and Svaiter, (2000); Désidéri, (2012) and recently applied to multitask learning in Sener and Koltun, (2018); Lin et al., (2019) and to training GANs in Albuquerque et al., (2019). Our concise derivation here reveals some new insights about MGDA, in particular its connection to AFL.

To adapt MGDA to the federated learning setting, we propose the following extensions.

Balancing user average performance and fairness. We observe that the MGDA update in (22) resembles FedAvg, with the crucial difference that MGDA automatically tunes the dual weighting variable 𝝀 in each step while FedAvg pre-sets 𝝀 based on a priori information about the user functions (or simply uniform in lack of such information). Importantly, the direction 𝐝t found in MGDA is a common descent direction for all participating objectives:

𝐟(𝐰t+1) 𝐟(𝐰t)+J𝐟(𝐰t)(𝐰t+1𝐰t)+12η𝐰t+1𝐰t2
𝐟(𝐰t), (24)

where the first inequality follows from familiar smoothness assumption on 𝐟 while the second inequality follows simply from plugging 𝐰=𝐰t in (20) and noting that 𝐰t+1 by definition can only decrease (20) even more. It is clear that equality is attained iff 𝐝t=J𝐟(𝐰t)𝝀t=𝟎, i.e., 𝐰t is Pareto-stationary (see Section 4). In other words, MGDA never sacrifices any participating objective to trade for more sizable improvements over some other objective, something FedAvg with a fixed weighting 𝝀 might attempt to do. On the other hand, FedAvg with a fixed weighting 𝝀 may achieve higher average performance under the weighting 𝝀. It is natural to introduce the following trade-off between average performance and fairness:

update (22) with 𝝀t=argmin𝝀Δ,𝝀𝝀0ϵJ𝐟(𝐰t)𝝀2. (25)

Clearly, setting ϵ=0 recovers FedAvg with a priori weighting 𝝀0 while setting ϵ=1 recovers MGDA where the weighting variable 𝝀 is tuned without any restriction to achieve maximal fairness. In practice, with an intermediate ϵ(0,1) we may strike a desirable balance between the two (sometimes) conflicting goals. Moreover, even with the uninformative weighting 𝝀0=𝟏/m, using an intermediate ϵ allows us to upper bound the contribution of each user function to the common direction 𝐝t hence achieve some form of robustness against malicious manipulations.

Robustness against malicious users through normalization. Existing work (e.g., Bhagoji et al.,, 2019; Xie et al.,, 2019) has demonstrated that the average gradient in FedAvg can be easily manipulated by even a single malicious user. While more robust aggregation strategies are studied recently (see e.g., Blanchard et al.,, 2017; Yin et al.,, 2018; Diakonikolas et al.,, 2019), they do not necessarily maintain the convergence properties of FedMGDA+ (e.g.finding a common descent direction and converging to a Pareto stationary solution). Instead, we propose to simply normalize the gradients from each user to unit length, based on the following considerations: (a) Normalizing the (sub)gradient is common for specialists in nonsmooth and stochastic optimization (Anstreicher and Wolsey,, 2009) and sometimes eases step size tuning. (b) Solving the weights 𝝀t in (22) with normalized gradients still guarantees fairness, i.e., the resulting direction 𝐝t is descending for all participating objectives (by a completely similar reasoning as the remark after (5)). (c) Normalization restores robustness against multiplicative “inflation” from any malicious user, which, combined with MGDA’s built-in robustness against additive “inflation” (see Equation 19), offers reasonable robustness guarantees against adversarial attacks.

Balancing communication and on-device computation. Communication between user devices and the central server is heavily constrained in FL, due to a variety of reasons mentioned in §3. On the other hand, modern edge devices are capable of performing reasonable amount of on-device computations. Thus, we allow each user device to perform multiple local updates before communicating its update 𝐠=𝐰0𝐰, namely the difference between the initial 𝐰0 and the final 𝐰, to the central server. The server then calls the (extended) MGDA to perform a global update, which will be broadcast to the next round of user devices. We note that similar strategy was already adopted in many existing FL systems (e.g., McMahan et al.,, 2017; Li et al., 2020b, ; Li et al., 2020a, ).

Subsampling to alleviate non-iid and enhance throughput. Due to the massive number of edge devices in FL, it is not realistic to expect most devices to participate at each or even most rounds. Consequently, the current practice in FL is to select a (different) subset of user devices to participate in each round (McMahan et al.,, 2017). Moreover, randomly subsampling user devices can also help combat the non-iid distribution of user-specific data (e.g., McMahan et al.,, 2017; Li et al.,, 2020). Here we point out an important advantage of our MGDA-based algorithm: its update is along a common descending direction (see (5)), meaning that the objective of any participating user can only decrease. We believe this unique property of MGDA provides strong incentive for users to participate in FL. To our best knowledge, existing FL algorithms do not provide similar algorithmic incentives. Last but not the least, subsampling also solves a degeneracy issue in MGDA: when the number of participating users exceeds the dimension d, the Jacobian J𝐟 has full row-rank hence (22) achieves Pareto-stationarity in a single iteration and stops making progress. Subsampling removes this undesirable effect and allows different subsets of users to be continuously optimized.

With the above extensions, we summarize our extended algorithm FedMGDA+ in Algorithm 1, and we prove the following convergence guarantees (precise statements and proofs can be found in Appendix A):

Theorem 1a.

Let each user function fi be L-Lipschitz smooth and M-Lipschitz continuous, and choose step size ηt so that tηt= and tσtηt<, where σt2:=𝐄𝐝t𝐝^t2 with

𝐝t :=J𝐟(𝐰t)𝝀t,𝝀t=argmin𝝀ΔJ𝐟(𝐰t)𝝀, (26)
𝐝^t :=J^𝐟(𝐰t)𝝀^t,𝝀^t=argmin𝝀ΔJ^𝐟(𝐰t)𝝀. (27)

Then, with k=r=1 we have:

mint=0,,T𝐄J𝐟(𝐰t)𝝀t20. (28)

Here k is the number of local updates and r is the number of minibatches in each local update. The convergence rate depends on how quickly the “variance” term σt of the stochastic common descent direction d^t diminishes (if at all), which in turn depends on how aggressively we subsample users or how heterogeneous the users are.

For deterministic gradient updates, we can prove convergence even with more local updates (i.e.k>1):

Theorem 1b.

Let each user function fi be L-Lipschitz smooth and M-Lipschitz continuous. For any number of local updates k, if the global step size ηt0 with tηt=, local learning rate ηtl0 and εt:=𝛌t𝛌^t0, then we have:

mint=0,,TJ𝐟(𝐰t)𝝀t20. (29)

Please refer to Appendix A for the precise statement of the theorem and its proof. We note that one natural approach to bound the deviation εt is by applying the ϵ-constrained version of FedMGDA. For example, if 𝝀𝝀0ϵt, and ϵt is bounded, then εt2mϵt is also bounded. Thus, εt0 when ϵt0. Moreover, when k=1, we do not need the local learning rate ηtl to decay for convergence; in addition, if εt0 (e.g. in FedAvg), then our convergence guarantee reduces to the usual one for gradient descent, which is expected since we know FedAvg with k=1,r=1 is the same as centralized gradient descent. Lastly, we note that when k>1, local learning rate ηtl must vanish in order to obtain convergence. This importance of local learning rate decay is also pointed out in Reddi et al., (2020).

When the functions fi are convex, we can derive a finer result:

Theorem 2.

Suppose each user function fi is convex and M-Lipschitz continuous. Suppose at each round FedMGDA+ includes a strongly convex user function whose weight is bounded away from 0. Then, with the choice ηt=2c(t+2) and k=r=1, we have

𝐄𝐰t𝐰t24M2c2(t+3), (30)

and 𝐰t𝐰t0 almost surely, where 𝐰t is the nearest Pareto stationary solution to 𝐰t and c is some constant.

A slightly stronger result where we also allow some user functions to be nonconvex can be found in Appendix A. The same results hold if the gradient normalization is bounded away from 0 (otherwise we are already close to Pareto stationarity). For r,k>1, using a similar argument as in §3, we expect FedMGDA+ to optimize some proxy problem (such as the proximal average), and we leave the thorough theoretical analysis for future work.

We remark that convergence rate for MGDA, even when restricted to the deterministic case, was only derived recently in Fliege et al., (2019). The stochastic case (that we consider here) is much more challenging and our theorems provide one of the first convergence guarantees for FedMGDA+. We wish to emphasize that FedMGDA+ is not just an alternative algorithm for FL practitioners; it can be used as a post-processing step to enhance existing FL systems or combined with existing FL algorithms (such as FedProx or q-FedAvg). This is particularly appealing with nonconvex user functions as MGDA is capable of converging to all Pareto stationary points while approaches such as FedAvg do not necessarily enjoy this property even when we enumerate the weighting 𝝀0 (Miettinen,, 1998). Furthermore, it is possible to find multiple or even enumerate all Pareto optimal solutions (i.e.the Pareto front). For instance, we may run FedMGDA+ multiple times with different random seeds or initializations. As shown by Lin et al., (2019), we could also incorporate additional linear constraints in (22) to encode one’s preference and encourage more diverse solutions. However, these techniques become less effective in higher dimensions (i.e.when the number of users is large) and in communication limited settings. Practically, the server may dynamically adjust the linear constraints in (22) to steer the algorithm to a more desirable Pareto stationary solution.

Lastly, we mention that finding the common descent direction (i.e.Line 6 of Algorithm 1) is a standard quadratic programming (QP) problem that is solved only at the server side. For moderate number of (sampled) users, it suffices to employ a generic QP solver while for large number of users we could also solve λ efficiently using for instance the conditional gradient algorithm (Sener and Koltun,, 2018), with per-step complexity proportional to the model dimension and the number of participating users. For our experiments below, we used a generic QP sovler and we observed that this overhead is negligible, resulting almost the same overall running time for FedAvg and FedMGDA.

1 for t=1,2, do
2 choose a subset It of pm clients/users
3 for iIt do
4 𝐠i ClientUpdate(i,𝐰t)
𝐠¯i:=𝐠i/𝐠i
// normalize
5
6
7 𝝀argmin𝝀Δ,𝝀𝝀0ϵiλi𝐠¯i2
𝐝tiλi𝐠¯i
// common direction
8
9 choose (global) step size ηt
10 𝐰t+1𝐰tηt𝐝t
11
12
13 Function ClientUpdate(i,𝐰):
14 𝐰0𝐰
15 repeat k epochs
// split local data into r batches
16
17 𝒟i𝒟i,1𝒟i,r
18 for j{1,,r} do
19 𝐰𝐰ηfi(𝐰;𝒟i,j)
20
21
22 return 𝐠:=𝐰0𝐰 to server
Algorithm 1 FedMGDA+

6 Experiments

6.1 Experimental setups

Table 1: Dataset summary
Dataset Train Clients Train samples Test clients Test samples Batch size
CIFAR-10 (Krizhevsky,, 2009) 100 50000 100 10000 {10,}
F-MNIST (Xiao et al.,, 2017) 100 60000 100 10000 {10,}
FEMNIST (Caldas et al.,, 2018) 3406 709385 3406 80011 {20,}
Shakespeare (Li et al., 2020a, ) 31 92959 31 23255 {10}
Adult (Dua and Graff,, 2017) 2 32561 2 16281 {10}
Table 2: CIFAR-10 model
Layer Output Shape # of Trainable Parameters Activation Hyper-parameters
Input (3,32,32) 0
Conv2d (64,28,28) 4864 ReLU kernel size =5; strides=(1,1)
MaxPool2d (64,14,14) 0 pool size=(2,2)
LocalResponseNorm (64,14,14) 0 size=2
Conv2d (64,10,10) 102464 ReLU kernel size =5; strides=(1,1)
LocalResponseNorm (64,10,10) 0 size=2
MaxPool2d (64,5,5) 0 pool size=(2,2)
Flatten 1600 0
Dense 384 614784 ReLU
Dense 192 73920 ReLU
Dense 10 1930 softmax
Total 797962
Table 3: Fashion MNIST model
Layer Output Shape # of Trainable Parameters Activation Hyper-parameters
Input (1,28,28) 0
Conv2d (10,24,24) 260 ReLU kernel size =5; strides=(1,1)
MaxPool2d (10,12,12) 0 pool size=(2,2)
Conv2d (20,8,8) 5020 ReLU kernel size =5; strides=(1,1)
MaxPool2d (20,4,4) 0 pool size=(2,2)
Dropout2d (20,4,4) 0 p=0.5
Flatten 320 0
Dense 50 16050 ReLU
Dropout 50 0 p=0.5
Dense 10 510 softmax
Total 21840
Table 4: Federated EMNIST model (Reddi et al.,, 2020)
Layer Output Shape # of Trainable Parameters Activation Hyper-parameters
Input (1,28,28) 0
Conv2d (32,26,26) 320 kernel size =3; strides=(1,1)
Conv2d (64,24,24) 18496 ReLU kernel size =3; strides=(1,1)
MaxPool2d (64,12,12) 0 pool size=(2,2)
Dropout (64,12,12) 0 p=0.25
Flatten 9216 0
Dense 128 1179776
Dropout 128 0 p=0.5
Dense 62 7998 softmax
Total 1206590
Table 5: Hyperparameters used in our experiments.
Name Parameters
AFL γλ{0.01,0.1,0.2,0.5},γw{0.01,0.1}
q-FedAvg q{0.001,0.01,0.1,0.5,1,2,5,10}, L{0.1,1,10}
FedMGDA+ η{0.5,1,1.5,2}, and Decay{0,140,130,120,110,13,12}
FedAvg-n η{0.5,1,1.5,2}, and Decay{0,140,130,120,110,13,12}
FedProx μ{0.001,0.01,0.1,0.5,1,10}
MGDA-Prox μ=0.1, η{0.5,1,1.5,2}, and Decay{0,140,130,120,110,15,13,12}

In this subsection we provide experimental details including dataset descriptions, sampling schemes, model configurations and hyper-parameter settings. A quick summary of the datasets that we use can be found in Table 1. We have two parameters in FedMGDA+ to control the total number of local updates in each communication round: k, the number of local epochs, and r=n/b, the number of updates in each local epoch. Here n is the number of samples at each user (assumed the same for simplicity) while b is the minibatch size for each local update. As observed by, e.g., McMahan et al., (2017) (Table 2), having a larger k is similar as having a smaller b (or equivalently a larger r), in terms of total number of local updates. Moreover, k=1 with a suitable b usually leads to satisfying performance while very large k can result in plateau or divergence. Thus, in our experiments we fix k=1 while vary b to reduce the total number of hyperparameters. This corresponds to a single pass of the local data at each user in every communication round.

6.1.1 CIFAR-10 (Krizhevsky,, 2009) and Fashion MNIST (Xiao et al.,, 2017) datasets

In order to create a non-i.i.d. dataset, we follow a similar sampling procedure as in McMahan et al., (2017): first we sort all data points according to their classes. Then, they are split into 500 shards, and each user is randomly assigned 5 shards of data. By considering 100 users, this procedure guarantees that no user receives data from more than 5 classes and the data distribution of each user is different from each other. The local datasets are balanced–all users have the same amount of training samples. The local data is split into train, validation, and test sets with percentage of 80%, 10%, and 10%, respectively. In this way, each user has 400 data points for training, 50 for test, and 50 for validation. We use a CNN model which resembles the one in McMahan et al., (2017), with two convolutional layers followed by three fully connected layers. The details are included in Table 2 for CIFAR-10 and in Table 3 for Fashin MNIST. To update the local models at each user using its local data, we apply stochastic gradient descent (SGD) with local batch size b=10, local epoch k=1, and local learning rate η=0.01, or b=400, k=1, and η=0.1. To model the fact that not all users may participate in each communication round, we employ a parameter p to control the fraction of participating users: p=0.1 is the default setting which means that only 10% of users participate in each communication round.

6.1.2 Federated EMNIST dataset (Caldas et al.,, 2018)

For this experimental setup, we use the same dataset, model, and hyper-parameters as Reddi et al., (2020). We use the federated EMNIST dataset of Caldas et al., (2018). The dataset consists of images of digits, and English characters—both lower and upper cases, with 62 classes in total. The images are partitioned by their authors in a way that naturally makes the dataset heterogeneous and unbalanced. We use the model described in Table 4 and the following hyper-parameters: local learning rate η=0.1 and selecting 10 clients per communication round as recommended. The only difference between our setup and the one in (Reddi et al.,, 2020) is that we use local epoch k=1 for all algorithms.

6.1.3 Shakespeare dataset (Li et al., 2020a, )

For experiments on the Shakespeare dataset, we use the same model, data pre-processing and sampling procedure as in q-FedAvg paper (Li et al., 2020b, ). The dataset is built from The Complete Works of William Shakespeare, where each role in the play represents one user. Following Li et al., 2020a , we subsample 31 users to train a neural language model for next character prediction. Each character is embedded in an 8-dimensional space and the sequence length is 80 characters. The model we use is a two-layer LSTM (with hidden size 256) followed by one dense layer (McMahan et al.,, 2017; Li et al., 2020a, ). Joint hyper-parameters that are shared by all algorithms include: total communication rounds T=200, local batch size b=10, local epoch k=1, and local optimizer being SGD, unless otherwise stated.

6.1.4 Adult dataset (Dua and Graff,, 2017)

Following the setting in AFL (Mohri et al.,, 2019), we split the Adult dataset into two non-overlapping domains based on the education attribute—phd domain and non-phd domain. The resulting FL setting consists of two users each of which has data from one of the two domains. Further, data is pre-processed as in Li et al., 2020b to have 99 binary features. We use a logistic regression model for all FL algorithms mentioned in the main paper. Local data is split into train, validation, and test sets with percentage of 80%, 10%, and 10%, respectively. In each round, both users participate and the server aggregates their losses and gradients (or weights). Joint hyper-parameters that are shared by all algorithms include: total communication rounds T=500, local batch size b=10, local epoch k=1, local learning rate η=0.01, and local optimizer being SGD without momentum, unless otherwise stated. Algorithm-specific hyper-parameters will be mentioned in the appropriate places below. One important note is that the phd domain has only 413 samples while the non-phd domain has 32,148 samples, so the split is very unbalanced while training only on the phd domain yields inferior performance on all domains due to the insufficient sample size.

6.1.5 Hyper-parameters

We evaluate the performance of different algorithms with a wide range of hyper-parameters, summarized in Table 5. In particular, following Anstreicher and Wolsey, (2009) we tried sublinear O(1/t) and exponential decay O(βt) learning rates η on the server, and a fixed local learning rate η for client updates. Eventually we settled on decaying ηt by a factor of β every 100 steps: ηt=β[t100], where β=decay100/T and T is the total number of communication rounds (with e.g. decay = 1/10). We note that Reddi et al., (2020) also found exponential decay to be most effective in their experiments. We use grid search to choose suitable local learning rates for all algorithms.

Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 1: Interpolation between FedAvg and FedMGDA on CIFAR-10. x-axis is the number of communication rounds. From left to right: (a) and (b) Average user accuracy in non-iid/iid setting resp. (c) and (d) Uniformly averaged training loss in non-iid/iid setting resp. Results are averaged over 5 runs with different random seeds.
Refer to caption
Refer to caption
Refer to caption
Figure 2: (Left) Test accuracy of SOTA algorithms on Adult dataset with adversarial biases added to the loss of PhD domain; and compared to the baseline of training only on PhD domain. The scales of biases for AFL and q-FedAvg are different because AFL uses averaged loss while q-FedAvg uses (non-averaged) total loss. (Right) Test accuracy of different algorithms on CIFAR-10 in the presence of a malicious user who scales its loss function with a constant factor. All algorithms are run for 500 rounds on Adult and 1500 rounds on CIFAR-10. The reported results are averaged across 5 runs with different random seeds. For detailed hyperparameter setting, see Section B.2.

We evaluate our algorithm FedMGDA+ on several public datasets: CIFAR-10 (Krizhevsky,, 2009), F-MNIST (Xiao et al.,, 2017), Federated EMNIST (Caldas et al.,, 2018), Shakespeare (Li et al., 2020a, ) and Adult (Dua and Graff,, 2017), and compare to existing FL systems including FedAvg (McMahan et al.,, 2017), FedProx (Li et al., 2020a, ), q-FedAvg (Li et al., 2020b, ), and AFL333Experiments of AFL in the original work (Mohri et al.,, 2019) and later work that compare with it (e.g. (Li et al., 2020b, )) was reported on datasets with very few clients (2 or 3), possibly due to applicability reasons. We followed this convention in our work. (Mohri et al.,, 2019). In addition, from the discussions in §5, one can envision several potential extensions of existing algorithms to improve their performance. So, we also compare to the following extensions: FedAvg-n which is FedAvg with gradient normalization, and MGDA-Prox which is FedMGDA+ with a proximal regularizer added to each user’s loss function.444One can also apply the gradient normalization idea to q-FedAvg; however, we observed from our experiments that the resulting algorithm is unstable particularly for large q values. We distinguish between FedMGDA+ and FedMGDA which is a vanilla extension of MGDA to FL.

We point out that FL algorithms are to be deployed on smart devices with moderate computational capabilities. Thus, the models we chose to experiment on are medium-sized (see Tables 2, 3 and 4 for details), with similar complexity to the ones in FedAvg, q-FedAvg and AFL. Due to space limits we only report some representative results in the main paper, and defer the full set of experiments to Appendix B.

6.2 Experimental results

In this subsection we report experimental results about our proposed algorithm FedMGDA+ and compare it with state-of-the-art (SOTA) alternatives under a variety of performance metrics, including accuracy, robustness and fairness. We remind that the accuracy metric is exactly what FedAvg aims to optimize during training, and hence it has some advantage in this metric over other alternative algorithms such as FedMGDA+, AFL, and q-FedAvg, which all aim to bring some fairness among users, perhaps at some occasional, and hopefully small, loss of accuracy.

6.2.1 Recovering FedAvg

As mentioned in §5, we can control the balance between the user average performance and fairness by tuning the ϵ-constraint in Equation 25. Setting ϵ=0 recovers FedAvg while setting ϵ=1 recovers FedMGDA. To verify this empirically, we run (25) with different ϵ, and report results on CIFAR-10 in Figure 1 for both iid and non-iid distributions of data (for results on F-MNIST, see Section B.1). These results confirm that changing ϵ from 0 to 1 yields an interpolation between FedAvg and FedMGDA, as expected. Since FedAvg essentially optimizes the (uniformly) averaged training loss, it naturally performs the best under this metric (Figure 1 (c) and (d)). Nevertheless, it is interesting to note that some intermediate ϵ values actually lead to better user accuracy than FedAvg in the non-iid setting (Figure 1 (a)).

Refer to caption
Refer to caption
Refer to caption
Figure 3: Distribution of the user test accuracy on CIFAR-10: (Left) the algorithms are run for 2000 communication rounds and b=10. The hyperparameters are: μ=0.01 for FedProx; η=1.5 and decay=1/10 for FedMGDA+ and FedAvg; η=1.0 and decay=1/10 for MGDA-Prox; q=0.5 and L=1.0 for q-FedAvg. (Right) the algorithms are run for 3000 communication rounds and b=400. The hyperparameters are: μ=0.5 for FedProx; η=1.0 and decay=1/40 for FedMGDA+, MGDA-Prox, and FedAvg; q=0.1 and L=0.1 for q-FedAvg. The reported statistics are averaged across 4 runs with different random seeds.
Refer to caption
Refer to caption
Figure 4: The percentage of improved users in terms of training loss vs communication rounds on the CIFAR-10 dataset. Two representative cases are shown: (Left) the local batch size b=10, and (Right) the local batch size b=400. The results are averaged across 4 runs with different random seeds.

6.2.2 Robustness

We discussed earlier in §5 that the gradient normalization and MGDA’s built-in robustness allow FedMGDA+ to combat against certain adversarial attacks in practical FL deployment. We now empirically evaluate the robustness of FedMGDA+ against these attacks. We run various FL algorithms in the presence of a single malicious user who aims to manipulate the system by inflating its loss. We consider an adversarial setting where the attacker participates in each communication round and inflates its loss function by (i) adding a bias to it, or (ii) multiplying it by a scaling factor, termed the bias and scaling attack, respectively. In the first experiment, we simulate a bias attack on the Adult dataset by adding a constant bias to the underrepresented user, i.e. the PhD domain, since it’s more natural to expect an attacker to be consisted of a small number of users. In this setup, the worst performance we can get is bounded by training the model using PhD data only. Results under the bias attack are presented in Figure 2 (Left); also see Section B.2 for more results. We observe that AFL and q-FedAvg perform slightly better than FedMGDA+ without the attack; however, their performances deteriorate to a level close to the worst case scenario under the attack. In contrast, FedMGDA+ is not affected by the attack with any bias, which empirically supports our claim in §5. Note that we did not include FedAvg in this comparison since from its definition it is clear that FedAvg, like FedMGDA+, is not affected by the bias attack. Figure 2 (Right) shows the results of different algorithms on CIFAR-10 with and without an adversarial scaling. As mentioned earlier, q-FedAvg with gradient normalization is highly unstable particularly under the scaling attack, so we did not include its result here. From Figure 2 (Right) it is immediate to see that (i) the scaling attack affects all algorithms that do not employ gradient normalization; (ii) q-FedAvg is the most affected under this attack; (iii) surprisingly, FedMGDA+ and, to a lesser extent, MGDA-Prox actually converge to slightly better Pareto solutions, compared to their own results under no scaling attack. The above results empirically verify the robustness of FedMGDA+ under perhaps the most common bias and scaling attacks.

6.2.3 Fairness

Lastly, we compare FedMGDA+ with existing FL algorithms using different notions of fairness on CIFAR-10. For the first experiment, we adopt the same fairness metric as (Li et al., 2020b, ), and measure fairness by calculating the variance of users’ test error. We run each algorithm with different hyperparameters, and among the results, we pick the best ones in terms of average accuracy to be shown in Figure 3; full table of results can be found in Section B.3. From this figure, we observe that (i) FedMGDA+ achieves the best average accuracy while its standard deviation is comparable with that of q-FedAvg; (ii) FedMGDA+ significantly outperforms FedMGDA, which clearly justifies our proposed modifications in Algorithm 1 to the vanilla MGDA; and (iii) FedMGDA+ outperforms FedAvg-n, which uses the same normalization step as FedMGDA+, in terms of average accuracy and standard deviation. These observations confirm the effectiveness of FedMGDA+ in inducing fairness. We perform the same experiment on the Federated EMNIST dataset, and observed similar results, which can be found in Table 6 and Section B.4.

In the next experiment, we show that FedMGDA+ not only yields a fair final solution but also maintains fairness during the entire training process in the sense that, in each round, it refrains from sacrificing the performance of any participating user for the sake of improving the overall performance. To the best of our knowledge, “fairness during training” has not been investigated before, in spite of having great practical implications—it encourages user participation. To examine this fairness, we run several experiments on CIFAR-10 and measure the percentage of improved participants in each communication round. Specifically, we measure the training loss before and after each round for all participating users, and report the percentage of those improved or stay unchanged.555The percentage of improved users at time t is defined as iIt𝕀{fi(𝐰t+1)fi(𝐰t)}/|It|, where It is the selected users at time t, and 𝕀{A} is the indicator function of an event A. Figure 4 shows the percentage of improved participating users in each communication round in terms of training loss for two representative cases; see Section B.5 for full results with different hyperparameters.

We can see that FedMGDA+ consistently outperforms other algorithms in terms of percentage of improved users, which means that by using FedMGDA+, fewer users’ performances get worse after each participation. Furthermore, we notice from Figure 4 (Left) that, with local batch size b=10, the percentage of improved users is less than 100%, which can be explained as follows: for small batch sizes (i.e., b<|𝒟| where 𝒟 represents a local dataset), the received updates from users are not the true gradients of users’ losses given the global model (i.e., 𝐠ifi(𝐰)); they are noisy estimates of the true gradients. Consequently, the common descent direction calculated by MGDA is noisy and may not always work for all participating users. To remove the effect of this noise, we set b=|𝒟| which allows us to recover the true gradients from the users. The results are presented in Figure 4 (Right), which confirms that, when step size decays (less overshooting), the percentage of improved users for FedMGDA+ reaches towards 100% during training, as is expected.

Table 6: Test accuracy of users on federated EMNIST with full batch, 10 users per rounds, local learning rate η=0.1, total communication rounds 1500. The reported statistics are averaged across 4 runs with different random seeds.
Algorithm Average (%) Std. (%)
FedMGDA 85.73±0.05 14.79±0.12
FedMGDA+ 87.60±0.20 13.68±0.19
MGDA-Prox 87.59±0.19 13.75±0.18
FedAvg 84.97±0.44 15.25±0.36
FedAvg-n 87.57±0.09 13.74±0.11
FedProx 84.97±0.45 15.26±0.35
q-FedAvg 84.97±0.44 15.25±0.37

7 Conclusion

We have proposed a novel algorithm FedMGDA+ for federated learning. FedMGDA+ is based on multi-objective optimization and aims to converge to Pareto stationary solutions. FedMGDA+ is simple to implement, has fewer hyperparameters to tune, and complements existing FL systems nicely. Most importantly, FedMGDA+ is robust against additive and multiplicative adversarial manipulations and ensures fairness among all participating users. We established preliminary convergence guarantees for FedMGDA+, pointed out its connections to recent FL algorithms, and conducted extensive experiments to verify its effectiveness. In the future we plan to formally quantify the tradeoff induced by multiple local updates and to establish some privacy guarantee for FedMGDA+.

Acknowledgment

Resources used in preparing this research were provided, in part, by the Province of Ontario, the Government of Canada through CIFAR, and companies sponsoring the Vector Institute. We gratefully acknowledge funding support from NSERC, the Canada CIFAR AI Chairs Program, and Waterloo-Huawei Joint Innovation Lab. We thank NVIDIA Corporation (the data science grant) for donating two Titan V GPUs that enabled in part the computation in this work.

References

Appendix A Proofs

Theorem 1a (full version).

Suppose each user function fi is L-Lipschitz smooth (i.e., 2fiL𝐈) and M-Lipschitz continuous. Then, with step size ηt(0,12L] we have

mint=0,,T𝐄[J𝐟(𝐰t)𝝀t]2[𝐟(𝐰0)𝐄𝐟(𝐰T+1)+t=0Tηt(Mσt+Lηtσt2)]t=0Tηt, (31)

where σt2:=𝐄J𝐟(𝐰t)𝛌tJ^𝐟(𝐰t)𝛌^t2 is the variance of the stochastic common direction. Moreover, if some user function fi is bounded from below, and it is possible to choose ηt so that tηt=,tηtσt<, then the left-hand side in (31) converges to 0.

Proof.

Let 𝝃t:=J𝐟(𝐰t)𝝀tJ^𝐟(𝐰t)𝝀^t, where J^𝐟(𝐰t):=[^f1(𝐰t),,^fm(𝐰t)] is the concatenation of stochastic gradients at each user, and

𝝀t=argmin𝝀ΔJ𝐟(𝐰t)𝝀,𝝀^t=argmin𝝀^ΔJ^𝐟(𝐰t)𝝀^, (32)

where for the latter we also constrain λ^i=0 if the i-th user is not participating in round t. Then, applying the quadratic bound and the update rule (we remind that comparison between vector and scalar should be understood as component-wise):

𝐟(𝐰t+1) 𝐟(𝐰t)ηtJ𝐟(𝐰t)J^𝐟(𝐰t)𝝀^t+Lηt22J^𝐟(𝐰t)𝝀^t2 (33)
𝐟(𝐰t)ηtJ𝐟(𝐰t)J𝐟(𝐰t)𝝀t+Lηt2J𝐟(𝐰t)𝝀t2+ηtJ𝐟(𝐰t)𝝃t+Lηt2𝝃t2 (34)
𝐟(𝐰t)ηt(1Lηt)J𝐟(𝐰t)𝝀t2+ηtM𝝃t+Lηt2𝝃t2, (35)

where we used the Lipschitz continuity fi(𝐰)M and the first-order optimality condition of 𝝀t so that

𝝀Δ,𝝀,J𝐟(𝐰t)J𝐟(𝐰t)𝝀t𝝀t,J𝐟(𝐰t)J𝐟(𝐰t)𝝀t. (36)

Letting ηt12L, taking expectations and rearranging we obtain

mint=0,,T𝐄[J𝐟(𝐰t)𝝀t]2[𝐟(𝐰0)𝐄𝐟(𝐰T+1)+t=0Tηt(Mσt+Lηtσt2)]t=0Tηt, (37)

where σt2:=𝐄𝝃t2.

Theorem 1b (full version).

Suppose each user function fi is L-Lipschitz smooth (i.e., 2fiL𝐈) and M-Lipschitz continuous. Then, for any number of local updates k, with global learning rate ηt(0,12L], deterministic gradient update and local learning rate ηtl, we have

mint=0,,TJ𝐟(𝐰t)𝝀t2[𝐟(𝐰0)𝐟(𝐰T+1)+M2ηtt=0T((εtm+ηtl(k1))+Lηt(εtm+ηtl(k1))2)]t=0Tηt, (38)

where εt:=𝛌t𝛌~t is the deviation between the exact and approximate (dual) weightings. Moreover, if some user function fi is bounded from below, then the left-hand side in (38) converges to 0 as long as εt0 , ηtl0 and ηt0 with tηt=.

Proof.

Let

𝝀t=argmin𝝀ΔJ𝐟(𝐰t)𝝀,𝝀~t=argmin𝝀ΔJ~𝐟(𝐰t)𝝀, (39)

and δt:=J𝐟(𝐰t)𝝀tJ~𝐟(𝐰t)𝝀~t, where J~𝐟(𝐰t):=[~f1(𝐰t),,~fm(𝐰t)] is the concatenation of accumulated updates ~fi(𝐰t) at each user. Formally, ~fi(𝐰t):=𝐰t𝐰tk, which denotes the difference between the initial 𝐰t and the final 𝐰tk after k local updates, for user i. (Note that we have abused the notation 𝐰t and 𝐰tk a bit for simplicity here, as they do not distinguish user i. This is not a big problem since the context is clear.)

Let 𝐰t1:=𝐰tfi(𝐰t) and 𝐰tj+1:=𝐰tjηtlfi(𝐰tj),j=1,,k1 be the local optimization steps.

Then,

~fi(𝐰t) =𝐰t𝐰tk (40)
=(𝐰t𝐰t1)+(𝐰t1𝐰t2)++(𝐰tk1𝐰tk) (41)
=fi(𝐰t)+ηtlfi(𝐰t1)++ηtlfi(𝐰tk1), (42)

Thus, the difference between ~fi(𝐰t) and gradient fi(𝐰t) is bounded by:

~fi(𝐰t)fi(𝐰t) =ηtlj=1k1fi(𝐰tj) (44)
ηtlj=1k1fi(𝐰tj) (45)
ηtl(k1)M, (46)

Thus,

δt =J𝐟(𝐰t)𝝀tJ~𝐟(𝐰t)𝝀~t (47)
J𝐟(𝐰t)𝝀tJ𝐟(𝐰t)𝝀~t+J𝐟(𝐰t)𝝀~tJ~𝐟(𝐰t)𝝀~t (48)
εtmM+ηtl(k1)M, (49)

the last step comes from matrix norm inequality on the first term, and triangular inequality on the second term. Note that δt vanishes when εt0 and ηtl0.

Then, applying the quadratic upper bound, we have

𝐟(𝐰t+1) 𝐟(𝐰t)ηtJ𝐟(𝐰t)J~𝐟(𝐰t)𝝀~t+Lηt22J~𝐟(𝐰t)𝝀~t2 (50)
=𝐟(𝐰t)ηtJ𝐟(𝐰t)J𝐟(𝐰t)𝝀t+Lηt2J𝐟(𝐰t)𝝀t2+ηtJ𝐟(𝐰t)δt+Lηt2δt2 (51)
𝐟(𝐰t)ηt(1Lηt)J𝐟(𝐰t)𝝀t2+ηtMδt+Lηt2δt2, (52)

Letting ηt12L, telescoping and rearranging we obtain

mint=0,,TJ𝐟(𝐰t)𝝀t2[𝐟(𝐰0)𝐟(𝐰T+1)+t=0Tηt(Mδt+Lηtδt2)]t=0Tηt, (53)

substitute δt with (49), and we get (38).

Finally, if εt0 and ηtl0, then δt0 and hence the right-hand side in (38) 0 when T, in which case the left-hand side mint=0,,TJ𝐟(𝐰t)𝝀t converges to 0 as well.

Theorem 2 (full version).

Suppose each user function fi is σ-strongly convex (i.e. 2fiσ𝐈) and M-Lipschitz continuous. Suppose at each round t FedMGDA includes some function fvt such that

fvt(𝐰t)fvt(𝐰t)t2𝐰t𝐰t2, (54)

where 𝐰t is the projection of 𝐰t to the Pareto stationary set W of (11). Assume 𝐄[λvtt+σt|𝐰t]c>0, then

𝐄[𝐰t+1𝐰t+12]πt(1cη0)𝐄[𝐰0𝐰02]+s=0tπtπsηs2M2, (55)

where πt=s=1tηs and π0=1. In particular,

  • if tηt=,tηt2<, then 𝐄[𝐰t𝐰t2]0 and 𝐰t converges to the Pareto stationarity set W almost surely;

  • with the choice ηt=2c(t+2) we have

    𝐄[𝐰t𝐰t2]4M2c2(t+3). (56)
Proof.

For each user i, let us define the function

f^i(𝐰,I):=Iifi(𝐰), (57)

where the random variable I{0,1}m indicates which user participates at a particular round. Clearly, we have 𝐄f^i(𝐰,I)=fi(𝐰)𝐄Ii. Therefore, our multi-objective minimization problem is equivalent as:

min𝐰{𝐄f^1(𝐰,I),,𝐄f^m(𝐰,I)}, (58)

since positive scaling does not change Pareto stationarity. (If one prefers, we can also normalize the stochastic functions f^i(𝐰,I) so that the unbiasedness property 𝐄f^i(𝐰,I)=fi(𝐰) holds.)

We now proceed as in Mercier et al., (2018) and provide a slightly sharper analysis. Let us denote 𝐰t the projection of 𝐰t to the Pareto-stationary set W of (58), i.e.,

𝐰t=argmin𝐰W𝐰t𝐰. (59)

Then,

𝐰t+1𝐰t+12 𝐰t+1𝐰t2 (60)
=𝐰tηt𝐝t𝐰t2 (61)
=𝐰t𝐰t22ηt𝐰t𝐰t,𝐝t+ηt2𝐝t2. (62)

To bound the middle term, we have from our assumption:

vt,f^vt(𝐰t,It)f^vt(𝐰t,It) t2𝐰t𝐰t2, (63)
i,f^i(𝐰t,It)f^i(𝐰t,It) 0, (64)

where the second inequality follows from the definition of 𝐰t. Therefore,

𝐰t𝐰t,𝐝t =𝐰t𝐰t,i:Ii=1λifi(𝐰t) (65)
i:Ii=1λi(fi(𝐰t)fi(𝐰t))+σt2𝐰t𝐰t2 (66)
=iλi(f^i(𝐰t,It)f^i(𝐰t,It))+σt2𝐰t𝐰t2 (67)
λvtt+σt2𝐰t𝐰t2. (68)

Continuing from (62) and taking conditional expectation:

𝐄[𝐰t+1𝐰t+12|𝐰t] (1ctηt)𝐰t𝐰t2+ηt2M2, (69)

where ct:=𝐄[λvtt+σt|𝐰t]c>0. Taking expectation we obtain the familiar recursion:

𝐄[𝐰t+1𝐰t+12] (1cηt)𝐄[𝐰t𝐰t2]+ηt2M2, (70)

from which we derive

𝐄[𝐰t+1𝐰t+12] πt(1cη0)𝐄[𝐰0𝐰02]+s=0tπtπsηs2M2, (71)

where πt=s=1t(1cηs) and π0=1. Since πt0tηt=, we know

𝐄[𝐰t+1𝐰t+12]0 (72)

if tηt= and tηt2<.

Setting ηt=2c(t+2) we obtain πt=2(t+2)(t+1) and by induction

s=0tπtπsηs2=4c2(t+2)(t+1)s=0ts+1s+24c2(t+4), (73)

whence

𝐄[𝐰t+1𝐰t+12]4M2c2(t+4). (74)

Using a standard supermartingale argument we can also prove that

𝐰t𝐰t0 almost surely. (75)

The proof is well-known in stochastic optimization hence omitted (or see Mercier et al., (2018, Theorem 5) for details).

Appendix B Full experimental results

In this section we provide additional results that are deferred from the main paper.

B.1 Recovering FedAvg full results: results on Fashion-MNIST and CIFAR-10

Complementary to the results shown in Figure 1, Figure 5 and Figure 6 summarize similar results on the F-MNIST dataset, while Figure 7 depicts the training losses on CIFAR-10 dataset in log-scale.

Refer to caption
Refer to caption
Figure 5: Interpolation between FedAvg and FedMGDA (F-MNIST, iid setting). (Left) Average user accuracy. (Right) Uniformly averaged training loss. Results are averaged over 5 runs with different random seeds.
Refer to caption
Refer to caption
Figure 6: Interpolation between FedAvg and FedMGDA (F-MNIST, non-iid setting). (Left) Average user accuracy. (Right) Uniformly averaged training loss. Results are averaged over 5 runs with different random seeds.
Refer to caption
Refer to caption
Figure 7: Interpolation between FedAvg and FedMGDA (CIFAR-10). Both figures plot the uniformly averaged training loss in log-scale. (Left) non-iid setting. (Right) iid setting. Results are averaged over 5 runs with different random seeds.

B.2 Robustness full results: bias attack on Adult dataset

Table 7 shows the full results of the experiment presented in Figure 2 (Left).

Table 7: Test accuracy of SOTA algorithms on Adult dataset with various scales of adversarial bias added to the domain loss of PhD; and compared to the baseline of training only on the PhD domain. The scale of bias for AFL is different from q-FedAvg since AFL uses averaged loss while q-FedAvg uses (non-averaged) total loss. The algorithms are run for 500 rounds, and the reported results are averaged across 5 runs with different random seeds.
Name Bias Uniform PhD Non-PhD
AFL 0 83.26±0.01 77.90±0.00 83.32±0.01
AFL 0.01 83.28±0.03 76.58±0.27 83.36±0.03
AFL 0.1 82.30±0.04 74.59±0.00 82.39±0.04
AFL 1 81.86±0.05 74.25±0.57 81.94±0.05
q-FedAvg, q=5 0 83.26±0.18 76.80±0.61 83.33±0.19
q-FedAvg, q=5 1000 83.34±0.04 76.57±0.44 83.41±0.04
q-FedAvg, q=5 5000 81.19±0.03 74.14±0.41 81.27±0.03
q-FedAvg, q=5 10000 81.07±0.03 73.48±0.78 81.16±0.02
q-FedAvg, q=2 0 83.30±0.09 76.46±0.56 83.38±0.09
q-FedAvg, q=2 1000 83.33±0.04 76.24±0.00 83.41±0.04
q-FedAvg, q=2 5000 83.11±0.03 75.69±0.00 83.20±0.03
q-FedAvg, q=2 10000 82.50±0.07 75.69±0.00 82.58±0.07
q-FedAvg, q=0.1 0 83.44±0.06 76.46±0.56 83.52±0.07
q-FedAvg, q=0.1 1000 83.3