failures of gradient-based deep learning€¦ · learning many orthogonal functions let hbe a...

46
Failures of Gradient-Based Deep Learning Shai Shalev-Shwartz, Shaked Shammah, Ohad Shamir The Hebrew University and Mobileye Representation Learning Workshop Simons Institute, Berkeley, 2017 Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 1 / 38

Upload: others

Post on 21-Jul-2020

0 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Failures of Gradient-Based Deep Learning

Shai Shalev-Shwartz, Shaked Shammah, Ohad Shamir

The Hebrew University and Mobileye

Representation Learning WorkshopSimons Institute, Berkeley, 2017

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 1 / 38

Page 2: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Deep Learning is Amazing ...

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 2 / 38

Page 3: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Deep Learning is Amazing ...

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 3 / 38

Page 4: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Deep Learning is Amazing ...

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 4 / 38

Page 5: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Deep Learning is Amazing ...

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 5 / 38

Page 6: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Deep Learning is Amazing ...

INTEL IS BUYING MOBILEYE FOR $15.3 BILLION IN BIGGESTISRAELI HI-TECH DEAL EVER

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 6 / 38

Page 7: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

This Talk

Simple problems where standard deep learning either

Does not work well

Requires prior knowledge for better architectural/algorithmic choicesRequires other than gradient update ruleRequiers to decompose the problem and add more supervision

Does not work at all

No “local-search” algorithm can workEven for “nice” distributions and well-specified modelsEven with over-parameterization (a.k.a. improper learning)

Mix of theory and experiments

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 7 / 38

Page 8: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Outline

1 Piece-wise Linear Curves

2 Flat Activations

3 End-to-end Training

4 Learning Many Orthogonal Functions

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 8 / 38

Page 9: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Piecewise-linear Curves: Motivation

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 9 / 38

Page 10: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Piecewise-linear Curves

Problem: Train a piecewise-linear curve detector

Input: f = (f(0), f(1), . . . , f(n− 1)) where

f(x) =

k∑r=1

ar[x− θr]+ , θr ∈ 0, . . . , n− 1

Output: Curve parameters ar, θrkr=1

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 10 / 38

Page 11: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

First try: Deep AutoEncoder

Encoding network, Ew1 : Dense(500,relu)-Dense(100,relu)-Dense(2k)

Decoding network, Dw2 : Dense(100,relu)-Dense(100,relu)-Dense(n)

Squared Loss: (Dw2(Ew1(f))− f)2

Doesn’t work well ...

500 iterations 10,000 iterations 50,000 iterations

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 11 / 38

Page 12: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

First try: Deep AutoEncoder

Encoding network, Ew1 : Dense(500,relu)-Dense(100,relu)-Dense(2k)

Decoding network, Dw2 : Dense(100,relu)-Dense(100,relu)-Dense(n)

Squared Loss: (Dw2(Ew1(f))− f)2

Doesn’t work well ...

500 iterations 10,000 iterations 50,000 iterations

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 11 / 38

Page 13: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Second try: Pose as a Convex Objective

Problem: Train a piecewise-linear curve detector

Input: f ∈ Rn where f(x) =∑k

r=1 ar[x− θr]+Output: Curve parameters ar, θrkr=1

Convex Formulation

Let p ∈ Rn,n be a k-sparse vector whose k max/argmax elements arear, θrkr=1

Observe: f = Wp where W ∈ Rn,n is s.t. Wi,j = [i− j + 1]+

Learning approach — linear regression: Train a one-layer fullyconnected network on (f ,p) examples:

minU

E[(U f − p)2

]= E

[(U f −W−1f)2

]

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 12 / 38

Page 14: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Second try: Pose as a Convex Objective

minU

E[(U f − p)2

]= E

[(U f −W−1f)2

]Convex; Realizable; but still doesn’t work well ...

500 iterations 10,000 iterations 50,000 iterations

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 13 / 38

Page 15: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

What went wrong?

Theorem

The convergence of SGD is governed by the condition number of W>W ,which is large:

λmax(W>W )

λmin(W>W )= Ω(n3.5)

⇒ SGD requires Ω(n3.5) iterations to reach U s.t.∥∥E[U ]−W−1

∥∥ < 1/2

Note: Adagrad/Adam doesn’t work because they perform diagonalconditioning

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 14 / 38

Page 16: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

3rd Try: Pose as a Convolution

p = W−1f

Observation:

W−1 =

1 0 0 0 · · ·−2 1 0 0 · · ·1 −2 1 0 · · ·0 1 −2 1 · · ·0 0 1 −2 · · ·...

......

W−1f is 1D convolution of f with “2nd derivative” filter (1,−2, 1)

Can train a one-layer convnet to learn filter (problem in R3!)

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 15 / 38

Page 17: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Better, but still doesn’t work well ...

500 iterations 10,000 iterations 50,000 iterations

Theorem: Condition number reduced to Θ(n3). Convolutions aidgeometry!

But, Θ(n3) is very disappointing for a problem in R3 ...

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 16 / 38

Page 18: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

4th Try: Convolution + Preconditioning

ConvNet is equivalent to solving minx E[(Fx− p)2

]where F is

n× 3 matrix with (f(i− 1), f(i), f(i+ 1)) at row i

Observation: Problem is now low-dimensional, so can easilyprecondition

Compute empirical approximation C of E[F>F ]Solve minx E[(FC−1/2x− p)2]Return C1/2x

Condition number of FC−1/2 is close to 1

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 17 / 38

Page 19: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Finally, it works ...

500 iterations 10,000 iterations 50,000 iterations

Remark:

Use of convnet allows for efficient preconditioning

Estimating and manipulating 3× 3 rather than n× n matrices.

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 18 / 38

Page 20: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Lesson Learned

SGD might be extremely slow

Prior knowledge allows us to:

Choose a better architecture (not for expressivity, but for a bettergeometry)Choose a better algorithm (preconditioned SGD)

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 19 / 38

Page 21: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Outline

1 Piece-wise Linear Curves

2 Flat Activations

3 End-to-end Training

4 Learning Many Orthogonal Functions

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 20 / 38

Page 22: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Flat Activations

Vanishing gradients due to saturating activations (e.g. in RNN’s)

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 21 / 38

Page 23: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Flat Activations

Problem: Learning x 7→ u(〈w?,x〉) where u is a fixed step function

Optimization problem:

minw

Ex

[(u(Nw(x))− u(w?>x))2]

u′(z) = 0 almost everywhere → can’t apply gradient-based methods

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 22 / 38

Page 24: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Flat Activations: Smooth approximation

Smooth approximation: replace u with u (similar to using sigmoids asgates in LSTM’s)

uu

Sometimes works, but slow and only approximate result. Often completelyfails

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 23 / 38

Page 25: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Flat Activations: Perhaps I should use a deeper network ...

Approach: End-to-end

minw

Ex

[(Nw(x)− u(w?>x))2]

(3 ReLU + 1 linear layers; 10000 iterations)

Slow train+test time; curve not captured well

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 24 / 38

Page 26: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Flat Activations: Multiclass

Approach: Multiclassminw

Ex

[`(Nw(x), y(x))]

Nw(x): to which step does x belong

Problem capturing boundaries

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 25 / 38

Page 27: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Flat Activations: “Forward only” backpropagation

Different approach (Kalai & Sastry 2009, Kakade, Kalai, Kanade, Shamir2011): Gradient descent, but replace gradient with something else

Objective: minw

Ex

[1

2

((u(w>x))− u(w?>x)

)2]Gradient: ∇ = E

x

[(u(w>x)− u(w?>x)

)· u′(w>x) · x

]Non-gradient direction: ∇ = E

x

[(u(w>x)− u(w?>x))x

]Interpretation: “Forward only” backpropagation

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 26 / 38

Page 28: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Flat Activation

(linear; 5000 iterations)

Best results, and smallest train+test time

Analysis (KS09, KKKS11): Needs O(L2/ε2) iterations if u isL-Lipschitz

Lesson learned: Local search works, but not with the gradient ...

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 27 / 38

Page 29: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Outline

1 Piece-wise Linear Curves

2 Flat Activations

3 End-to-end Training

4 Learning Many Orthogonal Functions

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 28 / 38

Page 30: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

End-to-End vs. Decomposition

Input x: k-tuple of images of random lines

f1(x): For each image, whether slope is negative/positive

f2(x): return parity of slope signs

Goal: Learn f2(f1(x))

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 29 / 38

Page 31: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

End-to-End vs. Decomposition

Architecture: Concatenation of Lenet and 2-layer ReLU, linked bysigmoid

End-to-end approach: Train overall network on primary objective

Decomposition approach: Augment objective with loss specific to firstnet, using per-image labels

0.3

1k = 1

0.3

1k = 2

0.3

1k = 3

20000 iterations

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 30 / 38

Page 32: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

End-to-End vs. Decomposition

Architecture: Concatenation of Lenet and 2-layer ReLU, linked bysigmoid

End-to-end approach: Train overall network on primary objective

Decomposition approach: Augment objective with loss specific to firstnet, using per-image labels

0.3

1k = 1

0.3

1k = 2

0.3

1k = 3

20000 iterations

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 30 / 38

Page 33: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

End-to-End vs. Decomposition

Why end-to-end training doesn’t work ?Similar experiment by Gulcehre and Bengio, 2016They suggest “local minima” problemsWe show that the problem is different

Signal-to-Noise Ratio (SNR) for random initialization:End-to-end (red) vs. decomposition (blue), as a function of kSNR of end-to-end for k ≥ 3 is below the precision of float32

1 2 3 40

·10−4

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 31 / 38

Page 34: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

End-to-End vs. Decomposition

Why end-to-end training doesn’t work ?Similar experiment by Gulcehre and Bengio, 2016They suggest “local minima” problemsWe show that the problem is different

Signal-to-Noise Ratio (SNR) for random initialization:End-to-end (red) vs. decomposition (blue), as a function of kSNR of end-to-end for k ≥ 3 is below the precision of float32

1 2 3 40

·10−4

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 31 / 38

Page 35: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Outline

1 Piece-wise Linear Curves

2 Flat Activations

3 End-to-end Training

4 Learning Many Orthogonal Functions

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 32 / 38

Page 36: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Learning Many Orthogonal Functions

Let H be a hypothesis class of orthonormal functions:∀h, h′ ∈ H, E[h(x)h′(x)] = 0

(Improper) learning of H using gradient-based deep learning:

Learn the parameter vector, w, of some architecture, pw : X → RFor every target h ∈ H, the learning task is to solve:

minw

Fh(w) := Ex

[`(pw(x), h(x)]

Start with a random w and update the weights based on ∇Fh(w)

Analysis tool: How much ∇Fh(w) tells us about the identity of h ?

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 33 / 38

Page 37: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Learning Many Orthogonal Functions

Let H be a hypothesis class of orthonormal functions:∀h, h′ ∈ H, E[h(x)h′(x)] = 0

(Improper) learning of H using gradient-based deep learning:

Learn the parameter vector, w, of some architecture, pw : X → RFor every target h ∈ H, the learning task is to solve:

minw

Fh(w) := Ex

[`(pw(x), h(x)]

Start with a random w and update the weights based on ∇Fh(w)

Analysis tool: How much ∇Fh(w) tells us about the identity of h ?

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 33 / 38

Page 38: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Learning Many Orthogonal Functions

Let H be a hypothesis class of orthonormal functions:∀h, h′ ∈ H, E[h(x)h′(x)] = 0

(Improper) learning of H using gradient-based deep learning:

Learn the parameter vector, w, of some architecture, pw : X → RFor every target h ∈ H, the learning task is to solve:

minw

Fh(w) := Ex

[`(pw(x), h(x)]

Start with a random w and update the weights based on ∇Fh(w)

Analysis tool: How much ∇Fh(w) tells us about the identity of h ?

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 33 / 38

Page 39: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Analysis: how much information in the gradient?

minw

Fh(w) := Ex

[`(pw(x), h(x)]

Theorem: For every w, there are many pairs h, h′ ∈ H s.t.Ex[h(x)h′(x)] = 0 while

‖∇Fh(w)−∇Fh′(w)‖2 = O

(1

|H|

)

Proof idea: show that if the functions in H are orthonormal then, forevery w,

Var(H, F,w) := Eh

∥∥∥∥∇Fh(w)− Eh′∇Fh′(w)

∥∥∥∥2 = O

(1

|H|

)To do so, express every coordinate of ∇pw(x) using the orthonormalfunctions in H

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 34 / 38

Page 40: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Analysis: how much information in the gradient?

minw

Fh(w) := Ex

[`(pw(x), h(x)]

Theorem: For every w, there are many pairs h, h′ ∈ H s.t.Ex[h(x)h′(x)] = 0 while

‖∇Fh(w)−∇Fh′(w)‖2 = O

(1

|H|

)

Proof idea: show that if the functions in H are orthonormal then, forevery w,

Var(H, F,w) := Eh

∥∥∥∥∇Fh(w)− Eh′∇Fh′(w)

∥∥∥∥2 = O

(1

|H|

)To do so, express every coordinate of ∇pw(x) using the orthonormalfunctions in H

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 34 / 38

Page 41: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Proof idea

Assume the squared loss, then

∇Fh(w) = Ex

[(pw(x)− h(x)∇pw(x)]

= Ex

[pw(x)∇pw(x)]︸ ︷︷ ︸independent of h

−Ex

[h(x)∇pw(x)]

Fix some j and denote g(x) = ∇jpw(x)

Can expand g =∑|H|

i=1〈hi, g〉hi + orthogonal component

Therefore, Eh (Ex[h(x) g(x)])2 ≤ Ex[g(x)2]|H|

It follows that

Var(H, F,w) ≤ Ex[‖∇pw(x)‖2]|H|

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 35 / 38

Page 42: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Proof idea

Assume the squared loss, then

∇Fh(w) = Ex

[(pw(x)− h(x)∇pw(x)]

= Ex

[pw(x)∇pw(x)]︸ ︷︷ ︸independent of h

−Ex

[h(x)∇pw(x)]

Fix some j and denote g(x) = ∇jpw(x)

Can expand g =∑|H|

i=1〈hi, g〉hi + orthogonal component

Therefore, Eh (Ex[h(x) g(x)])2 ≤ Ex[g(x)2]|H|

It follows that

Var(H, F,w) ≤ Ex[‖∇pw(x)‖2]|H|

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 35 / 38

Page 43: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Example: Parity Functions

H is the class of parity functions over 0, 1d and x is uniformlydistributed

There are 2d orthonormal functions, hence there are many pairsh, h′ ∈ H s.t. Ex[h(x)h′(x)] = 0 while‖∇Fh(w)−∇Fh′(w)‖2 = O(2−d)

Remark:

Similar hardness result can be shown by combining existing results:

Parities on uniform distribution over 0, 1d is difficult for statisticalquery algorithms (Kearns, 1999)Gradient descent with approximate gradients can be implemented withstatistical queries (Feldman, Guzman, Vempala 2015)

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 36 / 38

Page 44: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Example: Parity Functions

H is the class of parity functions over 0, 1d and x is uniformlydistributed

There are 2d orthonormal functions, hence there are many pairsh, h′ ∈ H s.t. Ex[h(x)h′(x)] = 0 while‖∇Fh(w)−∇Fh′(w)‖2 = O(2−d)

Remark:

Similar hardness result can be shown by combining existing results:

Parities on uniform distribution over 0, 1d is difficult for statisticalquery algorithms (Kearns, 1999)Gradient descent with approximate gradients can be implemented withstatistical queries (Feldman, Guzman, Vempala 2015)

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 36 / 38

Page 45: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Visual Illustration: Linear-Periodic Functions

Fh(w) = Ex[(cos(w>x)− h(x))2] for h(x) = cos([2, 2]>x), in 2dimensions, x ∼ N (0, I):

No local minima/saddle points

However, extremely flat unless very close to optimum⇒ difficult for gradient methods, especially stochastic

In fact, difficult for any local-search method

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 37 / 38

Page 46: Failures of Gradient-Based Deep Learning€¦ · Learning Many Orthogonal Functions Let Hbe a hypothesis class of orthonormal functions: 8h;h02H, E[h(x)h0(x)] = 0 (Improper) learning

Summary

Cause of failures: optimization can be difficult for geometric reasonsother than local minima / saddle points

Condition number, flatnessUsing bigger/deeper networks doesn’t always help

Remedies: prior knowledge can still be important

Convolution can improve geometry (and not just sample complexity)“Other than gradient” update ruleDecomposing the problem and adding supervision can improvegeometry

Understanding the limitations: While deep learning is great,understanding the limitations may lead to better algorithms and/orbetter theoretical guarantees

For more information:

“Failures of Deep Learning”: arxiv 1703.07950github.com/shakedshammah/failures_of_DL

Shai Shalev-Shwartz (huji,ME) Failures of Gradient-Based DL Berkeley’17 38 / 38