recurrent '3222378 networks relational learning '3223379 ... › talks › mila-rn.pdf ·...

62
x Recurrent . Networks Relational Learning % Reinforcement Petar Veliˇ ckovi´ c Artificial Intelligence Group Department of Computer Science and Technology, University of Cambridge, UK MILA Graph Representation Reading Group Meeting 22 August 2018

Upload: others

Post on 25-Jun-2020

4 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

xRecurrent

↙ ↑Networks← Relational→ Learning

↓ ↗Reinforcement

Petar Velickovic

Artificial Intelligence GroupDepartment of Computer Science and Technology, University of Cambridge, UK

MILA Graph Representation Reading Group Meeting 22 August 2018

Page 2: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Introduction

I In this talk, I will survey the Relational Network architecture,and its recent deployment in recurrent neural networks anddeep reinforcement learning.

I The discussion will span the following papers:I A simple neural network module for relational reasoning

(Santoro, Raposo et al., NIPS 2017)I Relational deep reinforcement learning

(Zambaldi, Raposo, Santoro et al., 2018)I Relational recurrent neural networks

(Santoro, Faulkner, Raposo et al., 2018)

I Substantial part of DeepMind’s recent “graph networks surge”.

Page 3: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Relational reasoning

I Being able to reason about relations between entities presentin an input is an important aspect of intelligence!

I Consider the simple task of inferring which two points from agiven point set are furthest apart—this requires computing andcomparing all* of their pairwise distances.

I Keep this task in mind—it will be revisited!

Page 4: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Approaches to relational reasoning

I Relations can be naturally expressed within symbolic methods(defined by e.g. the rules of logic)—but these are not robust tosmall variations of inputs/tasks.

I Robustness is often achievable with standard neural networkarchitectures (such as MLPs), but it is extremely challengingfor them to capture relations, despite their theoretical potency!

I This claim is extensively validated throughout the three papers.

=⇒ Seek a model inspired by symbolic AI, while empowered byneural networks (explicitly represent relations in a robust way).

Page 5: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Our task for today

♣♦♥

♠input

~o♣

~o♦

~o♥

~o♠

"two"

output

"How many outlined objects

are above the spade?"

query

Page 6: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Our task for today

♣♦♥

♠input

~o♣

~o♦

~o♥

~o♠

♣♦

♥♠

objects

"two"

output

"How many outlined objects

are above the spade?"

query

Page 7: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Our task for today

♣♦♥

♠input

~o♣

~o♦

~o♥

~o♠

♣♦

♥♠

objects

"two"

output

"How many outlined objects

are above the spade?"

query

RN

Page 8: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Our task for today

♣♦♥

♠input

~o♣

~o♦

~o♥

~o♠

♣♦

♥♠

objects

"two"

output

"How many outlined objects

are above the spade?"

query

RN

Page 9: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

The Relational Network

♣♦♥

♠input

~o♣

~o♦

~o♥

~o♠

♣♦

♥♠

objects

"two"

output

"How many outlined objects

are above the spade?"

query

RN

Page 10: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Relational Networks

I Initially, we will assume that the objects are provided as input.

I Consider a set of n objects, O = {~o1, ~o2, . . . , ~on}; with eachobject represented by a feature vector ~oi ∈ Rm.

I A Relational Network (RN) summarises the relations betweenthese objects as follows:

RN(O) = fφ

∑i,j

gθ(~oi , ~oj

)where gθ : Rm × Rm → Rk and fφ : Rk → Rl are functions withparameters θ and φ (usually MLPs).

Page 11: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Properties of RNs

I The central component of RNs is gθ; the relation function. Itsrole is to infer the nature of relations between objects i and j .

I An RN may be seen as a message-passing neural networkover the complete graph of object nodes.

I RNs have several highly desirable properties:I Relational inference—given the all-pairs nature of the

computation, the module does not assume upfront knowledge ofwhich pairs of objects are related, and how.

I Data efficiency—an MLP would need to learn and embed n2

(identical) functions to replicate behaviour of RNs.I Permutation invariance—the summation operation ensures

that the order of objects does not matter; therefore RNs can beapplied to arbitrary sets.

Page 12: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Dynamic physical systems

MuJoCo-simulated physical mass-spring systems with 10 objects.

Input: ~oi is RGB color and (x , y) coordinates across 16 time steps.Tasks: (i) infer relations; (ii) count number of systems (harder!).

Page 13: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on physical systems

I Relational Networks achieve 93% accuracy in predicting theexistence/absence of relations between objects, and 95%accuracy in predicting the number of interacting systems.

I MLPs fail to predict better than chance on either task!

I Learnt function transferable to unseen motion capture data!

Page 14: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Conditioning in RNs

♣♦♥

♠input

~o♣

~o♦

~o♥

~o♠

♣♦

♥♠

objects

"two"

output

"How many outlined objects

are above the spade?"

query

RN

Page 15: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

RN conditioning

I An RN may be seen as a module that “captures” the relationsbetween objects in a set—this computation may be arbitrarilyconditioned, e.g. to answer a specific relational query.

I Assuming we have a conditioning vector ~q, the RN architecturemay be trivially modified to include it:

RN(O, ~q) = fφ

∑i,j

gθ(~oi , ~oj , ~q

)

Page 16: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

The CLEVR dataset (Johnson et al., 2017)

Question Answering dataset on 3D-rendered objects.

What is the size of the brown sphere?

Non-relational question:Original Image:

Relational question:Are there any rubber things that have the same size as the yellow metallic cylinder?

Input: ~oi is RGB color, (x , y , z) coordinates, shape/material/size.Queries: count, exist, compare numbers, query attribute,

compare attribute.

Page 17: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on CLEVR

I The query sentence is encoded into ~q as the last-stage outputof a word-level LSTM (with learned word embeddings).

I Relational Networks achieve an accuracy of 96.4% on CLEVR.

I Human performance is 92.6%! This sounds great!

I . . .

OK, I lied to you. (Sorry!)

Page 18: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on CLEVR

I The query sentence is encoded into ~q as the last-stage outputof a word-level LSTM (with learned word embeddings).

I Relational Networks achieve an accuracy of 96.4% on CLEVR.

I Human performance is 92.6%! This sounds great!

I . . . OK, I lied to you. (Sorry!)

Page 19: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

The actual CLEVR dataset (Johnson et al., 2017)

Visual Question Answering dataset on 3D-rendered objects.

What is the size of the brown sphere?

Non-relational question:Original Image:

Relational question:Are there any rubber things that have the same size as the yellow metallic cylinder?

Input: The scene image. The ~oi vectors are not explicitly given!Queries: count, exist, compare numbers, query attribute,

compare attribute.

Page 20: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Object extraction

♣♦♥

♠input

~o♣

~o♦

~o♥

~o♠

♣♦

♥♠

objects

"two"

output

"How many outlined objects

are above the spade?"

query

RN

Page 21: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Object extraction from images

I In general, we should not assume the ~oi will be given!

I Arguably, obtaining the ~oi from raw input will be the mostvariable pipeline component.

I Often, we can obtain object representations as high-leveloutputs of neural networks specialised for such inputs.

I In the case of images (most common!) this will be aconvolutional neural network.

Page 22: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

CNN object extraction

I A convolutional architecture generally consists of interleavingconvolutional and pooling layers—progressively building moresophisticated feature maps.

I At any point during a CNN, a feature map f may have theshape n ×m × k , where n and m are the height and width ofthe feature map, and each pixel is represented by k features.

I Each pixel represents a summary of a certain region of theimage. Without any further assumptions, it is safest to let eachpixel constitute an object!

I Therefore, we will have an object set O = {~o1, . . . ~on·m} withn ·m objects and ~oi ∈ Rk that will correspond to ~fxy .

Page 23: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

CNN object extraction

Conv. Pool Conv.

input f

Page 24: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

CNN object extraction

Conv. Pool Conv.

input~o1

~o3

~o2

~o4

Page 25: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Overall CLEVR architecture

small+

-MLP

...

...

... ...

is spherewhat size

Final CNN feature maps RN

LSTM

object -MLP

Conv.

What size is the cylinder that is left of the brown metal thing that is left of the big sphere?

*

Object pair with question

Element-wise sum

End-to-end trainable with gradient descent.

Page 26: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Actual results on CLEVR

First approach to achieve superhuman performance on this task!

Page 27: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Actual results on CLEVR

equallessthan

morethan

existcountoverall

querysize

queryshape

querymaterial

querycolor

comparesize

compareshape

comparematerial

comparecolor

Human

LSTMCNN+LSTMCNN+LSTM+SACNN+LSTM+RN

Q-type baseline

Acc

urac

yAcc

urac

y

0.00.250.5

0.751.0

0.00.250.5

0.751.0

compare attributequery attribute

compare numbers

Especially excels at compare attribute, the query type whichheavily relies on relational reasoning.

Page 28: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Failure cases on CLEVR

Failure inputs are often occurring under heavyocclusion—challenging for humans as well!

Page 29: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on Sort-of-CLEVR

A simple CLEVR-inspired dataset with clear separation of relationalvs. non-relational queries.

Demonstrates clear advantage of RNs on relational queries.

Page 30: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

LSTM object extraction: bAbI (Weston et al., 2015)

A text-based set of 20 question-answering tasks.

Let O = {~o1, . . . , ~o20} be the LSTM representations of up to 20sentences preceding the question. ~q is once again obtained as theLSTM representation of the question.

Page 31: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on bAbI

I RN(O, ~q) passes (95+% accuracy) 18/20 tasks after jointtraining—comparable with other state-of-the-art memorynetwork architectures.

I Memory networks: 14/20I DNC: 18/20I Sparse DNC: 19/20I EntNet: 16/20

I Does not catastrophically fail (91.9% and 83.5% accuracy) onthe remaining two.

I Notably, it succeeds on the basic induction task (97.9%), whereSparse DNC (46%), DNC (44.9%) and EntNet (47.9%) all fail.

Page 32: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Self-attention

I Thus far, the building block functions of a Relational Network(fφ, gθ) were simple MLPs.

I For more recent RN architectures, we focus instead on theself-attention operator.

I A self-attentional operator, Aθ, acts on a set of n entities,E = {~e1, ~e2, . . . , ~en}, producing higher-level representations:

E = Aθ(E)

where E = {~e′1, ~e′2, . . . , ~e′n}, and θ are learnable parameters.

Page 33: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Self-attention, cont’d

I Each component of E will be derived by examining allcomponents of E (by way of linear combinations):

~e′i =∑

j

αij fψ(~ej)

where fψ : Rm → Rk is a learnable transformation.

I Here, the coefficients αij correspond to the importance of thefeatures of entity j to entity i , and are derived by a learnableattention mechanism, aφ : Rm × Rm → R:

αij = aφ(~ei , ~ej)

Page 34: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Self-attention

~e1 . . . ~ej . . . ~en

aφ aφ aφ

α1,j αj,j αn,j

× × ×

Σ

~e′j

fψ fψ fψ

Page 35: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

The Transformer architecture

I In particular, the Transformer architecture (Vaswani et al.,2017) is used for Aθ.

I Here abbreviated as MHDPA (multi-head dot-product attention).

I First, derive queries, keys and values for the attention:

~qi = Wq~ei~ki = Wk~ei ~vi = Wv~ei

I Now, use the queries and keys to derive coefficients:

αij =exp

(〈~qi , ~kj〉/

√dk

)∑

m exp(〈~qi , ~km〉/

√dk

)where dk is the dimensionality of the keys.

Page 36: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

The Transformer architecture, cont’d

I Now, can use αij to recombine the values at each position:

~e′i =∑

j

αij~vj

I Can be conveniently written in matrix form as:

E = softmax(

QKT√

dk

)V

I Further optimised by using multi-head attention; replicatingthis operation K times (each with independent parameters Wq,Wk , Wv ) and featurewise-concatenating the results.

Page 37: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Reinforcement learning: Box-World and StarCraft II

Box-World: A grid RL environment meant to stress relationalreasoning while deciding how to act.StarCraft II: Mini-games (Vinyals et al., 2017).

In both cases, agent receives pixel-structured inputs (minimaps,screens, etc.).

Page 38: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Relational deep reinforcement learning

I Empowers a standard CNN-based policy network (in an RLsetting) with a relational module based on self-attention.

I Architectures for both tasks are very similar!

I Extract entities, ~ei , just as before (as separate pixels in ahigh-level feature map).

I Then perform several rounds of the Transformer self-attentionover E (each round followed by a small MLP, fθ, and layernormalisation to introduce nonlinearity).

I Finally, perform global pooling and a small MLP to derive thepolicy for the RL algorithm (IMPALA (Espeholt et al., 2018)).

Page 39: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

The Box-World architecture

ReLU

ReLU

FC 256

Relationalmodule

x 2

x 4

Input

Feature-wisemax pooling

... ...

query

key

value

...

Multi-head dot product attention

Conv. 2 x 2, stride 1

Page 40: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on Box-World

Underlying graphObservation

Bra

nch

leng

th =

1B

ranc

h le

ngth

= 3

0 1 2 3 4 5 6 7 8Environment steps 1e8

1e8

0.0

0.2

0.4

0.6

0.8

1.0

Frac

tion

solv

ed

Environment steps

Frac

tion

solv

ed

0 2 4 6 8 10 12 140.0

0.2

0.4

0.6

0.8

1.0

Relational (1 block)Relational (2 blocks)Baseline (3 blocks)Baseline (6 blocks)

Relational (2 block)Relational (4 blocks)

Page 41: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Visualising the attentional coefficients

Page 42: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Zero-shot experiments in Box-World

Longer solution path lengthsa)

0.0

1.0

0.8

Frac

tion

solv

ed

0.6

0.4

0.2

Train.

Test

Train.

Test4 1086

Frac

tion

solv

ed

Train.

Train. 4 1086

Test Test

0.0

1.0

0.8

0.6

0.4

0.2

Not requiredduring training

Withheld keyb)

Relational BaselineRelational Baseline

The non-relational baseline (ResNet CNN) fails to generalise!

Page 43: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on StarCraft II

Sets new state-of-the-art, often beating human grandmaster.

Page 44: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Zero-shot experiments in StarCraft

1 marine 2 marines 3 marines 4 marines 5 marines 10 marines

Mea

n sc

ore

0

50

100

Relational

Control150

200

Exhibits higher—although not fully conclusive—generalisationability from 2 marines to higher numbers.

Page 45: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Neural networks for sequential processing

I Finally, we turn our attention to architectures used for generalsequential processing of data.

I In the general setting, we require a stateful system, Sθ, capableof processing incoming inputs ~xt , and updating its internalstate, ~st , appropriately:

~st = Sθ(~xt ,~st−1)

I Inference may then be performed by leveraging ~st .

Page 46: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Approaches to sequential processing

I Traditional approaches for modelling Sθ include recurrentneural networks (e.g. LSTM, GRU, etc.) andmemory-augmented neural networks (e.g. NTM, DNC, etc.)

I Recurrent neural networks generally represent their state as afixed-size vector, ~ct , which gets appropriately updated at eachstage of input processing.

I Memory-augmented networks have a memory matrix,M∈ Rn×m, which may be read from/written to by using arecurrent controller.

Page 47: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Analysis of approaches

I Both approaches have shortcomings when explicit relationalreasoning through time is required:

I RNNs pack entire representation in a single dense vector,making it hard to reason about entities (and therefore relations);

I Memory-augmented networks explicitly represent entities (asrows ofM), but these cannot easily interact once written to.

I Relational recurrent neural networks address bothshortcomings simultaneously, explicitly allowing rows ofM tointeract using self-attention!

Page 48: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Relational memory

I Assume we have a memory matrixM = {~m1, ~m2, . . . , ~mn}.

I Applying (Transformer) self-attention to it, we obtain a newmemory state M = {~m′1, ~m′2, . . . , ~m′n}, explicitly taking intoaccount the relations between memory rows ~mi :

M = softmax(

QKT√

dk

)V

whereQ =MWq K =MWk V =MWv

just as before.

Page 49: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Incorporating new inputs

I The interactions thus far are self-contained to what’s already inthe memory; however, we’d like the memory to adapt toincoming inputs, ~x , appropriately.

I Simple extension: let the memory locations attend over ~x too!

M = softmax(

Q[K‖Wk~x ]T√dk

)[V‖Wv~x ]

where ‖ denotes row-concatenation.

Page 50: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

The LSTM controller

I We do not wish to fully overwriteM by M—can control thisprocess with an LSTM:

~ii,t = σ(

Wi~xt + Ui~hi,t−1 + ~bi

)~fi,t = σ

(Wf~xt + Uf

~hi,t−1 + ~bf

)~oi,t = σ

(Wo~xt + Uo~hi,t−1 + ~bo

)~mi,t = gψ(~m′i,t)�~ii,t + ~mi,t−1 �~fi,t~hi,t = tanh

(~mi,t)� ~oi,t

where gψ is a learnable function (2-layer MLP with layernormalisation in the paper).

Page 51: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

The Relational RNN

NextMemory

CORE

Prev.Memory

Input

Output

AResidual

+

MLP

Apply gating

*

Residual

+

***computation of gates not depicted

(a)

Memory

MULTI-HEAD DOT PRODUCT ATTENTION

Input

query

key

value

UpdatedMemory

(b)

.KeysQueries

Compute attention weights

Weights Normalized Weights

Normalize weights with row-wise softmax

Values

.

Weights

Compute weighted average of values Return updated memory

Updated Memory

(c)

Page 52: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Tasks under consideration

A suite of supervised and reinforcement learning tasks demandingexplicit sequential relational reasoning.What is the Nth farthest from vector m?

x = 339for [19]: x += 597 for[94]: x += 875x if 428 < 778 else 652print(x)

BoxWorld Mini-Pacman

LockKey

Loose Key

AgentGem

Viewport

Reinforcement Learning

Program EvaluationNth farthest Language Modeling

Supervised Learning

It had 24 step programming abilities, which meant it was highly _____

A gold dollar had been proposed several times in the 1830s and 1840s , but was not initially _____

Super Mario Land is a 1989 side scrolling platform video _____

I N-th farthest vector from a given vector;I Program evaluation from characters (Learning to Execute);I Language modelling;I Mini-PacMan and Box-World with viewport!

Page 53: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on N-th farthest: LSTM/DNC

LSTM DNC

Failing to surpass 30% batch accuracy!

Page 54: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on N-th farthest: RRNN

Page 55: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Attention weight visualisation

inpu

t

Input Vector Id

timeattention weights

(a) Reference vector is the last in a sequence, e.g. "Choose the 5th furthest vector from vector 7"

(b) Reference vector is the first in a sequence, e.g. "Choose the 3rd furthest vector from vector 4"

(c) Reference vector comes in the middle of a sequence, e.g. "Choose the 6th furthest vector from vector 6"

att

en

din

g f

rom

attending to

Page 56: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on LTE

The RRNN is again highly competitive, especially in scenarioswhere strong relational reasoning may be required (full programs).

Page 57: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on LTE

Page 58: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on Language Modelling

The RRNN obtains competitive perplexity levels, compared toseveral strong baselines.

Page 59: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on Language Modelling: WikiText-103

0 1 2 3 4 5 6 7 8Steps (B)

30

35

40

45

50

55

60

Perp

lexit

y

RMC

LSTM

Page 60: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Results on Mini-PacMan

With viewport Without Viewport

The RRNN outperforms an LSTM when used as a policy network(for IMPALA). Specifically, when the entire map is observed, it

doubles the LSTM performance!

Page 61: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Concluding remarks

I Empowering neural networks with various kinds of relationalreasoning modules will likely be a necessary step towardsstrong and robust intelligent systems.

I This claim is clearly supported by several “failure modes” ofbaseline architectures we considered today.

I One limitation going forward lies in the all-pairs interactions,which will limit scalability to larger object sets, especially ifself-attention is used.

I The NRI (Kipf, Fetaya et al., 2018) offers one possible directionto address this, but probably not the ultimate solution. . .

I In my opinion, particularly important avenue for future work aregraph-structured memories; where we are not restricted to amatrix, and relations between slots are not all-pairs.

Page 62: Recurrent '3222378 Networks Relational Learning '3223379 ... › talks › MILA-RN.pdf · MILA Graph Representation Reading Group Meeting 22 August 2018. Introduction IIn this talk,

Thank you!

[email protected]

http://www.cst.cam.ac.uk/∼pv273/