fast classification using sparsely active spiking networksfpga prototype 97% test set classification...

18
Fast classification using sparsely active spiking networks Hesham Mostafa Institute of neural computation, UCSD

Upload: others

Post on 15-Jul-2020

2 views

Category:

Documents


0 download

TRANSCRIPT

Page 1: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Fast classification using sparsely active spiking networks

Hesham MostafaInstitute of neural computation, UCSD

Page 2: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Artificial networks vs. spiking networks

input layer

hidden layer 1

hidden layer 2

output layer

......

backpropagation

......

input layer

hidden layer 1

hidden layer 2

output layer

......

???

Multi-layer networks are extremelypowerful function approximators.

Backpropagation is the most effective method we know of to solve the credit assignment problem in deep artificial networks

How do we solve the credit assignmentproblem in multi-layer spiking networks?

Page 3: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Neural codes and gradient descent

Rate coding: ● Spike counts/rates are discrete quantities● Gradient is zero almost everywhere● Only indirect or approximate gradient

descent training possible

3

4

2

5

input output input output

t1

t2

t3

tOut

t1

t2

t3

tOut

Temporal coding● Spike times are analog quantities● Gradient of output spike time w.r.t input

spike times is well-defined and non-zero● Direct gradient descent training possible

Page 4: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

The neuron model

dVmem(t)dt

=Isyn (t )

Isyn(t )=∑i

w i exp(−(t−t i))Θ(t−t i)

Θ(t−t i) : Step function

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.00.0

0.2

0.4

0.6

0.8

1.0

1.2

Vm

em

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0

time0.0

0.2

0.4

0.6

0.8

1.0

Isyn

(firing threshold is 1)

Non-leaky integrate and fire neuron

Exponentially decaying synaptic current

Page 5: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

The neuron's transfer function

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.00.0

0.2

0.4

0.6

0.8

1.0

1.2

Vm

em

t1 t2 t3 t4

tout

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0

time0.0

0.2

0.4

0.6

0.8

1.0

Isyn

w1w2

w3

w4

Page 6: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

The neuron's transfer function

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.00.0

0.2

0.4

0.6

0.8

1.0

1.2

Vm

em

t1 t2 t3 t4

tout

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0

time0.0

0.2

0.4

0.6

0.8

1.0

Isyn

w1w2

w3

w4

In general:

Where C is the causal set of inputspikes (input spikes that arrive before output spike)

Page 7: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

The neuron's transfer function

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.00.0

0.2

0.4

0.6

0.8

1.0

1.2

Vm

em

t1 t2 t3 t4

tout

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0

time0.0

0.2

0.4

0.6

0.8

1.0

Isyn

w1w2

w3

w4

In general:

Where C is the causal set of inputspikes (input spikes that arrive before output spike)

Time of the Lth output spike:

Page 8: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Change of variables

The neuron's transfer function then becomes piece-wise linear in the inputs(but not the weights):

Page 9: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Where is the non-linearity?

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.00.0

0.2

0.4

0.6

0.8

1.0

1.2

vm

em

zout =w1z1 +w2z2 +w3z3

w1 +w2 +w3 −1

ln(zout)

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0

time0.0

0.2

0.4

0.6

0.8

1.0

synapti

c cu

rrent ln(z1) ln(z2) ln(z3) ln(z4)

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.00.0

0.2

0.4

0.6

0.8

1.0

1.2

vm

em

zout =w1z1 +w2z2 +w3z3 +w4z4

w1 +w2 +w3 +w4 −1

ln(zout)

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0

time0.0

0.2

0.4

0.6

0.8

1.0

synap

tic

curr

ent ln(z1) ln(z2) ln(z3)ln(z4)

Non-linearity arises due to the input dependence of the causal set of input spikes

The piecewise linear input-output relation is reminiscent of Rectified Linear Units(ReLU) networks

Page 10: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

What is the form of computation implemented by the temporal dynamics?

....

● To compute zout:● Sort {z1,z2,..,zn} ● Find the causal set, C, by progressively considering more early spikes● Calculate

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.00.0

0.2

0.4

0.6

0.8

1.0

1.2

vm

em

zout =w1z1 +w2z2 +w3z3

w1 +w2 +w3 −1

ln(zout)

0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0

time0.0

0.2

0.4

0.6

0.8

1.0

synapti

c cu

rrent ln(z1) ln(z2) ln(z3) ln(z4)

Can not be reduced to the conventional ANN neuron: zout=f (∑i

w i zi)

Page 11: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Backpropagation

To use backpropagation to train a multi-layer network, we need the derivatives of the neuron's output w.r.t:

Weights Inputs

Time of first spike encodes neuron’s value. Each neuron is allowed to spike only once in response to an input pattern:● Forces sparse activity. Training has to make maximum use of each spike● Allows quick classification response

Page 12: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Classification Tasks

● We can relate the time of any spike differentiably to the times of all spikes that caused it

● We can impose any differentiable cost function on the spike times of the output layer and use backpropagation to minimize cost across training set

● In a classification setting, use a loss function that encourages the output neuron representing the correct class to spike first

● Since we have an analytical input-output relation for each neuron, training can be done using conventional machine learning packages (Theano/Tensorflow)

Page 13: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

MNIST task● Pixel values were binarized.● High intensity pixels spike early● Low intensity pixels spike late

Page 14: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Classification is extremely rapid

● A decision is made when the first output neuron spikes● A decision is made after only 25 spikes (on average) from the hidden layer

in the 768-800-10 network, i.e, only 3% of the hidden layer neurons contribute to each classification decision

Page 15: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

FPGA prototype

● 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights)● Average number of spikes until classification: 139● Only 13% of input to hidden weights are looked up● Only 5% of hidden to output weights are looked up

0 100 200 300 400 500 600Timesteps to classification

0

200

400

600

800

1000

1200

1400

Count Mean : 167

Median : 162

0 20 40 60 80 100 120 140 160 180Number of hidden layer spikes before output spike

0

200

400

600

800

1000

1200

1400

Cou

nt Mean : 30

Median : 29

Page 16: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Acknowledgements

Institute ofneuroinformatics

Giacomo IndiveriTobi Delbruck

Gert CauwenberghsSadique SheikBruno Pedroi

Page 17: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Approximate learning

correct label

+

N7 N8 N9

- -

++

N5 N6

N2 N3

+ -- +

N4

N1

Output layer

Hidden layer

Input layer

● Update Hidden→Output weights to encourage the right neuron to spike first● Only update weights that actually contributed to output timings

Page 18: Fast classification using sparsely active spiking networksFPGA prototype 97% test set classification accuracy on MNIST in a 784-600-10 network (8-bit weights) Average number of spikes

Approximate learning

correct label

+

+1 -(-1)

-1 } -1 +1}

N7 N8 N9

- -

++

N5 N6

N2 N3

+ -- +

N4

N1

Output layer

Hidden layer

Input layer

● Backpropagate time deltas using only the sign of the weights● The final time delta at a hidden layer neuron can be obtained using 2

parallel popcount operations (count 1s in a bit vector) and a comparison.