MOCHA: Federated Multi-Task Learning
Virginia SmithStanford / CMU
Chao-Kai Chiang · USCMaziar Sanjabi · USC
Ameet Talwalkar · CMU
NIPS ‘17
MACHINE LEARNING WORKFLOW
machine learning model
data & problem
optimization algorithm minw
nX
i=1
`(w, xi) + g(w)
MACHINE LEARNING WORKFLOW
optimization algorithm minw
nX
i=1
`(w, xi) + g(w)
systems setting
data & problem
machine learning model
^ IN PRACTICE
how can we perform fast distributed optimization?
BEYOND THE DATACENTER
Massively Distributed Node Heterogeneity Unbalanced Non-IID Underlying Structure
BEYOND THE DATACENTER
Massively Distributed Node Heterogeneity
Unbalanced Non-IID Underlying Structure
Statistical Challenges
Systems Challenges
MACHINE LEARNING WORKFLOW
optimization algorithm minw
nX
i=1
`(w, xi) + g(w)
data & problem
machine learning model
^ IN PRACTICE
systems setting
MACHINE LEARNING WORKFLOW
optimization algorithm minw
nX
i=1
`(w, xi) + g(w)
data & problem
machine learning model
^ IN PRACTICE
systems setting
Statistical ChallengesUnbalanced Non-IID Underlying Structure
OUTLINE
Massively Distributed Node Heterogeneity
Systems Challenges
Statistical ChallengesUnbalanced Non-IID Underlying Structure
OUTLINE
Massively Distributed Node Heterogeneity
Systems Challenges
A GLOBAL APPROACH
W
[MMRHA, AISTATS 16]
A LOCAL APPROACH
W1
W2
W3
W4
W5
W6
W7
W8
W10
W9
W11
W12
OUR APPROACH: PERSONALIZED MODELS
W1
W2
W3
W4
W5
W6
W7
W8
W10
W9
W11
W12
OUR APPROACH: PERSONALIZED MODELS
W1
W2
W3
W4
W5
W6
W7
W8
W10
W9
W11
W12
task relationship
regularizermodels
MULTI-TASK LEARNING
All tasks related
Outlier tasks
Clusters / groups
Asymmetric relationships
[ZCY, SDM 2012]
losses
minW,⌦
mX
t=1
ntX
i=1
`t(wt,xit) +R(W,⌦)
FEDERATED DATASETS
Human Activity
Google Glass
Land Mine
Vehicle Sensor
PREDICTION ERROR
Global Local MTL
Human Activity
2.23 (0.30)
1.34 (0.21)
0.46 (0.11)
Google Glass
5.34 (0.26)
4.92 (0.26)
2.02 (0.15)
Land Mine
27.72 (1.08)
23.43 (0.77)
20.09 (1.04)
Vehicle Sensor
13.4 (0.26)
7.81 (0.13)
6.59 (0.21)
Statistical ChallengesUnbalanced Non-IID Underlying Structure
OUTLINE
Massively Distributed Node Heterogeneity
Systems Challenges
Statistical ChallengesUnbalanced Non-IID Underlying Structure
OUTLINE
Massively Distributed Node Heterogeneity
Systems Challenges
GOAL: FEDERATED OPTIMIZATION FOR MULTI-TASK LEARNING
minW,⌦
mX
t=1
ntX
i=1
`t(wTt x
it) +R(W,⌦)
Solve for W, 𝛀 in an alternating fashion 𝛀 can be updated centrally W needs to be solved in federated setting
Challenges: Communication is expensive Statistical & systems heterogeneity
Stragglers Fault tolerance
GOAL: FEDERATED OPTIMIZATION FOR MULTI-TASK LEARNING
minW,⌦
mX
t=1
ntX
i=1
`t(wTt x
it) +R(W,⌦)
Solve for W, 𝛀 in an alternating fashion 𝛀 can be updated centrally W needs to be solved in federated setting
Challenges: Communication is expensive Statistical & systems heterogeneity
Stragglers Fault tolerance
Idea:
Modify a communication-efficient method for the data center setting to handle:
✔ Multi-task learning ✔ Stragglers
✔ Fault tolerance
COCOA: COMMUNICATION-EFFICIENT DISTRIBUTED OPTIMIZATION
mini-batch methods
one-shot communication
key idea:control communication
COCOA: PRIMAL-DUAL FRAMEWORK
PRIMAL DUAL≥
minw2Rd
1
n
nX
i=1
`(wTxi) + �g(w) max
↵2Rn� 1
n
nX
i=1
`⇤(�↵i)� �g⇤(X,↵)
max
↵2Rn
1
n
nX
i=1
�`⇤(�↵i)� �KX
k=1
g̃⇤(X[k],↵[k])
αk(t)αk(t+1) αk
COCOA: PRIMAL-DUAL FRAMEWORK
PRIMAL DUAL≥
minw2Rd
1
n
nX
i=1
`(wTxi) + �g(w) max
↵2Rn� 1
n
nX
i=1
`⇤(�↵i)� �g⇤(X,↵)
max
↵2Rn
1
n
nX
i=1
�`⇤(�↵i)� �KX
k=1
g̃⇤(X[k],↵[k])
αk(t)αk(t+1) αk
challenge #1: extend to MTL setup
COCOA: COMMUNICATION PARAMETER
𝚹 ≈amount of local
computationvs.
communication∈ [0, 1)
Main assumption: each subproblem is solved to accuracy 𝚹
exactly solve
inexactly solve
COCOA: COMMUNICATION PARAMETER
𝚹 ≈amount of local
computationvs.
communication∈ [0, 1)
Main assumption: each subproblem is solved to accuracy 𝚹
exactly solve
inexactly solve
challenge #2: make communication
more flexible
MOCHA: COMMUNICATION-EFFICIENT FEDERATED OPTIMIZATION
minW,⌦
mX
t=1
ntX
i=1
`t(wTt x
it) +R(W,⌦)
Solve for W, 𝛀 in an alternating fashion
Modify CoCoA to solve W in federated setting
min↵
mX
t=1
ntX
i=1
`⇤t (�↵it) +R⇤(X↵)
min�↵t
ntX
i=1
`⇤t (�↵it ��↵i
t) + hwt(↵),Xt�↵ti+�0
2kXt�↵tk2Mt
MOCHA: PER-DEVICE, PER-ITERATION APPROXIMATIONS
Stragglers (Statistical heterogeneity)Difficulty of solving subproblem Size of local dataset
Stragglers (Systems heterogeneity)Hardware (CPU, memory) Network connection (3G, LTE, …) Power (battery level)
Fault toleranceDevices going offline
New assumption: each subproblem is solved to accuracy
✓ht 2 [0, 1]✓ 2 [0, 1)
CONVERGENCE
linear rate1/ε rate
New assumption: each subproblem is solved to accuracy ✓ht 2 [0, 1]
and assume: P[✓ht := 1] < 1
Theorem 1. Let be -Lipschitz, thenL
T � 1
(1� ⇥̄)
✓8L2n2
✏+ c̃
◆
`t Theorem 2. Let be -smooth, then
T � 1
(1� ¯
⇥)
µ+ n
µlog
n
✏
(1/µ)`t
MOCHA: COMMUNICATION-EFFICIENT FEDERATED OPTIMIZATION
Algorithm 1 Mocha: Federated Multi-Task Learning Framework
1: Input: Data Xt stored on t = 1, . . . ,m devices
2: Initialize ↵(0):= 0, v
(0):= 0
3: for iterations i = 0, 1, . . . do
4: for iterations h = 0, 1, · · · , Hi do
5: for devices t 2 {1, 2, . . . ,m} in parallel do
6: call local solver, returning ✓ht -approximate solution �↵t
7: update local variables ↵t ↵t +�↵t
8: reduce: v v +
Pt Xt�↵t
9: Update ⌦ centrally using w(v) := rR⇤(v)
10: Compute w(v) := rR⇤(v)
11: return: W := [w1, . . . ,wm]
STATISTICAL HETEROGENEITY
0 1 2 3 4 5 6 7Estimated Time 106
10-3
10-2
10-1
100
101
102
Prim
al S
ub-O
ptim
ality
Human Activity: Statistical Heterogeneity (WiFi)
MOCHACoCoAMb-SDCAMb-SGD
Wifi
0 1 2 3 4 5 6 7 8Estimated Time 106
10-3
10-2
10-1
100
101
102
Prim
al S
ub-O
ptim
ality
Human Activity: Statistical Heterogeneity (LTE)
MOCHACoCoAMb-SDCAMb-SGD
LTE
0 0.5 1 1.5 2Estimated Time 107
10-3
10-2
10-1
100
101
102
Prim
al S
ub-O
ptim
ality
Human Activity: Statistical Heterogeneity (3G)
MOCHACoCoAMb-SDCAMb-SGD
3G
MOCHA IS ROBUST TO STATISTICAL
HETEROGENEITY
MOCHA & COCOA PERFORM
PARTICULARLY WELL IN HIGH-
COMMUNICATION SETTINGS
SYSTEMS HETEROGENEITY
0 1 2 3 4 5 6 7 8Estimated Time 106
10-3
10-2
10-1
100
101
102
Prim
al S
ub-O
ptim
ality
Vehicle Sensor: Systems Heterogeneity (Low)
MOCHACoCoAMb-SDCAMb-SGD
Low
0 1 2 3 4 5 6 7 8Estimated Time 106
10-3
10-2
10-1
100
101
102
Prim
al S
ub-O
ptim
ality
Vehicle Sensor: Systems Heterogeneity (High)
MOCHACoCoAMb-SDCAMb-SGD
High
MOCHA SIGNIFICANTLY OUTPERFORMS ALL COMPETITORS
[BY 2 ORDERS OF MAGNITUDE]
FAULT TOLERANCE
0 2 4 6 8 10Estimated Time 106
10-3
10-2
10-1
100
101
102
Prim
al S
ub-O
ptim
ality
Google Glass: Fault Tolerance, W Step
W-Step
0 1 2 3 4 5 6 7 8Estimated Time 107
10-3
10-2
10-1
100
101
102
Prim
al S
ub-O
ptim
ality
Google Glass: Fault Tolerance, Full Method
Full Method
MOCHA IS ROBUST TO DROPPED NODES
Statistical ChallengesUnbalanced Non-IID Underlying Structure
OUTLINE
Massively Distributed Node Heterogeneity
Systems Challenges
Virginia Smith Stanford / CMU
cs.berkeley.edu/~vsmithCODE & PAPERS
WWW.SYSML.CC