overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + xk i=1 w2 =(hwt) t(hwt)+...

10
1 ©2016 Sham Kakade 1 Overfitting Machine Learning – CSE546 Sham Kakade University of Washington October 6, 2016 2 ©2016 Sham Kakade Announcements: HW1 posted Today: Review: bias variance Overfitting, overfitting overfitting Ridge regression Checking for overfitting: cross validation ©2016 Sham Kakade 3 Review: BiasVariance Tradeoff Machine Learning – CSE546 Sham Kakade University of Washington Oct 4, 2016 4 BiasVariance tradeoff – Intuition Model too “simple” does not fit the data well A biased solution Model too complex small changes to the data, solution changes a lot A highvariance solution ©2016 Sham Kakade

Upload: others

Post on 07-Aug-2020

1 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

1

©2016  Sham  Kakade 1

Overfitting

Machine  Learning  – CSE546Sham  KakadeUniversity  of  Washington

October  6,  20162©2016  Sham  Kakade

Announcements:

n HW1  posted

n Today:  ¨ Review:  bias  variance¨ Overfitting,  overfitting  overfitting¨ Ridge  regression¨ Checking  for  overfitting:  cross  validation

©2016  Sham  Kakade 3

Review:Bias-­Variance  Tradeoff

Machine  Learning  – CSE546Sham  KakadeUniversity  of  Washington

Oct  4,  20164

Bias-­Variance  tradeoff  – Intuition  

n Model  too  “simple”  è does  not  fit  the  data  well¨ A  biased  solution

n Model  too  complex  è small  changes  to  the  data,  solution  changes  a  lot¨ A  high-­variance  solution

©2016  Sham  Kakade

Page 2: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

2

5

(Squared)  Bias  of  learner

n Given  dataset  D  with  N  samples,  learn  function  hD(x)

n If  you  sample  a  different  dataset  D’  with  N samples,  you  will  learn  different  hD’(x)

n Expected  hypothesis:  ED[hD(x)]

n Bias:  difference  between  what  you  expect  to  learn  and  truth¨ Measures  how  well  you  expect  to  represent  true  solution¨ Decreases  with  more  complex  model  ¨ Bias2 at  one  point  x:¨ Average  Bias2:

©2016  Sham  Kakade 6

Variance  of  learner

n Given  dataset  D  with  N  samples,  learn  function  hD(x)

n If  you  sample  a  different  dataset  D’  with  N samples,  you  will  learn  different  hD’(x)

n Variance:  difference  between  what  you  expect  to  learn  and  what  you  learn  from  a  particular  dataset  ¨ Measures  how  sensitive  learner  is  to  specific  dataset¨ Decreases  with  simpler  model¨ Variance  at  one  point  x:¨ Average  variance:

©2016  Sham  Kakade

©2016  Sham  Kakade 7

Overfitting

Machine  Learning  – CSE546Sham  KakadeUniversity  of  Washington

October  6,  2016

Training  set  error

n Given  a  dataset  (Training  data)n Choose  a  loss  function

¨ e.g.,  squared  error  (L2)  for  regressionn Training  set  error: For  a  particular  set  of  parameters,  loss  function  on  training  data:

©2016  Sham  Kakade 8

Page 3: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

3

Training  set  error  as  a  function  of  model  complexity

©2016  Sham  Kakade 9

Prediction  error

n Training  set  error  can  be  poor  measure  of  “quality”  of  solution

n Prediction  error:  We  really  care  about  error  over  all  possible  input  points,  not  just  training  data:

©2016  Sham  Kakade 10

Prediction  error  as  a  function  of  model  complexity

©2016  Sham  Kakade 11

Computing  prediction  error

n Computing  prediction  ¨ Hard  integral¨ May  not  know  t(x)  for  every  x

n Monte  Carlo  integration  (sampling  approximation)¨ Sample  a  set  of  i.i.d.  points  {x1,…,xM}  from  p(x)¨ Approximate  integral  with  sample  average

©2016  Sham  Kakade 12

Page 4: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

4

Why  training  set  error  doesn’t  approximate  prediction  error?

n Sampling  approximation  of  prediction  error:

n Training  error  :

n Very  similar  equations!!!  ¨ Why  is  training  set  a  bad  measure  of  prediction  error???

©2016  Sham  Kakade 13

Why  training  set  error  doesn’t  approximate  prediction  error?

n Sampling  approximation  of  prediction  error:

n Training  error  :

n Very  similar  equations!!!  ¨ Why  is  training  set  a  bad  measure  of  prediction  error???

Because  you  cheated!!!  

Training  error  good  estimate  for  a  single  w,But  you  optimized  w with  respect  to  the  training  error,  and  found  w that  is  good  for  this  set  of  samples

Training  error  is  a  (optimistically)  biased  estimate  of  prediction  error  

©2016  Sham  Kakade 14

Test  set  error

n Given  a  dataset,  randomly split  it  into  two  parts:  ¨ Training  data  – {x1,…,  xNtrain}¨ Test  data  – {x1,…,  xNtest}

n Use  training  data  to  optimize  parameters  wn Test  set  error: For  the  final  output w,  evaluate  the  error  using:

©2016  Sham  Kakade 15

^

Test  set  error  as  a  function  of  model  complexity

©2016  Sham  Kakade 16

Page 5: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

5

How  many  points  to  I  use  for  training/testing?

n Very  hard  question  to  answer!¨ Too  few  training  points,  learned  w is  bad¨ Too  few  test  points,  you  never  know  if  you  reached  a  good  solution

n Bounds,  such  as  Hoeffding’s inequality  can  help:

n More  on  this  later  this  quarter,  but  still  hard  to  answern Typically:

¨ If  you  have  a  reasonable  amount  of  data,  pick  test  set  “large  enough”  for  a  “reasonable”  estimate  of  error,  and  use  the  rest  for  learning

¨ If  you  have  little  data,  then  you  need  to  pull  out  the  big  guns…n e.g.,  bootstrapping  

©2016  Sham  Kakade 17

Error  estimators  

©2016  Sham  Kakade 18

Error  estimators  

Be  careful!!!  

Test  set  only  unbiased  if  you  never  never ever  everdo  any  any any any learning  on  the  test  data

For  example,  if  you  use  the  test  set  to  selectthe  degree  of  the  polynomial…  no  longer  unbiased!!!

(We  will  address  this  problem  later)

©2016  Sham  Kakade 19 20

What  you  need  to  know

n True  error,  training  error,  test  error¨ Never  learn  on  the  test  data¨ Never  learn  on  the  test  data¨ Never  learn  on  the  test  data¨ Never  learn  on  the  test  data¨ Never  learn  on  the  test  data

n Overfitting

©2016  Sham  Kakade

Page 6: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

6

©2016  Sham  Kakade 21

Regularization

Machine  Learning  – CSE546Sham  KakadeUniversity  of  Washington

October  6,  2016

Regularization  in  Linear  Regression

n Overfitting usually  leads  to  very  large  parameter  choices,  e.g.:

n Regularized or  penalized regression  aims  to  impose  a  “complexity”  penalty  by  penalizing  large  weights¨ “Shrinkage”  method

-­2.2  +  3.1  X  – 0.30  X2 -­1.1  +  4,700,910.7  X  – 8,585,638.4  X2 +  …

22©2016  Sham  Kakade

Ridge  Regression

23

n Aleviating issues  with  overfitting:  

n New  objective:

©2016  Sham  Kakade 24

Ridge  Regression  in  Matrix  Notation

N  data  points

K+1  basis  functions

N  data  points

observationsweights

K+1  basis  func

©2016  Sham  Kakade

h0 . . . hk

+ � wT I0+kw

wridge = argminw

NX

j=1

t(xj)� (w0 +

kX

i=1

wihi(xj))

!2

+ �

kX

i=1

w

2i

Page 7: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

7

Minimizing  the  Ridge  Regression  Objective

25

wridge = argminw

NX

j=1

t(xj)� (w0 +

kX

i=1

wihi(xj))

!2

+ �

kX

i=1

w2i

= (Hw � t)T (Hw � t) + � wT I0+kw

©2016  Sham  Kakade

Shrinkage  Properties

26

n If  orthonormal  features/basis:  

wridge = (HTH + � I0+k)�1HT t

HTH = I

©2016  Sham  Kakade

Ridge  Regression:  Effect  of  Regularization

27

n Solution  is  indexed  by  the  regularization  parameter  λn Larger  λ

n Smaller  λ

n As  λ à 0

n As  λ à∞

wridge = argminw

NX

j=1

t(xj)� (w0 +

kX

i=1

wihi(xj))

!2

+ �

kX

i=1

w

2i

©2016  Sham  Kakade

Ridge  Coefficient  Path

n Typical  approach:  select  λ using  cross  validation,  more  on  this  later  in  the  quarter

28

From  Kevin  Murphy  textbook

©2016  Sham  Kakade

Page 8: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

8

What  you  need  to  know…

n Regularization¨ Penalizes  for  complex  models

n Ridge  regression¨ L2 penalized  least-­squares  regression¨ Regularization  parameter  trades  off  model  complexity  with  training  error

29©2016  Sham  Kakade©2016  Sham  Kakade 30

Cross-­Validation

Machine  Learning  – CSE546Sham  KakadeUniversity  of  Washington

October  6,  2016

Test  set  error  as  a  function  of  model  complexity

©2016  Sham  Kakade 31

How…  How…  How???????

n How  do  we  pick  the  regularization  constant  λ…¨ And  all  other  constants  in  ML,  ‘cause  one  thing  ML  doesn’t  lack  is  constants  to  tune…  L

n We  could  use  the  test  data,  but…  

32©2016  Sham  Kakade

Page 9: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

9

(LOO)  Leave-­one-­out  cross  validation

n Consider  a  validation  set  with  1  example:¨ D – training  data¨ D\j – training  data  with  j th data  point  moved  to  validation  set

n Learn  classifier  hD\j with  D\j  datasetn Estimate  true  error as  squared  error  on  predicting  t(xj):

¨ Unbiased  estimate  of  errortrue(hD\j)!

¨ Seems  really  bad  estimator,  but  wait!

n LOO  cross  validation:  Average  over  all  data  points  j:¨ For  each  data  point  you  leave  out,  learn  a  new  classifier  hD\j¨ Estimate  error as:  

33

errorLOO =1

N

NX

j=1

�t(xj)� hD\j(xj)

�2

©2016  Sham  Kakade

LOO  cross  validation  is  (almost)  unbiased  estimate  of  true  error  of  hD!

n When  computing  LOOCV  error,  we  only  use  N-­1 data  points¨ So  it’s  not  estimate  of  true  error  of  learning  with  N data  points!¨ Usually  pessimistic,  though  – learning  with  less  data  typically  gives  worse  answer

n LOO  is  almost  unbiased!

n Great  news!¨ Use  LOO  error  for  model  selection!!!¨ E.g.,  picking  λ

34©2016  Sham  Kakade

Computational  cost  of  LOO

n Suppose  you  have  100,000  data  pointsn You  implemented  a  great  version  of  your  learning  algorithm¨ Learns  in  only  1  second  

n Computing  LOO  will  take  about  1  day!!!¨ If  you  have  to  do  for  each  choice  of  basis  functions,  it  will  take  fooooooreeeve’!!!

n Solution  1:  Preferred,  but  not  usually  possible¨ Find  a  cool  trick  to  compute  LOO  (e.g.,  see  homework)

35©2016  Sham  Kakade

Solution  2  to  complexity  of  computing  LOO:    (More  typical)  Use  k-­fold  cross  validation

n Randomly  divide  training  data  into  k equal  parts¨ D1,…,Dk

n For  each  i¨ Learn  classifier  hD\Di using  data  point  not  in  Di  ¨ Estimate  error  of  hD\Di on  validation  set  Di:

n k-­fold  cross  validation  error  is  average  over  data  splits:

n k-­fold  cross  validation  properties:¨ Much  faster  to  compute  than  LOO¨ More  (pessimistically)  biased – using  much  less  data,  only  m(k-­1)/k¨ Usually,  k  =  10 J

36

errorDi =k

N

X

xj2Di

�t(xj)� hD\Di

(xj)�2

©2016  Sham  Kakade

Page 10: overfitting-regularization-xvalidation-annotated ... · w ih i(x j))! 2 + Xk i=1 w2 =(Hwt) T(Hwt)+ w I 0+kw ©2016&Sham&Kakade Shrinkage&Properties 26! Iforthonormalfeatures/basis:

10

What  you  need  to  know…

n Never  ever ever ever ever ever ever ever everever ever ever ever ever ever ever ever everever ever ever ever ever ever ever ever ever  train  on  the  test  data

n Use  cross-­validation  to  choose  magic  parameters  such  as  λ

n Leave-­one-­out  is  the  best  you  can  do,  but  sometimes  too  slow¨ In  that  case,  use  k-­fold  cross-­validation

37©2016  Sham  Kakade