learning representations for counterfactual...

Post on 16-Nov-2019

2 Views

Category:

Documents

0 Downloads

Preview:

Click to see full reader

TRANSCRIPT

Learningrepresentationsforcounterfactualinference

FredrikD.Johansson*2,UriShalit*1,DavidSontag1*EqualContribution

2

1NIPS2016DeepLearningSymposiumDecember2016

Talktodayabouttwopapers

•FredrikD.Johansson,UriShalit,DavidSontag“LearningRepresentationsforCounterfactualInference”ICML2016

•UriShalit,FredrikD.Johansson,DavidSontag“Estimatingindividualtreatmenteffect:generalizationboundsandalgorithms”arXiv:1606.03976

Code:https://github.com/clinicalml/cfrnet

Causalinferencefromobservationaldata•Patient“Anna”comesinwithhypertension

• Asian,54,historyofdiabetes,bloodpressure150/95,…•Whichofthetreatments𝑡 willcauseAnnatohavelowerbloodpressure?• Calciumchannelblocker(𝑡 = 1)• ACEinhibitor(𝑡 = 0)

•Datasetofobservationaldatafrommanypatients:medications,bloodtests,pastdiagnoses,demographics…

Causalinferencefromobservationaldata

Howtobestuseobservationaldata for

individual-levelcausalinference?

Causalinferencefromobservationaldata:Jobtraining•1,000unemployedpersons• Jobtrainingprogramwithcapacityof100

• Training(𝑡 = 1)• Notraining(𝑡 = 0)

•Whoshouldgetjobtraining?• Forwhichpersonswilljobtraininghavethemostimpact?

•Observationaldataaboutthousandsofpeople:jobhistory,jobtraining,education,skills,demographics…

•Datasetoffeatures,actionsandoutcomes

•Wedonotcontroltheactions•Wedonotknowthemodelgeneratingtheactions

Observationaldata

Causalinferencefromobservationaldataandreinforcementlearning

•Robotonthesideline,learningbyobservingotherrobotsplayingrobotfootball

•Sideline-robotdoesnotknowtheplaying-robots’internalmodel

•Formofoff-policylearning,learningfromdemonstration

OutlineBackgroundModelExperimentsTheory

OutlineBackgroundModelExperimentsTheory

•Patient“Anna”comesinwithhypertension• Asian,54,historyofdiabetes,bloodpressure150/95,…

•Whichofthetreatments𝑡willlowerAnna’sbloodpressure?• Calciumchannelblocker(𝑡 = 1)• ACEinhibitor(𝑡 = 0)

•Datasetofobservationaldatamedications,bloodtests,pastdiagnoses,demographics…

Causalinferencefromobservationaldata:Medication

Buildaregressionmodelfrompatientfeaturesandtreatmentdecisionstobloodpressure

predictedBP

predictedBP

• Buildregressionmodelfrompatientfeaturesandtreatmentdecisiontobloodpressure(BP)usingourobservationaldata

• Input:

• Compare

Regressionmodeling

Anna’sfeatures𝑥 𝑡 = 1

Anna’sfeatures𝑥

Output:

=?

𝑡 = 0

𝑙𝑜𝑠𝑠(ℎ 𝑥, 𝑡 , 𝑦)… …

𝑡treatment

𝑥features

Regressionmodeling

𝒍𝒐𝒔𝒔(𝒉 𝒙, 𝒕 , 𝒚)… …

𝑡treatment

𝑥features

Regressionmodeling

𝑙𝑜𝑠𝑠(ℎ 𝑥, 𝑡 , 𝑦)… …

𝒕treatment

𝑥features

Regressionmodeling

Notsupervisedlearning!

•Thisisnotaclassicsupervisedlearningproblem•Supervisedlearningisoptimizedtopredictoutcome,nottodifferentiatetheinfluenceof𝑡 = 1 vs.𝑡 = 0

•Whatifourhigh-dimensionalmodelthrewawaythefeatureoftreatment𝑡?

•Maybethere’sconfounding:youngerpatientstendtogetmedication𝑡 = 1olderpatientstendtogetmedication𝑡 = 0

Potentialoutcomes(Rubin&Neyman)Foreverysample𝑥 ∈ 𝒳,andtreatment𝑡 ∈ {0,1},thereisapotentialoutcome𝑌=|𝑥

Bloodpressurehadtheyreceivedtreatment1

Bloodpressurehadtheyreceivedtreatment0 𝑌?|𝑥

𝑌@|𝑥

Individualtreatmenteffect𝑰𝑻𝑬 𝒙 := 𝔼 𝒀𝟏 − 𝒀𝟎|𝒙Weobserveonlyonepotentialoutcome,andnotatrandom!

Example– patientbloodpressure(BP)

Factual(observed)set(age,gender,treatment)

BPaftermedication

(40, F, 1) 𝑌@ = 140(40, M, 1) 𝑌@ = 145(65, F, 0) 𝑌? = 170(65, M, 0) 𝑌? = 175(70, F, 0) 𝑌? = 165

Features:𝑥 = (𝑎𝑔𝑒, 𝑔𝑒𝑛𝑑𝑒𝑟),𝑡𝑟𝑒𝑎𝑡𝑚𝑒𝑛𝑡: 𝑡 ∈ {0,1}

Example– patientbloodpressure(BP)

Factual(observed)set(age,gender,treatment)

BPaftermedication

(40, F, 1) 𝑌@ = 140(40, M, 1) 𝑌@ = 145(65, F, 0) 𝑌? = 170(65, M, 0) 𝑌? = 175(70, F, 0) 𝑌? = 165

Counterfactual set(age,gender,treatment)

BPaftermedication

(40, F, 0) 𝑌? =?(40, M, 0) 𝑌? =?(65, F, 1) 𝑌@ =?(65, M, 1) 𝑌@ =?(70, F, 1) 𝑌@ =?

Features:𝑥 = (𝑎𝑔𝑒, 𝑔𝑒𝑛𝑑𝑒𝑟),𝑡𝑟𝑒𝑎𝑡𝑚𝑒𝑛𝑡: 𝑡 ∈ {0,1}

Example– patientbloodpressure(BP)

Factual(observed)set(age,gender,treatment)

BPaftermedication

(40, F, 1) 𝑌@ = 140(40, M, 1) 𝑌@ = 145(65, F, 0) 𝑌? = 170(65, M, 0) 𝑌? = 175(70, F, 0) 𝑌? = 165

Counterfactual set(age,gender,treatment)

BPaftermedication

(40, F, 0) 𝑌? =?(40, M, 0) 𝑌? =?(65, F, 1) 𝑌@ =?(65, M, 1) 𝑌@ =?(70, F, 1) 𝑌@ =?

Features:𝑥 = (𝑎𝑔𝑒, 𝑔𝑒𝑛𝑑𝑒𝑟),𝑡𝑟𝑒𝑎𝑡𝑚𝑒𝑛𝑡: 𝑡 ∈ {0,1}Predictionset

Factual(observed)set(age,gender,treatment)

BPaftermedication

(40, F, 1) 𝑌@ = 140(40, M, 1) 𝑌@ = 145(65, F, 0) 𝑌? = 170(65, M, 0) 𝑌? = 175(70, F, 0) 𝑌? = 165

Counterfactual set(age,gender,treatment)

BPaftermedication

(40, F, 0) 𝑌? =?(40, M, 0) 𝑌? =?(65, F, 1) 𝑌? =?(65, M, 1) 𝑌@ =?(70, F, 1) 𝑌@ =?

Predictionset

• Closelyrelatedtounsuperviseddomainadaptation

• Nosamplesfromthetestset• Can’tperformcross-validation!

OutlineBackgroundModelExperimentsTheory

OutlineBackgroundModelExperimentsTheory

OurWork•Newneural-netbasedrepresentationlearningalgorithmwithexplicitregularizationforcounterfactualestimation

•State-of-the-artonpreviousbenchmarkandonreal-worldcausalinferencetask

•Firsterrorboundforestimatingindividualtreatmenteffect(ITE)

Features𝑥

Control,𝑡 = 0Treated,𝑡 = 1

Whenisthisproblemeasier?RandomizedControlledTrials

Randomizedtreatmentàcounterfactualandfactualhaveidenticaldistributions

Features𝑥

Control,𝑡 = 0Treated,𝑡 = 1

Whenisthisproblemharder?Observationalstudy

Treatmentassignmentnon-randomàcounterfactualandfactualhavedifferentdistributions

Learningmorebalancedrepresentations

Features𝑥

Control,𝑡 = 0Treated,𝑡 = 1

Learningmorebalancedrepresentations

Features𝑥

RepresentationΦ(𝑥)

Control,𝑡 = 0Treated,𝑡 = 1

Learningmorebalancedrepresentations

𝑝=VWX=WY(𝑥)

Features𝑥

RepresentationΦ(𝑥)

Control,𝑡 = 0Treated,𝑡 = 1

Learningmorebalancedrepresentations

𝑝Z[\=V[](𝑥)

Features𝑥

RepresentationΦ(𝑥)

Control,𝑡 = 0Treated,𝑡 = 1

Learningmorebalancedrepresentations

𝑝^=VWX=WY(𝑥)

Features𝑥

RepresentationΦ(𝑥)

Control,𝑡 = 0Treated,𝑡 = 1

Learningmorebalancedrepresentations

𝑝^Z[\=V[](𝑥)

Features𝑥

RepresentationΦ(𝑥)

Control,𝑡 = 0Treated,𝑡 = 1

𝑙𝑜𝑠𝑠(ℎ 𝑥, 𝑡 , 𝑌=)… …

𝑡treatment

𝑥features

NaïveNeuralNetworkforestimatingindividualtreatmenteffect(ITE)

𝑙𝑜𝑠𝑠(ℎ Φ, 𝑡 , 𝑌=)… …Φ

𝑡

RepresentationΦ

Predictionℎ

𝑡treatment

𝑥features

VanillaNeuralNetworkforCounterfactualRegression(CFR)

BalancingNeuralNetworkforCounterfactualRegression(CFR)

𝑙𝑜𝑠𝑠(ℎ Φ, 𝑡 , 𝑌=)… …Φ

𝑡

𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])

RepresentationΦ

𝑡treatment

𝑥features

Predictionℎ

𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])

RepresentationΦ

Predictionℎ

𝑑𝑖𝑠𝑡 𝑝^=VWX=WY, 𝑝^Z[\=V[] :

MMDdistance(Gretton etal.2012)Wassersteindistance(Villani2008,Cuturi 2013)

𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])

RepresentationΦ

Predictionℎ

InspiredbyDomainAdversarialNetworks(Ganin etal.,2016):

(source domain, target domain) à(treated population, control population)

𝑑𝑖𝑠𝑡 𝑝^=VWX=WY, 𝑝^Z[\=V[] :

MMDdistance(Gretton etal.2012)Wassersteindistance(Villani2008,Cuturi 2013)

OutlineBackgroundModelExperimentsTheory

OutlineBackgroundModelExperimentsTheory

EvaluatingcounterfactualinferenceTrain-testparadigmbreaksNoobservationsfromthecounterfactual“test”setCan’tdocross-validationforhyper-parameterselection

1)Simulateddata:IHDP(Hill,2011)2)Realdata:NationalSupportedWorkstudy(LaLonde,1986,Todd&Smith2005)Theeffectofjobtrainingonemployment andincomeObservationalstudywitharandomizedcontrolledtrialsubset

EvaluatingcounterfactualinferenceTrain-testparadigmbreaksNoobservationsfromthecounterfactual“test”setCan’tdocross-validationforhyper-parameterselection

1)Simulateddata:IHDP(Hill,2011)2)Realdata:NationalSupportedWorkstudy(LaLonde,1986,Todd&Smith2005)Theeffectofjobtrainingonemployment andincomeObservationalstudywitharandomizedcontrolledtrialsubset3212samples,8features incl.educationandpreviousincome

Evaluatingmodelswithrandomizedcontrolledtrialsdata• Wecan’tdirectlyevaluateindividualtreatmenteffect(ITE)errorbecauseweneverseethecounterfactual

• EveryITEestimatorimpliesapolicy𝐼𝑇𝐸c 𝑥 = 𝑓(𝑥)

Policy𝜋f,g:𝒳 → {0,1}Treatallpersons𝑥with𝑓 𝑥 > 𝜆, forthreshold𝜆

• Everypolicy𝜋hasapolicy-value:

𝔼 𝑌@ 𝜋 𝑥 = 1 𝑝 𝜋 = 1 + 𝔼 𝑌? 𝜋 𝑥 = 0 𝑝 𝜋 = 0

RandomizedControlledTrial Policy𝜋

Agreement

Evaluatingmodelperformanceusingrandomizeddata(off-policyevaluation)

Control,𝑡 = 0Treated,𝑡 = 1

Policyvalue:𝔼 𝑌@ 𝜋 𝑥 = 1 𝑝 𝜋 = 1 + 𝔼 𝑌? 𝜋 𝑥 = 0 𝑝 𝜋 = 0

• NationalSupportedWork:randomizedtrialembeddedinanobservationalstudy• Policyriskestimatedonrandomizedsubsample• CFR-2-2:ourmodel,with2 layersbeforeΦand2layersafterΦ

Method Policyrisk(std)

ℓ@-reg. logisticregression 0.23±0.00BART(Chipman,George &McCulloch,2010) 0.24±0.01Causalforests(Wager& Athey,2015) 0.17±0.006CFR-2-2 Vanilla 0.16±0.02CFR-2-2Wasserstein 0.15±0.02CFR-2-2MMD 0.13±0.02

Experimentalresults– NationalSupportedWorkStudy

Loweris

better

Experimentalresults– NationalSupportedWorkStudy

Loweris

better

CausalforestCFR2-2VanillaCFR2-2MMDRandomPolicy

OutlineBackgroundModelExperimentsTheory

OutlineBackgroundModelExperimentsTheory

Theoryofcausaleffectinference

•Standardresultsinstatistics:asymptoticrateofconvergencetotrueaverageeffect•Assumptions:weknowtruemodel(consistency)

•Ourresult:generalizationerrorboundforindividual-levelinference•Assumptions:truemodellieswithinlargemodelfamily,e.g.boundedLipschitzfunctions

Theorem (informal) • Let 𝑌m=

^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@

^,n(𝑥) − 𝑌m?^,n(𝑥)

• If “strong ignorability” holds, and if 𝑑𝑖𝑠𝑡 is “nice” with respect to the true potential outcomes 𝑌? and 𝑌@and the representation Φ, then for all normalized Φand ℎ:

𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t

^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])

• Let 𝑌m=^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1

• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@^,n(𝑥) − 𝑌m?

^,n(𝑥)

• If “strong ignorability” holds, and if 𝑑𝑖𝑠𝑡 is “nice” with respect to the true potential outcomes 𝑌? and 𝑌@and the representation Φ, then for all normalized Φand ℎ:

Theorem (informal)

ExpectederrorinestimatingITE

𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t

^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])

Theorem (informal)

“supervisedlearninggeneralizationerror”

𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t

^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])

• Let 𝑌m=^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1

• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@^,n(𝑥) − 𝑌m?

^,n(𝑥)

• If “strong ignorability” holds, and if 𝑑𝑖𝑠𝑡 is “nice” with respect to the true potential outcomes 𝑌? and 𝑌@and the representation Φ, then for all normalized Φand ℎ:

Theorem (informal)

Distancebetween𝛷-induceddistributions

𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t

^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])

• Let 𝑌m=^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1

• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@^,n(𝑥) − 𝑌m?

^,n(𝑥)

• If “strong ignorability” holds, and if 𝑑𝑖𝑠𝑡 is “nice” with respect to the true potential outcomes 𝑌? and 𝑌@and the representation Φ, then for all normalized Φand ℎ:

• Let 𝑌m=^,n(𝑥) = ℎ(Φ 𝑥 , 𝑡) for 𝑡 = 0,1

• 𝐼𝑇𝐸c ^,n(𝑥):= 𝑌m@^,n(𝑥) − 𝑌m?

^,n(𝑥)

Theorem (informal)

Weminimizeupperboundwithrespectto𝛷andℎ

𝔼o 𝑒𝑟𝑟𝑜𝑟 𝐼𝑇𝐸c ^,p(𝑥) ≤2 s 𝔼o,= 𝑒𝑟𝑟𝑜𝑟 𝑌=t

^,p(𝑥) + 𝑑𝑖𝑠𝑡(𝑝^=VWX=WY, 𝑝^Z[\=V[])

Summary•EstimatingIndividualTreatmentEffectisdifferentfromsupervisedlearning• Bearsstrongconnectionstodomainadaptation

•WegivenewrepresentationlearningalgorithmsforestimatingIndividualTreatmentEffect• UsetheMMDandWassersteindistributionaldistances

• Experimentsshowourmethodiscompetitiveorbetterthanstate-of-the-art

•Wegiveanewerrorbound forestimatingIndividualTreatmentEffect

• FredrikD.Johansson,UriShalit,DavidSontag“LearningRepresentationsforCounterfactualInference”ICML2016

• UriShalit,FredrikD.Johansson,DavidSontag“Estimatingindividualtreatmenteffect:generalizationboundsandalgorithms”arXiv:1606.03976

Acknowledgments:JustinChiu(FAIR),MarcoCuturi (ENSAE/CREST),JenniferHill(NYU),Aahlad Manas (NYU),Sanjong Misra (U.Chicago),EstebanTabak (NYU)andStefanWager(Columbia)

Thankyou!

top related