direct optimization · kingma et al 2013. standard (gaussian) vae kingma et al 2013. standard...
TRANSCRIPT
Direct OptimizationCSC2547
Adamo Young, Dami Choi, Sepehr Abbasi Zadeh
Direct Optimization
● A way to obtain gradient estimates that directly optimizes a non-differentiable objective.
● It has first appeared in structured prediction problems.
Structured PredictionWhenever the goal state has inter-dependency
Image from Wikipedia Image from http://dbmsnotes-ritu.blogspot.com/
Structured Prediction
Scoring function , discrete
Structured Prediction
Inference:
Structured Prediction
Scoring function , discrete
Structured Prediction
Inference:
Training:
Structured Prediction
Scoring function , discrete
Gradient Estimator
Gradient Estimator
● Gradient descent on discrete :
Gradient Estimator
● Gradient descent on discrete :
● Option 1: continuous relaxation
Gradient Estimator
● Gradient descent on discrete :
● Option 1: continuous relaxation● Option 2: estimate
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
Inference:
Loss-augmented Inference:
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
Loss Gradient Theorem (McAllester et al., 2010;Song et al,. 2016)
“Away from worse” “Towards better”
Limitations
● Existence of ○ Bias/variance trade-off
● Solving argmax of loss-adjusted inference
Applications● Phoneme-to-speech alignment (McAllester et al. 2010)
● Maximizing average precision for ranking (Song et al. 2016)
● Discrete structured VAE (Lorberbom et al. 2018)
● RL with discrete action spaces (Lorberbom et al. 2019)
Applications● Phoneme-to-speech alignment (McAllester et al. 2010)
● Maximizing average precision for ranking (Song et al. 2016)
● Discrete structured VAE (Lorberbom et al. 2018)
● RL with discrete action spaces (Lorberbom et al. 2019)
Direct Optimization through arg max for Discrete
Variational Auto-EncoderGuy Lorberbom, Andreea Gane, Tommi Jaakola,
Tamir Hazan
Probability Background
● Gumbel Distribution● Various Sampling “Tricks”
○ Reparameterization○ Gumbel-Max○ Gumbel-Softmax
Gumbel Distribution
Intuitively: Distribution of extreme value of a number of normally distributed samples
x
p(x)
https://en.wikipedia.org/wiki/Gumbel_distribution
Dot = parameter nodeRectangle = deterministic node
Circle = stochastic nodeLine = functional dependency
Gradient Estimators for Stochastic Computation Graphs
Schulman et al 2016
Gradient Estimators for Stochastic Computation Graphs
Dot = parameter nodeRectangle = deterministic node
Circle = stochastic nodeLine = functional dependency
Red Line = gradient propagation
Reparameterization Trick
Kingma et al 2015
Reparameterization Trick
REINFORCE/REBAR/RELAX Reparam
Williams 1988Tucker et al 2016Grathwohl et al 2017
Gumbel-Max Trick
Gumbel-Max Trick
REINFORCE/REBAR/RELAX Direct Optimization
Gumbel-Softmax Trick
REINFORCE/REBAR/RELAX CONCRETE
Jang et al 2017Maddison et al 2017
Gumbel-Softmax Distribution
Jang et al 2017
Why discrete latent variables?
● Stronger inductive bias● Interpretability● Allow structural relations in encoder
Standard (Gaussian) VAE
Kingma et al 2013
Standard (Gaussian) VAE
Kingma et al 2013
Standard (Gaussian) VAE
Kingma et al 2013
Standard (Gaussian) VAE
Kingma et al 2013
Naive Categorical VAE
Naive Categorical VAE
Naive Categorical VAE
Naive Categorical VAE
We can apply standard gradient estimators (REINFORCE/REBAR/RELAX)
Gumbel-Max VAE
Gumbel-Max VAE + Direct Optimization
Gumbel-Max VAE + Direct Optimization
Gumbel-Max VAE + Direct Optimization
Algorithm:1) Sample from Gumbel2) Compute 3) Estimate gradient
Structured Encoder
No structure:
Structured Encoder
No structure:
Pairwise relationships:
Solve argmax with QIP/MaxFlow
Structured Encoder
No structure:
Pairwise relationships:
Solve with CPLEX/Max Flow
Not practical with Gumbel-Softmax: exponential number of terms to sum over in the denominator
Structured Encoder may help
Gradient Bias-Variance Tradeoff
Direct Gumbel-Max VAE (with associated epsilon)Gumbel-Softmax VAE (with associated tau)
Direct Gumbel-Max VAE trains fasterK = 10
VAE Comparison
Standard (Gaussian) Gumbel-Softmax Naive Categorical + standard gradient estimator
Gumbel-Max + Direct
+ Unbiased, low variance gradients
+ Discrete latent variables
+ Discrete latent variables
+ Unbiased gradients
+ Discrete latent variables
+ Allows structural relations
- Continuous latent variables
- Limited structural relations
- Biased gradients- Limited structural
relations- Extra parameter (tau)
- Limited structural relations
- Biased gradients- Extra parameter
(epsilon)- Optimization
subproblem to get gradients
Direct Policy Gradients: Direct Optimization of Policies in
Discrete Action SpacesGuy Lorberbom, Chris J. Maddison, Nicolas Heess,
Tamir Hazan, Daniel Tarlow
Reinforcement Learning
Agent
Environment
actionreward, state
Goal:Maximize cumulative reward
Policy Gradient Method
Goal:
Agent
Environment
actionreward, state
Policy Gradient Method
Want:
REINFORCE:
Policy Gradient Method
Want:
REINFORCE:
Direct Policy Gradient:
State Reward TreeTree of all possible trajectories (fix the seed of the environment)
Separate environment stochasticity and policy stochasticity
State Reward TreeGiven:
Can sample trajectories:
Reparameterize the PolicyInstead of sampling per-timestep
we sample per-trajectory.
Given action sequences ,
define:
Gumbel-max reparameterizationNow that we have
Let for each trajectory , and
Gumbel-max reparameterizationNow that we have
Let for each trajectory , and
Gumbel-max reparameterization
Let , and .
Then under this reparameterization,
Discrete configurations
Scoring function
Loss
Inference
Loss-augmentedInference
Structured Prediction RL
Discrete configurations
Scoring function
Loss
Inference
Loss-augmentedInference
Structured Prediction RL
Discrete configurations
Scoring function
Loss
Inference
Loss-augmentedInference
Structured Prediction RL
Discrete configurations
Scoring function
Loss
Inference
Loss-augmentedInference
Structured Prediction RL
Discrete configurations
Scoring function
Loss
Inference
Loss-augmentedInference
Structured Prediction RL
Discrete configurations
Scoring function
Loss
Inference
Loss-augmentedInference
Structured Prediction RL
Direct Policy Gradient (DirPG)
Direct Policy Gradient (DirPG)
Direct Policy Gradient (DirPG)
AlgorithmFor every training step:
1. Sample
2.
3. Compute gradients
ProblemFor every training step:
1. Sample
2. ⇐ How to obtain this?
3. Compute gradients
Solution: A* sampling (Maddison et al., 2014)
Use heuristic search to find trajectory with direct objective better than
Complete AlgorithmFor every training step:
1. Sample and compute
2. While budget not exceeded:
a. Obtain from heuristic search
b. End search if
3. Compute gradients
LimitationsFor every training step:
1. Sample and compute
2. While budget not exceeded:
a. Obtain from heuristic search
b. End search if
3. Compute gradients
Must be able to reset environment to previously visited states.
LimitationsFor every training step:
1. Sample and compute
2. While budget not exceeded:
a. Obtain from heuristic search
b. End search if
3. Compute gradients
Must be able to reset environment to previously visited states.
Termination on first improvement
Combinatorial banditsNumber of trajectories searched to find increases as training progresses for combinatorial bandits.
MiniGridComparisons between different heuristics for DirPG and REINFORCE on MiniGrid.
MiniGridEvidence of “pulling up” on MiniGrid.
Related Work● Gradient Estimators
○ REINFORCE (Williams 1988)○ REBAR (Tucker et al 2017)○ RELAX (Grathwohl et al 2018)○ Gumbel-Softmax (Jang et al 2017, Maddison et al 2017)
● Discrete Deep Generative Models○ VQ-VAE (Oord et al 2017)○ Discrete VAE (Rolfe 2017)○ Gumbel-Sinkhorn (Mena at al 2018)
● Reinforcement Learning
Top-Down sampling using A* Sampling
Non-starters● Compute for all possible trajectories
● Roll-out many trajectories and select best
Gumbel Process
Gumbel ProcessWe know:
Gumbel ProcessWe know:
Therefore:
Gumbel ProcessWe know:
Gumbel Process
A B
Gumbel Process
A B
Gumbel Process
A B
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3 1.1
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3 1.1
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3 1.1
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3 1.1
1.3
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3 1.1
1.3
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3 1.1
1.30.19
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3 1.1
1.30.19
● Repeat until terminating state found.
● Yield trajectory and
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3 1.1
1.30.19
● Repeat until terminating state found.
● Yield trajectory and
Recall, Goal:
How to prioritize ?
Trajectory Generation● Lazily create partitions of trajectories.● Recursion rule:
○ For , copy parent node’s value.○ For the remaining choices of actions, group them and compute truncated value.
1.3
1.3 1.1
1.30.19
● Repeat until terminating state found.
● Yield trajectory and
Recall, Goal:
How to prioritize ?
Search for large using A* Sampling● Lower bound of accumulated reward (L)
● Upper bound of reward-to-go (U)
●
● In practice: