large-scale distributedsecond-order optimization using

1
Q1. How much does Distributed K-FAC improve the performance? A1. Here is a plot which shows that larger # GPUs improves the performance while # GPUs < # layers (ResNet-50 has 107 layers in total (FC, Conv, BN)). Q2. Why do you want to accelerate training ResNet-50 on ImageNet? A2. We chose training ResNet-50 on ImageNet as a benchmark where highly optimized SGD results are available as references. What we really want to accelerate are even larger problems. Q3. What is the benefit of using K-FAC? It’s slower than SGD! A3. Since K-FAC converges in fewer iterations than SGD, improving the per iteration performance will directly accelerate the training without additional tuning. Even if we end up having similar performance to the best known first-order methods, at least we will have a better understanding of why it works by starting from second-order methods. Q4. What’s the Next? A4. There is still room for improvement in our distributed design to overcome the bottleneck of computation/communication for K-FAC – the Fisher information matrix can be approximated more aggressively without loss of accuracy. Q5. Codes available? A5. Our Chainer implementation is available on GitHub! (PyTorch version is coming soon) Acknowledgements Computational resource of AI Bridging Cloud Infrastructure (ABCI) was awarded by ”ABCI Grand Challenge” Program, National Institute of Advanced Industrial Science and Technology (AIST). This work is supported by JST CREST Grant Number JPMJCR19F5, Japan. This work was supported by JSPS KAKENHI Grant Number JP18H03248. (Part of) This work is conducted as research activities of AIST - Tokyo Tech Real World Big-Data Computation Open Innovation Laboratory (RWBC-OIL). This work is supported by ”Joint Usage/Research Center for Interdisciplinary Large-scale Information Infrastructures” in Japan (Project ID: jh180012-NAHI). Stochastic Gradient Descent Natural Gradient Descent (S. Amari, 1998) Large-Scale Distributed Second-Order Optimization Using Kronecker -Factored Approximate Curvature for Deep Convolutional Neural Networks Kazuki Osawa 1 , Yohei Tsuji 1,3 , Yuichiro Ueno 1 , Akira Naruse 2 , Rio Yokota 1,3 , and Satoshi Matsuoka 1,4 1 Tokyo Tech, 2 NVIDIA, 3 AIST-Tokyo Tech RWBC-OIL, AIST, 4 RIKEN CCS Image Classification Discussion | Kazuki’s-FAQ Our Approach | Distributed K-FAC training time = time / iteration × # iterations Step1: layer-wise block-diagonal Step2: Kronecker-factorization F ` F ` G ` A ` Layer-wise Update: = arg min 2R m L() L()= 1 |S | X (x,y)2S - log p (y |x) L B ()= 1 |B | X (x,y)2B - log p (y |x) S = {(x,y )} B S Optimization in Deep Learning Set of image & class label: Goal of optimization: Loss function: Stochastic learning with mini-batch: (t+1) = (t) - rL B ((t) ) (t+1) = (t) - F((t) ) -1 rL B ((t) ) F()= 1 |B | X (x,y )2B r log p (y |x)r log p (y |x) T (empirical) Fisher information matrix Which Optimizer is “Fast” ? SGD NGD K-FAC Good Bad Bad Not Bad Good Good W (t+1) ` = W (t) ` - ( G -1 ` A -1 ` ) r W ` L B ((t) ) G ` , A ` , r W ` L B () GPU Data parallelism Layers parallelism G 1 , A 1 , r W 1 L B () G L , A L , r W L L B () W 1 W L W ` B DNN r log p (y |x) DNN r log p (y |x) DNN r log p (y |x) GPU GPU ... ... ... ... ... ... ... ... Each GPU handles different parts of mini-batch Each GPU applies K-FAC to the weights of different layers (the weights will be synchronized.) Reduce-Scatter Kazuki Osawa, Ph.D. candidate at Tokyo Tech (oosawa.k.ad at m.titech.ac.jp) supervised by Rio Yokota Kronecker product This work (75.0%) Yamazaki+, 2019 (75.0%) Mikami+, 2018 (75.3%) Jia+, 2018 (75.8%) Akiba+, 2018 (74.9%) mini-batch size B # iterations Tesla V100 x 1024 Tesla V100 x 2048 Tesla V100 x 3456 Tesla P40 x 2048 Tesla P100 x1024 training time |B | 131,072 978 81,920 55,296 65,536 32,768 1,440 2,086 1,800 3,519 15 min 6.6 min 2 min 1.2 min 10 min SGD K-FAC K-FAC converges faster than SGD (SOTA results) Our Distributed K-FAC enabled tuning large mini-batch K-FAC on ImageNet in realistic time. We tuned K-FAC with extremely large mini-batch size {4096, 8192, 16384, 32768, 65536, 131072}. We achieved validation accuracy >75% in {10948, 5434, 2737, 1760, 1173, 978} iterations. We showed the advantage of a second-order method (K-FAC) over a first-order method (SGD) in ImageNet classification for the first time. Competitive training time Data and layers hybrid parallelism introduced in our design allowed us to train on 1024 GPUs We achieved 74.9% in 10 min by using K-FAC with the stale Fisher information matrix for mini-batch size 32,768. (32images/GPU x 1024GPUs = 32,768images) (|B|=32,768) Layer Layer Layer 1 ` L softmax x p(y |x) r log p (y |x) (at least for a small DNN) F -1 ` (G ` A ` ) -1 = G -1 ` A -1 ` K-FAC (J. Martens+, 2015) accelerates NGD 1 L L ` F F 1 F ` F L 1 ` Summary | Large-Scale Distributed Deep Learning 2 R mm Large mini-batch training is essential The goal of large-scale distributed deep learning is accelerating training by using many GPUs at the same time. Larger mini-batch allows us to utilize more GPUs and to converges in fewer iterations . But, large mini-batch training is difficult … SGD with large mini-batch (> 8K) suffers from the generalization gap (P. Goyal+, 2018). Previous approaches tried to solve this problem by ad-hoc modification of SGD. K-FAC with large mini-batch (<= 2K) was said to be able to generalize well (J. Ba+, 2017), but the comparison to SGD wasn’t done with well-trained models. No one has shown the clear advantage of K-FAC over SGD since K-FAC requires extra (heavy) computation/communication that need to be distributed among many GPUs . K-FAC can handle extremely large mini-batch (Our Contribution) We implemented Distributed K-FAC which allowed us to scale K-FAC up to 1024 GPUs. We trained ResNet-50 on ImageNet with extremely large mini-batch size (131K). We show the advantage of K-FAC over SGD for large mini-batch for the first time. Benchmark | Training ResNet-50 on ImageNet in 978 iterations / 10 minutes with K-FAC https://github.com/tyohei/chainerkfac

Upload: others

Post on 17-Feb-2022

2 views

Category:

Documents


0 download

TRANSCRIPT

Q1. How much does Distributed K-FAC improve the performance?A1. Here is a plot which shows that larger # GPUs improves the performance while # GPUs < # layers(ResNet-50 has 107 layersin total (FC, Conv, BN)).

Q2. Why do you want to accelerate training ResNet-50 on ImageNet?A2. We chose training ResNet-50 on ImageNet as a benchmark where highly optimized SGD results are available as references. What we really want to accelerate are even larger problems.

Q3. What is the benefit of using K-FAC? It’s slower than SGD!A3. Since K-FAC converges in fewer iterations than SGD, improving the per iteration performance will directly accelerate the training without additional tuning. Even if we end up having similar performance to the best known first-order methods, at least we will have a better understanding of why it works by starting from second-order methods.

Q4. What’s the Next?A4. There is still room for improvement in our distributed design to overcome the bottleneck of computation/communication for K-FAC – the Fisher information matrix can be approximated more aggressively without loss of accuracy.

Q5. Codes available?A5. Our Chainer implementation is available on GitHub!(PyTorch version is coming soon)

Acknowledgements Computational resource of AI Bridging Cloud Infrastructure (ABCI) was awarded by ”ABCI Grand Challenge” Program, National Institute of Advanced Industrial Science and Technology (AIST). This work is supported by JST CREST Grant Number JPMJCR19F5, Japan. This work was supported by JSPS KAKENHI Grant Number JP18H03248. (Part of) This work is conducted as research activities of AIST - Tokyo Tech Real World Big-Data Computation Open Innovation Laboratory (RWBC-OIL). This work is supported by ”Joint Usage/Research Center for Interdisciplinary Large-scale Information Infrastructures” in Japan (Project ID: jh180012-NAHI).

• Stochastic Gradient Descent

• Natural Gradient Descent (S. Amari, 1998)

Large-Scale Distributed Second-Order Optimization Using Kronecker-Factored Approximate Curvature for Deep Convolutional Neural NetworksKazuki Osawa1, Yohei Tsuji1,3, Yuichiro Ueno1, Akira Naruse2, Rio Yokota1,3, and Satoshi Matsuoka1,41Tokyo Tech, 2NVIDIA, 3AIST-Tokyo Tech RWBC-OIL, AIST, 4RIKEN CCS

Image Classification Discussion | Kazuki’s-FAQOur Approach | Distributed K-FAC

training time = time / iteration × # iterations

Step1: layer-wise block-diagonal Step2: Kronecker-factorization

⌦<latexit sha1_base64="/AopNi+UvXmbne7B02/2UzbATF0=">AAAB7nicbVDLSgNBEOyNrxhfUY9eBoPgKeyqoN4CXjxGMA9IljA7mU2GzM4sM71CCPkILx4U8er3ePNvnCR70MSChqKqm+6uKJXCou9/e4W19Y3NreJ2aWd3b/+gfHjUtDozjDeYltq0I2q5FIo3UKDk7dRwmkSSt6LR3cxvPXFjhVaPOE55mNCBErFgFJ3U6moUCbe9csWv+nOQVRLkpAI56r3yV7evWZZwhUxSazuBn2I4oQYFk3xa6maWp5SN6IB3HFXULQkn83On5MwpfRJr40ohmau/JyY0sXacRK4zoTi0y95M/M/rZBjfhBOh0gy5YotFcSYJajL7nfSF4Qzl2BHKjHC3EjakhjJ0CZVcCMHyy6ukeVENLqv+w1WldpvHUYQTOIVzCOAaanAPdWgAgxE8wyu8ean34r17H4vWgpfPHMMfeJ8/g+CPpg==</latexit>

⇡<latexit sha1_base64="YFCHK6WsItiBx7hYztqr3rYTJRk=">AAAB7nicbVBNSwMxEJ2tX7V+VT16CRbBU9nVgnorePFYwX5Au5Rsmm1Ds0lIsmJZ+iO8eFDEq7/Hm//GtN2Dtj4YeLw3w8y8SHFmrO9/e4W19Y3NreJ2aWd3b/+gfHjUMjLVhDaJ5FJ3ImwoZ4I2LbOcdpSmOIk4bUfj25nffqTaMCke7ETRMMFDwWJGsHVSu4eV0vKpX674VX8OtEqCnFQgR6Nf/uoNJEkTKizh2Jhu4CsbZlhbRjidlnqpoQqTMR7SrqMCJ9SE2fzcKTpzygDFUrsSFs3V3xMZToyZJJHrTLAdmWVvJv7ndVMbX4cZEyq1VJDFojjlyEo0+x0NmKbE8okjmGjmbkVkhDUm1iVUciEEyy+vktZFNbis+ve1Sv0mj6MIJ3AK5xDAFdThDhrQBAJjeIZXePOU9+K9ex+L1oKXzxzDH3ifP5FXj68=</latexit>F`<latexit sha1_base64="K/uBK+ocIla+gKS2lbvERVNb55Y=">AAAB+3icbVDLSsNAFJ34rPVV69LNYBFclUQFdVcQxGUF+4AmhMn0ph06mYSZiVhCfsWNC0Xc+iPu/BsnbRbaemDgcM693DMnSDhT2ra/rZXVtfWNzcpWdXtnd2+/dlDvqjiVFDo05rHsB0QBZwI6mmkO/UQCiQIOvWByU/i9R5CKxeJBTxPwIjISLGSUaCP5tbobET0Owuw29zMXOM+rfq1hN+0Z8DJxStJAJdp+7csdxjSNQGjKiVIDx060lxGpGeWQV91UQULohIxgYKggESgvm2XP8YlRhjiMpXlC45n6eyMjkVLTKDCTRVK16BXif94g1eGVlzGRpBoEnR8KU451jIsi8JBJoJpPDSFUMpMV0zGRhGpTV1GCs/jlZdI9azrnTfv+otG6LuuooCN0jE6Rgy5RC92hNuogip7QM3pFb1ZuvVjv1sd8dMUqdw7RH1ifP/5ElF4=</latexit>

F` ⇡ G` ⌦A`<latexit sha1_base64="ki+vUBlkUNb7xw4uVXpobTJhU4I=">AAACLnicbZDLSsNAFIYnXmu8RV26GSyCq5KooMuqeFlWsBdoQ5hMJ+3QSSbMTMQS+kRufBVdCCri1sdw0qZeWg8M/PzfGc45vx8zKpVtvxgzs3PzC4uFJXN5ZXVt3drYrEmeCEyqmDMuGj6ShNGIVBVVjDRiQVDoM1L3e2cZr98SISmPblQ/Jm6IOhENKEZKW5513gqR6vpBejHw0hZhbGC2UBwLfmeOyeUP4YqGRH6TkzHxrKJdsocFp4WTiyLIq+JZT602x0lIIoUZkrLp2LFyUyQUxYzoSYkkMcI91CFNLSOkx7rp8NwB3NVOGwZc6BcpOHR//0hRKGU/9HVntqicZJn5H2smKjh2UxrFiSIRHg0KEgYVh1l2sE0FwYr1tUBYUL0rxF0kEFY6YVOH4EyePC1q+yXnoGRfHxbLp3kcBbANdsAecMARKIMrUAFVgME9eASv4M14MJ6Nd+Nj1Dpj5H+2wJ8yPr8AB9+qWg==</latexit>

Layer-wise Update:

✓⇤ = arg min✓2Rm

L(✓)<latexit sha1_base64="dtJUyvO7XYM2aoCRDgYUuuK4ZHQ=">AAACYXicbVFNbxMxEPVuA00DbZdy7MUiQiocol2oVC5IUblw4FAQaStl02jWcRKr/ljZs6iRtf2R3Lhw6R+pN8kBmoxk+enNvJnxc1FK4TBN/0TxTuvZ8932XufFy/2Dw+TV0aUzlWV8wIw09roAx6XQfIACJb8uLQdVSH5V3H5p8le/uHXC6J+4KPlIwUyLqWCAgRondz4vjJy4hQpXjnOOUN+8/9zJFeDclD63ioKd3Suh61wKJdCN/RZNLvRSUhT+R33jVV2vWjCQ/lt9siF4N066aS9dBt0E2Rp0yTouxsnvfGJYpbhGJsG5YZaWOPJgUTDJw7jK8RLYLcz4MEANiruRXzpU07eBmdCpseFopEv2X4UH5Zr1QmWztXuaa8htuWGF008jL3RZIddsNWhaSYqGNnbTibCcoVwEAMyKsCtlc7DAMHxKJ5iQPX3yJrj80Ms+9tLvp93++dqONjkmb8gJycgZ6ZOv5IIMCCN/o1a0Hx1ED/FenMRHq9I4Wmtek/8iPn4EXr26ig==</latexit>

L(✓) = 1

|S|X

(x,y)2S

� log p✓(y|x)<latexit sha1_base64="ozgn53gCyErkIyFe95HnIA3MEhA=">AAACanicbVFBb9MwGHUCg81sUIYEmnaxKEitNKoEkOCCNLELhx2GRrdJTRU5rtNac+zI/oKI3Bz4i9z4BbvwI+Z0OXQbn2T56X3v2Z+fs1IKC1H0NwgfPNx49HhzCz/Z3nn6rPd898zqyjA+Zlpqc5FRy6VQfAwCJL8oDadFJvl5dnnU9s9/cmOFVj+gLvm0oHMlcsEoeCrt/U4KCgtGpTtu8AAnmZYzWxd+cwksONAGD/EXnOSGMhc3bnm6bHBiqyJ1A7emJr+aA1IPE6HIaYPfJVLPSZmuK7rjBvXylm2Y9vrRKFoVuQ/iDvRRVydp708y06wquAImqbWTOCph6qgBwST301WWl5Rd0jmfeKhowe3UraJqyFvPzEiujV8KyIpddzha2HY0r2yDsXd7Lfm/3qSC/PPUCVVWwBW7uSivJAFN2tzJTBjOQNYeUGaEn5WwBfWpgv8d7EOI7z75Pjh7P4o/jKLvH/uHX7s4NtE+eo0GKEaf0CH6hk7QGDF0FewEL4NXwb9wN9wL92+kYdB5XqBbFb65Bqggunw=</latexit>

LB(✓) =1

|B|X

(x,y)2B

� log p✓(y|x)<latexit sha1_base64="4to5At4zCAv314UwBQhDvYimZEQ=">AAACbHicbVHPb9MwGHUCg2HGWn4cQBOSRcXUSlAlMGlckKZy4cBhSHSb1FSR4zqtNceO7C+IyM2J/5AbfwIX/gacLodu45MsP73vPfvzc1ZKYSGKfgfhnbs79+7vPsAP9x7t9/qPn5xZXRnGp0xLbS4yarkUik9BgOQXpeG0yCQ/zy4/tf3z79xYodU3qEs+L+hSiVwwCp5K+z+TgsKKUem+NOkED3GSabmwdeE3l8CKA23wCH/ESW4oc3Hj1pN1gxNbFakbui01+dG8IfUoEYpMGvw2kXpJynRb0R03rNfXbKO0P4jG0abIbRB3YIC6Ok37v5KFZlXBFTBJrZ3FUQlzRw0IJrmfrrK8pOySLvnMQ0ULbuduE1ZDXntmQXJt/FJANuy2w9HCtqN5ZRuNvdlryf/1ZhXkH+ZOqLICrtjVRXklCWjSJk8WwnAGsvaAMiP8rIStqE8V/P9gH0J888m3wdm7cfx+HH09GpxMujh20QF6hYYoRsfoBH1Gp2iKGPoT9ILnwYvgb/gsPAhfXknDoPM8RdcqPPwH/967Dw==</latexit>

S = {(x, y)}<latexit sha1_base64="JXqGkuXQPF0gfco+HFMOTNa5nmQ=">AAACA3icbVDLSsNAFJ3UV62vqDvdDBahgpREBd0IRTcuK9oHNKFMJtN26GQSZiZiCAE3/oobF4q49Sfc+TdO2iy0emCYwzn3cu89XsSoVJb1ZZTm5hcWl8rLlZXVtfUNc3OrLcNYYNLCIQtF10OSMMpJS1HFSDcSBAUeIx1vfJn7nTsiJA35rUoi4gZoyOmAYqS01Dd3bs6dtJY6Xsh8mQT6g/fZIUwOnKxvVq26NQH8S+yCVEGBZt/8dPwQxwHhCjMkZc+2IuWmSCiKGckqTixJhPAYDUlPU44CIt10ckMG97Xiw0Eo9OMKTtSfHSkKZL6grgyQGslZLxf/83qxGpy5KeVRrAjH00GDmEEVwjwQ6FNBsGKJJggLqneFeIQEwkrHVtEh2LMn/yXto7p9XLeuT6qNiyKOMtgFe6AGbHAKGuAKNEELYPAAnsALeDUejWfjzXiflpaMomcb/ILx8Q0a+Jcq</latexit>

B ⇢ S<latexit sha1_base64="e++dvwTCUJbcwqxDvAatTlVKtfE=">AAAB8nicbVBNS8NAEN3Ur1q/qh69LBbBU0lU0GOpF48V7QekoWy2k3bpJht2J0Ip/RlePCji1V/jzX/jts1BWx8MPN6bYWZemEph0HW/ncLa+sbmVnG7tLO7t39QPjxqGZVpDk2upNKdkBmQIoEmCpTQSTWwOJTQDke3M7/9BNoIlTziOIUgZoNERIIztJJfp12ThQaQPvTKFbfqzkFXiZeTCsnR6JW/un3FsxgS5JIZ43tuisGEaRRcwrTUzQykjI/YAHxLExaDCSbzk6f0zCp9GiltK0E6V39PTFhszDgObWfMcGiWvZn4n+dnGN0EE5GkGULCF4uiTFJUdPY/7QsNHOXYEsa1sLdSPmSacbQplWwI3vLLq6R1UfUuq+79VaVWz+MokhNySs6JR65JjdyRBmkSThR5Jq/kzUHnxXl3PhatBSefOSZ/4Hz+AG6ukLE=</latexit>

Optimization in Deep Learning

Set of image & class label:

Goal of optimization:

Loss function:

Stochastic learning with mini-batch:

✓(t+1) = ✓(t) � ⌘rLB(✓(t))

<latexit sha1_base64="LZvq7oJAlPga1k6BaJ0ecuggqZQ=">AAACWXicfVFNSwMxEM2uXzV+VXv0sliEilh2VdCLIHrx4EHBqtCtZTZNbWg2uySzQln2T3oQxL/iwWzbg98DIS/vzTAzL1EqhUHff3Xcmdm5+YXKIl1aXlldq65v3Jok04y3WCITfR+B4VIo3kKBkt+nmkMcSX4XDc9L/e6JayMSdYOjlHdieFSiLxigpbrVNIwS2TOj2F55iAOOUDzkDdwNdgp6Qv9QrbZHQ/sKFUQSaBgDDhjI/LLontHGP2U73Wrdb/rj8H6CYArqZBpX3epz2EtYFnOFTIIx7cBPsZODRsEkL2iYGZ4CG8Ijb1uoIOamk4+dKbxty/S8fqLtUeiN2c8VOcSmnNNmliuY71pJ/qa1M+wfd3Kh0gy5YpNG/Ux6mHilzV5PaM5QjiwApoWd1WMD0MDQfga1JgTfV/4JbvebwUHTvz6sn55N7aiQTbJFGiQgR+SUXJAr0iKMvJB3Z86Zd95cx624dJLqOtOaGvkSbu0D4NuzzQ==</latexit>

✓(t+1) = ✓(t) � ⌘F(✓(t))�1rLB(✓(t))

<latexit sha1_base64="DbMEZ34/0/Wa8mbDc1VVk9bWJ/U=">AAACiHicfVFdT9swFHUCDGbYKPC4l2gVUhGiSvgQ7AEJddK0Bx5AooDUlOrGdaiF40T2zaTKym/hP/G2fzOnzcMoG1eyfHzOPfL9SAopDIbhb89fWl75sLr2ka5vfPq82dravjV5qRnvs1zm+j4Bw6VQvI8CJb8vNIcskfwuefpe63e/uDYiVzc4Lfgwg0clUsEAHTVqPcdJLsdmmrnLxjjhCNWD7eB+tFfRc/of1WkHNHYvGmeAkyS1PyraeSd778EeRBWNFSSyMTGQ9rIa9d73jVrtsBvOIngLoga0SRNXo9ZLPM5ZmXGFTIIxgygscGhBo2CSuxJKwwtgT/DIBw4qyLgZ2tkgq2DXMeMgzbU7CoMZ+7fDQmbqOl1m3YJZ1GryX9qgxPRsaIUqSuSKzT9KSxlgHtRbCcZCc4Zy6gAwLVytAZuABoZud9QNIVps+S24PexGR93w+rh90WvGsUa+kK+kQyJySi7IT3JF+oR5K96+d+yd+NQP/VP/2zzV9xrPDnkVfu8PBP3DAQ==</latexit>

F(✓) =1

|B|X

(x,y)2B

r log p✓(y|x)r log p✓(y|x)T

<latexit sha1_base64="I4DIk5KBjBFQ8YnHQcHNgtishjo=">AAACsnicnVFNbxMxEPUuUMr2gwBHLhYRUiKhdBcqtRekKkiIYys1bUU2Xc063sSqP1a2F7Fy9gdy5ca/wZvuIU05MZLlp5n37DczecmZsXH8JwifPH2283z3RbS3f3D4svfq9ZVRlSZ0QhRX+iYHQzmTdGKZ5fSm1BREzul1fvelrV//oNowJS9tXdKZgIVkBSNgfSrr/UoF2GVeuK9NNIjSXPG5qYW/XGqX1EITDaPPUVpoIC5p3Gq8aqLUVCJzA7fBxj+bD7gepkzisSdIyDmkXC1wmW3SujcH9WpLPPwfza1bm9fCXTZNlPX68SheB34Mkg70URfnWe93OlekElRawsGYaRKXduZAW0Y49V1UhpZA7mBBpx5KENTM3HrkDX7vM3NcKO2PtHid3VQ4EKZ16pmtR7Nda5P/qk0rW5zOHJNlZakk9x8VFcdW4XZ/eM40JZbXHgDRzHvFZAl+PdZvuR1Cst3yY3D1cZR8GsUXx/2zcTeOXfQWvUMDlKATdIa+oXM0QSQ4CibBbZCFx+H3EEJyTw2DTvMGPYiQ/wWiW9fL</latexit>

(empirical) Fisher information matrix

Which Optimizer is “Fast” ?

SGDNGDK-FAC

Good BadBad

Not BadGoodGood

W (t+1)` = W (t)

` � ⌘�G�1

` ⌦A�1`

�rW`LB(✓

(t))<latexit sha1_base64="zyeAln1ZqzygnYinkXFipiTy3Vg=">AAAClHicbVFda9swFJW9r077aLrBXvZiFgYJo8HeBu1DB2nL2B7G6GBpCnFmZOU6EZUlI10PgvAv2r/Z2/7N5CQtXboLQodz7tW59yqvpLAYx3+C8M7de/cf7Dykjx4/ebrb2Xt2bnVtOIy4ltpc5MyCFApGKFDCRWWAlbmEcX552urjn2Cs0Oo7LiuYlmyuRCE4Q09lnV/jzKUgZfPD9fBN0m/oB3qT8sQ+TQEZTSUU2KNpyXCRF+5Tc521nzQ01ShKsNfy8ZZsxHyBfZoqlkuWuSuLZl3AmXRfmuyE+vdzLWd2WfrLpbjwzld99LNONx7Eq4hug2QDumQTZ1nndzrTvC5BIZfM2kkSVzh1zKDgErx3baFi/JLNYeKhYn6CqVsttYlee2YWFdr4ozBasTcrHCtt26fPbEew21pL/k+b1FgcTp1QVY2g+NqoqGWEOmp/KJoJAxzl0gPGjfC9RnzBDOPo/5H6JSTbI98G528HybtB/O19d3iyWccOeUlekR5JyAEZks/kjIwID/aCg2AYHIcvwqPwNPy4Tg2DTc1z8k+EX/8CVVzHpw==</latexit>

G`,A`,rW`LB(✓)<latexit sha1_base64="M6Bk0ZScoigx/ZE9T3mETChvqvg=">AAACRXicbZBNSxxBEIZ7NB86JnE1Ry+Di6AQlhkNJEc1h+TgwUDWFXaWoaa3xm3s6Rm6a4SlmT/nxbu3/AMvHhKCV+1Zx5CoBU2/PFVFVb1pKYWhMPzpzc2/ePnq9cKiv/Tm7bvlzsrqkSkqzbHPC1no4xQMSqGwT4IkHpcaIU8lDtLTL01+cIbaiEL9oGmJoxxOlMgEB3Io6cRxDjRJM/u1TmyMUtYfHsjeX+LHClIJiR20pPZnRRykPaiTfX/Tj9NCjs00d5+NaYIEtb+VdLphL5xF8FREreiyNg6TzmU8LniVoyIuwZhhFJY0sqBJcIluamWwBH4KJzh0UkGOZmRnLtTBhiPjICu0e4qCGf23w0Jumg1dZbO8eZxr4HO5YUXZ55EVqqwIFb8flFUyoCJoLA3GQiMnOXUCuBZu14BPQAMnZ7zvTIgen/xUHG33op1e+P1jd3e/tWOBrbF1tski9ontsm/skPUZZ+fsiv1iv70L79r7493cl855bc979l94t3cip7Mk</latexit>

GPU

Data parallelism Layers parallelism

G1,A1,rW1LB(✓)<latexit sha1_base64="2JY50+sPueerunN5oR5YPu+V+ic=">AAACPHicbVBNb9NAEF2XAsHlw8CRi9WoUiqhyC5IcAzpAQ4cUrX5kOLIGm/WySrrtbU7RopW/mFc+iN664kLBxDiypl1kkpt0pFW++a9Gc3MSwrBNQbBtbP3YP/ho8eNJ+7B02fPX3gvXw10XirK+jQXuRoloJngkvWRo2CjQjHIEsGGyeK01offmNI8lxe4LNgkg5nkKaeAloq98ygDnCep+VzFJqze3qSf1qkbSUgExGZYp5W7kikI87WKu27LjZJcTPUys5+JcM4QKvc49ppBO1iFvwvCDWiSTfRi7yqa5rTMmEQqQOtxGBQ4MaCQU8Hs1FKzAugCZmxsoYSM6YlZHV/5R5aZ+mmu7JPor9jbHQYyXW9oK+vl9bZWk/dp4xLTjxPDZVEik3Q9KC2Fj7lfO+lPuWIUxdICoIrbXX06BwUUrd+uNSHcPnkXDE7a4bt2cPa+2elu7GiQN+SQtEhIPpAO+UJ6pE8o+U5+kF/kt3Pp/HT+OH/XpXvOpuc1uRPOv/8QL66S</latexit>

GL,AL,rWLLB(✓)<latexit sha1_base64="B6SWyLQLdi0gDl2vCtqaPz/TlVw=">AAACPHicbVBLS8NAEN74Nr6qHr0Ei6AgJVFBjz4OevCgaK3QlDDZbuzSzSbsToQS8sO8+CO8efLiQRGvnt20FXwNLPvN980wM1+YCq7RdR+tkdGx8YnJqWl7ZnZufqGyuHSlk0xRVqeJSNR1CJoJLlkdOQp2nSoGcShYI+welXrjlinNE3mJvZS1YriRPOIU0FBB5cKPATthlB8XQX5abH6lB4PU9iWEAoK8UaaF3ZcpCJMEh/a67YeJaOtebL7cxw5DKOyNoFJ1a24/nL/AG4IqGcZZUHnw2wnNYiaRCtC66bkptnJQyKlgZmqmWQq0CzesaaCEmOlW3j++cNYM03aiRJkn0emz3ztyiHW5oaksl9e/tZL8T2tmGO21ci7TDJmkg0FRJhxMnNJJp80Voyh6BgBV3Ozq0A4ooGj8to0J3u+T/4KrrZq3XXPPd6r7h0M7psgKWSXrxCO7ZJ+ckDNSJ5TckSfyQl6te+vZerPeB6Uj1rBnmfwI6+MTnFqu4w==</latexit>

W1<latexit sha1_base64="DJ3oSPWbOjvYHqbBy0e9sjM3rAI=">AAAB7HicbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cKpi20oWy2m3bpZhN2J0IJ/Q1ePCji1R/kzX/jts1BWx8MPN6bYWZemEph0HW/ndLa+sbmVnm7srO7t39QPTxqmSTTjPsskYnuhNRwKRT3UaDknVRzGoeSt8Px3cxvP3FtRKIecZLyIKZDJSLBKFrJb/dzb9qv1ty6OwdZJV5BalCg2a9+9QYJy2KukElqTNdzUwxyqlEwyaeVXmZ4StmYDnnXUkVjboJ8fuyUnFllQKJE21JI5urviZzGxkzi0HbGFEdm2ZuJ/3ndDKObIBcqzZArtlgUZZJgQmafk4HQnKGcWEKZFvZWwkZUU4Y2n4oNwVt+eZW0LureZd19uKo1bos4ynACp3AOHlxDA+6hCT4wEPAMr/DmKOfFeXc+Fq0lp5g5hj9wPn8AnmiOjQ==</latexit>

WL<latexit sha1_base64="tF/sBEaYW+NWcpu+dOg68mSbGok=">AAAB7HicbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRiwcPFYwttKFstpN26WYTdjdCCf0NXjwo4tUf5M1/47bNQVsfDDzem2FmXpgKro3rfjulldW19Y3yZmVre2d3r7p/8KiTTDH0WSIS1Q6pRsEl+oYbge1UIY1Dga1wdDP1W0+oNE/kgxmnGMR0IHnEGTVW8lu9/G7Sq9bcujsDWSZeQWpQoNmrfnX7CctilIYJqnXHc1MT5FQZzgROKt1MY0rZiA6wY6mkMeognx07ISdW6ZMoUbakITP190ROY63HcWg7Y2qGetGbiv95ncxEV0HOZZoZlGy+KMoEMQmZfk76XCEzYmwJZYrbWwkbUkWZsflUbAje4svL5PGs7p3X3fuLWuO6iKMMR3AMp+DBJTTgFprgAwMOz/AKb450Xpx352PeWnKKmUP4A+fzB8dvjqg=</latexit>

W`<latexit sha1_base64="I9+bXfKP8Ew9Qjb2D5+Iz1Z7Hsw=">AAAB73icbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cK9gPaUDbbSbt0s0l3N0IJ/RNePCji1b/jzX/jts1BWx8MPN6bYWZekAiujet+O4W19Y3NreJ2aWd3b/+gfHjU1HGqGDZYLGLVDqhGwSU2DDcC24lCGgUCW8Hobua3nlBpHstHM0nQj+hA8pAzaqzUbvWyLgox7ZUrbtWdg6wSLycVyFHvlb+6/ZilEUrDBNW647mJ8TOqDGcCp6VuqjGhbEQH2LFU0gi1n83vnZIzq/RJGCtb0pC5+nsio5HWkyiwnRE1Q73szcT/vE5qwhs/4zJJDUq2WBSmgpiYzJ4nfa6QGTGxhDLF7a2EDamizNiISjYEb/nlVdK8qHqXVffhqlK7zeMowgmcwjl4cA01uIc6NICBgGd4hTdn7Lw4787HorXg5DPH8AfO5w82RJAT</latexit>

B<latexit sha1_base64="sVHOAzGHs33lE5PsQrwykCZchyA=">AAAB6HicbVBNS8NAEJ34WetX1aOXxSJ4KokKeiz14rEF+wFtKJvtpF272YTdjVBCf4EXD4p49Sd589+4bXPQ1gcDj/dmmJkXJIJr47rfztr6xubWdmGnuLu3f3BYOjpu6ThVDJssFrHqBFSj4BKbhhuBnUQhjQKB7WB8N/PbT6g0j+WDmSToR3QoecgZNVZq1Pqlsltx5yCrxMtJGXLU+6Wv3iBmaYTSMEG17npuYvyMKsOZwGmxl2pMKBvTIXYtlTRC7WfzQ6fk3CoDEsbKljRkrv6eyGik9SQKbGdEzUgvezPxP6+bmvDWz7hMUoOSLRaFqSAmJrOvyYArZEZMLKFMcXsrYSOqKDM2m6INwVt+eZW0LiveVcVtXJertTyOApzCGVyABzdQhXuoQxMYIDzDK7w5j86L8+58LFrXnHzmBP7A+fwBlQuMyA==</latexit>

DNN r log p✓(y|x)<latexit sha1_base64="m1TUBRLsTO87CZPgomikPWail28=">AAACHXicbVBNSwMxEM3Wr1q/qh69BIugl7Krgh5FLx4rWBW6pcymaRvMJksyK5a1f8SLf8WLB0U8eBH/jdnaQ219EPJ4b4aZeVEihUXf//YKM7Nz8wvFxdLS8srqWnl948rq1DBeZ1pqcxOB5VIoXkeBkt8khkMcSX4d3Z7l/vUdN1ZodYn9hDdj6CrREQzQSa3yYaggkhBK3aVJKwsjLdu2H7svxB5HGOz2H8ZVej/YK7XKFb/qD0GnSTAiFTJCrVX+DNuapTFXyCRY2wj8BJsZGBRM8kEpTC1PgN1ClzccVRBz28yG1w3ojlPatKONewrpUB3vyCC2+XKuMgbs2UkvF//zGil2jpuZUEmKXLHfQZ1UUtQ0j4q2heEMZd8RYEa4XSnrgQGGLtA8hGDy5GlytV8NDqr+xWHl5HQUR5FskW2ySwJyRE7IOamROmHkkTyTV/LmPXkv3rv38Vta8EY9m+QPvK8f90WjDA==</latexit>

DNN r log p✓(y|x)<latexit sha1_base64="m1TUBRLsTO87CZPgomikPWail28=">AAACHXicbVBNSwMxEM3Wr1q/qh69BIugl7Krgh5FLx4rWBW6pcymaRvMJksyK5a1f8SLf8WLB0U8eBH/jdnaQ219EPJ4b4aZeVEihUXf//YKM7Nz8wvFxdLS8srqWnl948rq1DBeZ1pqcxOB5VIoXkeBkt8khkMcSX4d3Z7l/vUdN1ZodYn9hDdj6CrREQzQSa3yYaggkhBK3aVJKwsjLdu2H7svxB5HGOz2H8ZVej/YK7XKFb/qD0GnSTAiFTJCrVX+DNuapTFXyCRY2wj8BJsZGBRM8kEpTC1PgN1ClzccVRBz28yG1w3ojlPatKONewrpUB3vyCC2+XKuMgbs2UkvF//zGil2jpuZUEmKXLHfQZ1UUtQ0j4q2heEMZd8RYEa4XSnrgQGGLtA8hGDy5GlytV8NDqr+xWHl5HQUR5FskW2ySwJyRE7IOamROmHkkTyTV/LmPXkv3rv38Vta8EY9m+QPvK8f90WjDA==</latexit>

DNN r log p✓(y|x)<latexit sha1_base64="m1TUBRLsTO87CZPgomikPWail28=">AAACHXicbVBNSwMxEM3Wr1q/qh69BIugl7Krgh5FLx4rWBW6pcymaRvMJksyK5a1f8SLf8WLB0U8eBH/jdnaQ219EPJ4b4aZeVEihUXf//YKM7Nz8wvFxdLS8srqWnl948rq1DBeZ1pqcxOB5VIoXkeBkt8khkMcSX4d3Z7l/vUdN1ZodYn9hDdj6CrREQzQSa3yYaggkhBK3aVJKwsjLdu2H7svxB5HGOz2H8ZVej/YK7XKFb/qD0GnSTAiFTJCrVX+DNuapTFXyCRY2wj8BJsZGBRM8kEpTC1PgN1ClzccVRBz28yG1w3ojlPatKONewrpUB3vyCC2+XKuMgbs2UkvF//zGil2jpuZUEmKXLHfQZ1UUtQ0j4q2heEMZd8RYEa4XSnrgQGGLtA8hGDy5GlytV8NDqr+xWHl5HQUR5FskW2ySwJyRE7IOamROmHkkTyTV/LmPXkv3rv38Vta8EY9m+QPvK8f90WjDA==</latexit>

GPU

GPU

......

......

......

......

Each GPU handlesdifferent parts of mini-batch

Each GPU applies K-FAC tothe weights of different layers(the weights will be synchronized.)

Reduce-Scatter

Kazuki Osawa, Ph.D. candidate at Tokyo Tech(oosawa.k.ad at m.titech.ac.jp)

supervised by Rio Yokota

Kronecker product

This work (75.0%)

Yamazaki+, 2019 (75.0%)

Mikami+, 2018 (75.3%)

Jia+, 2018 (75.8%)

Akiba+, 2018 (74.9%)

mini-batch size

B<latexit sha1_base64="sVHOAzGHs33lE5PsQrwykCZchyA=">AAAB6HicbVBNS8NAEJ34WetX1aOXxSJ4KokKeiz14rEF+wFtKJvtpF272YTdjVBCf4EXD4p49Sd589+4bXPQ1gcDj/dmmJkXJIJr47rfztr6xubWdmGnuLu3f3BYOjpu6ThVDJssFrHqBFSj4BKbhhuBnUQhjQKB7WB8N/PbT6g0j+WDmSToR3QoecgZNVZq1Pqlsltx5yCrxMtJGXLU+6Wv3iBmaYTSMEG17npuYvyMKsOZwGmxl2pMKBvTIXYtlTRC7WfzQ6fk3CoDEsbKljRkrv6eyGik9SQKbGdEzUgvezPxP6+bmvDWz7hMUoOSLRaFqSAmJrOvyYArZEZMLKFMcXsrYSOqKDM2m6INwVt+eZW0LiveVcVtXJertTyOApzCGVyABzdQhXuoQxMYIDzDK7w5j86L8+58LFrXnHzmBP7A+fwBlQuMyA==</latexit>

# iterations

Tesla V100 x 1024

Tesla V100 x 2048

Tesla V100 x 3456

Tesla P40 x 2048

Tesla P100 x1024

training time|B|<latexit sha1_base64="0IPFuuvd8PBpK2rh4m/UotPiO6M=">AAAB6nicbVDLTgJBEOzFF+IL9ehlIjHxRHbVRI8ELx4xyiOBDZkdBpgwO7uZ6TUhC5/gxYPGePWLvPk3DrAHBSvppFLVne6uIJbCoOt+O7m19Y3Nrfx2YWd3b/+geHjUMFGiGa+zSEa6FVDDpVC8jgIlb8Wa0zCQvBmMbmd+84lrIyL1iOOY+yEdKNEXjKKVHibVSbdYcsvuHGSVeBkpQYZat/jV6UUsCblCJqkxbc+N0U+pRsEknxY6ieExZSM64G1LFQ258dP5qVNyZpUe6UfalkIyV39PpDQ0ZhwGtjOkODTL3kz8z2sn2L/xU6HiBLlii0X9RBKMyOxv0hOaM5RjSyjTwt5K2JBqytCmU7AheMsvr5LGRdm7LLv3V6VKNYsjDydwCufgwTVU4A5qUAcGA3iGV3hzpPPivDsfi9ack80cwx84nz9YDY3U</latexit>

131,072 978

81,92055,296

65,53632,768

1,4402,086

1,8003,519 15 min

6.6 min2 min

1.2 min

10 min

SGD

K-FAC

K-FAC converges faster than SGD (SOTA results)• Our Distributed K-FAC enabled tuning large mini-batch K-FAC on ImageNet in realistic time.• We tuned K-FAC with extremely large mini-batch size {4096, 8192, 16384, 32768, 65536, 131072}.• We achieved validation accuracy >75% in {10948, 5434, 2737, 1760, 1173, 978} iterations.• We showed the advantage of a second-order method (K-FAC) over a first-order method (SGD)

in ImageNet classification for the first time.

Competitive training time• Data and layers hybrid parallelism introduced in our design

allowed us to train on 1024 GPUs• We achieved 74.9% in 10 min by using K-FAC with the

stale Fisher information matrix for mini-batch size 32,768.• (32images/GPU x 1024GPUs = 32,768images)

(|B|=32,768)

Layer

Layer

Layer

……

1<latexit sha1_base64="8jLNIv/S8WBvHtfFxahEaJPNMnU=">AAAB6HicbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cW7Ae0oWy2k3btZhN2N0IJ/QVePCji1Z/kzX/jts1BWx8MPN6bYWZekAiujet+O4W19Y3NreJ2aWd3b/+gfHjU0nGqGDZZLGLVCahGwSU2DTcCO4lCGgUC28H4bua3n1BpHssHM0nQj+hQ8pAzaqzU8Prlilt15yCrxMtJBXLU++Wv3iBmaYTSMEG17npuYvyMKsOZwGmpl2pMKBvTIXYtlTRC7WfzQ6fkzCoDEsbKljRkrv6eyGik9SQKbGdEzUgvezPxP6+bmvDGz7hMUoOSLRaFqSAmJrOvyYArZEZMLKFMcXsrYSOqKDM2m5INwVt+eZW0LqreZdVtXFVqt3kcRTiBUzgHD66hBvdQhyYwQHiGV3hzHp0X5935WLQWnHzmGP7A+fwBe0eMtw==</latexit>

<̀latexit sha1_base64="tO8xm45MXnYyaIGHhHOD/E2+VfU=">AAAB63icbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cK9gPaUDbbSbt0swm7G6GE/gUvHhTx6h/y5r9x0+agrQ8GHu/NMDMvSATXxnW/ndLa+sbmVnm7srO7t39QPTxq6zhVDFssFrHqBlSj4BJbhhuB3UQhjQKBnWByl/udJ1Sax/LRTBP0IzqSPOSMmlzqoxCDas2tu3OQVeIVpAYFmoPqV38YszRCaZigWvc8NzF+RpXhTOCs0k81JpRN6Ah7lkoaofaz+a0zcmaVIQljZUsaMld/T2Q00noaBbYzomasl71c/M/rpSa88TMuk9SgZItFYSqIiUn+OBlyhcyIqSWUKW5vJWxMFWXGxlOxIXjLL6+S9kXdu6y7D1e1xm0RRxlO4BTOwYNraMA9NKEFDMbwDK/w5kTOi/PufCxaS04xcwx/4Hz+AA4Qjj0=</latexit>

L<latexit sha1_base64="9PF+tKjSQN8i8QYPTB/Bgp+YVSU=">AAAB6HicbVA9SwNBEJ3zM8avqKXNYhCswp0KWgZtLCwSMB+QHGFvM5es2ds7dveEcOQX2FgoYutPsvPfuEmu0MQHA4/3ZpiZFySCa+O6387K6tr6xmZhq7i9s7u3Xzo4bOo4VQwbLBaxagdUo+ASG4Ybge1EIY0Cga1gdDv1W0+oNI/lgxkn6Ed0IHnIGTVWqt/3SmW34s5AlomXkzLkqPVKX91+zNIIpWGCat3x3MT4GVWGM4GTYjfVmFA2ogPsWCpphNrPZodOyKlV+iSMlS1pyEz9PZHRSOtxFNjOiJqhXvSm4n9eJzXhtZ9xmaQGJZsvClNBTEymX5M+V8iMGFtCmeL2VsKGVFFmbDZFG4K3+PIyaZ5XvIuKW78sV2/yOApwDCdwBh5cQRXuoAYNYIDwDK/w5jw6L8678zFvXXHymSP4A+fzB6QzjNI=</latexit>

softmax

x<latexit sha1_base64="DQivoxdGTcXKHyd2yP4tsN1A+g4=">AAAB+HicbVDLSgMxFL3js9ZHR126CRbBVZlRQd0V3LisYB/QDiWTSdvQTDIkGbEO/RI3LhRx66e482/MtLPQ1gMhh3PuJScnTDjTxvO+nZXVtfWNzdJWeXtnd6/i7h+0tEwVoU0iuVSdEGvKmaBNwwynnURRHIectsPxTe63H6jSTIp7M0loEOOhYANGsLFS3630QskjPYntlT1Oy3236tW8GdAy8QtShQKNvvvViyRJYyoM4Vjrru8lJsiwMoxwOi33Uk0TTMZ4SLuWChxTHWSz4FN0YpUIDaSyRxg0U39vZDjWeTY7GWMz0oteLv7ndVMzuAoyJpLUUEHmDw1SjoxEeQsoYooSwyeWYKKYzYrICCtMjO0qL8Ff/PIyaZ3V/POad3dRrV8XdZTgCI7hFHy4hDrcQgOaQCCFZ3iFN+fJeXHenY/56IpT7BzCHzifPwMuk0c=</latexit>

p(y|x)<latexit sha1_base64="bJo+0U6EvP+XBvCn65GwURolgJI=">AAAB/XicbVA7T8MwGHR4lvIKj43FokIqS5UAErBVYmEsEn1IbVQ5jtNadezIdhAhVPwVFgYQYuV/sPFvcNoM0HKS5dPd98nn82NGlXacb2thcWl5ZbW0Vl7f2Nzatnd2W0okEpMmFkzIjo8UYZSTpqaakU4sCYp8Rtr+6Cr323dEKir4rU5j4kVowGlIMdJG6tv7cTV97PmCBSqNzJXdj4/Lfbvi1JwJ4DxxC1IBBRp9+6sXCJxEhGvMkFJd14m1lyGpKWZkXO4lisQIj9CAdA3lKCLKyybpx/DIKAEMhTSHazhRf29kKFJ5ODMZIT1Us14u/ud1Ex1eeBnlcaIJx9OHwoRBLWBeBQyoJFiz1BCEJTVZIR4iibA2heUluLNfnietk5p7WnNuzir1y6KOEjgAh6AKXHAO6uAaNEATYPAAnsEreLOerBfr3fqYji5Yxc4e+APr8wd195Uv</latexit>

r log p✓(y|x)<latexit sha1_base64="m1TUBRLsTO87CZPgomikPWail28=">AAACHXicbVBNSwMxEM3Wr1q/qh69BIugl7Krgh5FLx4rWBW6pcymaRvMJksyK5a1f8SLf8WLB0U8eBH/jdnaQ219EPJ4b4aZeVEihUXf//YKM7Nz8wvFxdLS8srqWnl948rq1DBeZ1pqcxOB5VIoXkeBkt8khkMcSX4d3Z7l/vUdN1ZodYn9hDdj6CrREQzQSa3yYaggkhBK3aVJKwsjLdu2H7svxB5HGOz2H8ZVej/YK7XKFb/qD0GnSTAiFTJCrVX+DNuapTFXyCRY2wj8BJsZGBRM8kEpTC1PgN1ClzccVRBz28yG1w3ojlPatKONewrpUB3vyCC2+XKuMgbs2UkvF//zGil2jpuZUEmKXLHfQZ1UUtQ0j4q2heEMZd8RYEa4XSnrgQGGLtA8hGDy5GlytV8NDqr+xWHl5HQUR5FskW2ySwJyRE7IOamROmHkkTyTV/LmPXkv3rv38Vta8EY9m+QPvK8f90WjDA==</latexit>

(at least for a small DNN)

F�1` ⇡ (G` ⌦A`)

�1 = G�1` ⌦A�1

`<latexit sha1_base64="aTsHk+gKFAQTrxG/F/wOQn/2o/s=">AAACgHicdVFdS+NAFJ3E7/hV9dGXYBEqaE12BUUQ3BXURwWrQlPLZHrTDp1kwsyNWEJ/x/6vfdsfs+CkjZ/VCwOHc86dc+dOmAqu0fP+WfbU9Mzs3PyCs7i0vLJaWVu/1TJTDBpMCqnuQ6pB8AQayFHAfaqAxqGAu7B/Vuh3j6A0l8kNDlJoxbSb8IgzioZqV/4EMcVeGOXnw3YegBDDh3zPHzoBTVMln5xAQIQ158V18eJyAok8Bv2q/HpTFO/2cGd8z8lkaxnwXf9YbleqXt0blTsJ/BJUSVlX7crfoCNZFkOCTFCtm76XYiunCjkTYPIyDSllfdqFpoEJNeGtfLTAobttmI4bSWVOgu6Ifd+R01jrQRwaZzGu/qwV5FdaM8PoqJXzJM0QEjYOijLhonSL33A7XAFDMTCAMsXNrC7rUUUZmj8rluB/fvIkuP1R93/WveuD6unvch3zZJNskRrxySE5JZfkijQII/+tqrVr7dm2XbP3bX9sta2yZ4N8KPv4GRM8wpw=</latexit>

K-FAC (J. Martens+, 2015) accelerates NGD1

<latexit sha1_base64="8jLNIv/S8WBvHtfFxahEaJPNMnU=">AAAB6HicbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cW7Ae0oWy2k3btZhN2N0IJ/QVePCji1Z/kzX/jts1BWx8MPN6bYWZekAiujet+O4W19Y3NreJ2aWd3b/+gfHjU0nGqGDZZLGLVCahGwSU2DTcCO4lCGgUC28H4bua3n1BpHssHM0nQj+hQ8pAzaqzU8Prlilt15yCrxMtJBXLU++Wv3iBmaYTSMEG17npuYvyMKsOZwGmpl2pMKBvTIXYtlTRC7WfzQ6fkzCoDEsbKljRkrv6eyGik9SQKbGdEzUgvezPxP6+bmvDGz7hMUoOSLRaFqSAmJrOvyYArZEZMLKFMcXsrYSOqKDM2m5INwVt+eZW0LqreZdVtXFVqt3kcRTiBUzgHD66hBvdQhyYwQHiGV3hzHp0X5935WLQWnHzmGP7A+fwBe0eMtw==</latexit> L<latexit sha1_base64="9PF+tKjSQN8i8QYPTB/Bgp+YVSU=">AAAB6HicbVA9SwNBEJ3zM8avqKXNYhCswp0KWgZtLCwSMB+QHGFvM5es2ds7dveEcOQX2FgoYutPsvPfuEmu0MQHA4/3ZpiZFySCa+O6387K6tr6xmZhq7i9s7u3Xzo4bOo4VQwbLBaxagdUo+ASG4Ybge1EIY0Cga1gdDv1W0+oNI/lgxkn6Ed0IHnIGTVWqt/3SmW34s5AlomXkzLkqPVKX91+zNIIpWGCat3x3MT4GVWGM4GTYjfVmFA2ogPsWCpphNrPZodOyKlV+iSMlS1pyEz9PZHRSOtxFNjOiJqhXvSm4n9eJzXhtZ9xmaQGJZsvClNBTEymX5M+V8iMGFtCmeL2VsKGVFFmbDZFG4K3+PIyaZ5XvIuKW78sV2/yOApwDCdwBh5cQRXuoAYNYIDwDK/w5jw6L8678zFvXXHymSP4A+fzB6QzjNI=</latexit>

L<latexit sha1_base64="9PF+tKjSQN8i8QYPTB/Bgp+YVSU=">AAAB6HicbVA9SwNBEJ3zM8avqKXNYhCswp0KWgZtLCwSMB+QHGFvM5es2ds7dveEcOQX2FgoYutPsvPfuEmu0MQHA4/3ZpiZFySCa+O6387K6tr6xmZhq7i9s7u3Xzo4bOo4VQwbLBaxagdUo+ASG4Ybge1EIY0Cga1gdDv1W0+oNI/lgxkn6Ed0IHnIGTVWqt/3SmW34s5AlomXkzLkqPVKX91+zNIIpWGCat3x3MT4GVWGM4GTYjfVmFA2ogPsWCpphNrPZodOyKlV+iSMlS1pyEz9PZHRSOtxFNjOiJqhXvSm4n9eJzXhtZ9xmaQGJZsvClNBTEymX5M+V8iMGFtCmeL2VsKGVFFmbDZFG4K3+PIyaZ5XvIuKW78sV2/yOApwDCdwBh5cQRXuoAYNYIDwDK/w5jw6L8678zFvXXHymSP4A+fzB6QzjNI=</latexit>

<̀latexit sha1_base64="tO8xm45MXnYyaIGHhHOD/E2+VfU=">AAAB63icbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cK9gPaUDbbSbt0swm7G6GE/gUvHhTx6h/y5r9x0+agrQ8GHu/NMDMvSATXxnW/ndLa+sbmVnm7srO7t39QPTxq6zhVDFssFrHqBlSj4BJbhhuB3UQhjQKBnWByl/udJ1Sax/LRTBP0IzqSPOSMmlzqoxCDas2tu3OQVeIVpAYFmoPqV38YszRCaZigWvc8NzF+RpXhTOCs0k81JpRN6Ah7lkoaofaz+a0zcmaVIQljZUsaMld/T2Q00noaBbYzomasl71c/M/rpSa88TMuk9SgZItFYSqIiUn+OBlyhcyIqSWUKW5vJWxMFWXGxlOxIXjLL6+S9kXdu6y7D1e1xm0RRxlO4BTOwYNraMA9NKEFDMbwDK/w5kTOi/PufCxaS04xcwx/4Hz+AA4Qjj0=</latexit>

⇡<latexit sha1_base64="YFCHK6WsItiBx7hYztqr3rYTJRk=">AAAB7nicbVBNSwMxEJ2tX7V+VT16CRbBU9nVgnorePFYwX5Au5Rsmm1Ds0lIsmJZ+iO8eFDEq7/Hm//GtN2Dtj4YeLw3w8y8SHFmrO9/e4W19Y3NreJ2aWd3b/+gfHjUMjLVhDaJ5FJ3ImwoZ4I2LbOcdpSmOIk4bUfj25nffqTaMCke7ETRMMFDwWJGsHVSu4eV0vKpX674VX8OtEqCnFQgR6Nf/uoNJEkTKizh2Jhu4CsbZlhbRjidlnqpoQqTMR7SrqMCJ9SE2fzcKTpzygDFUrsSFs3V3xMZToyZJJHrTLAdmWVvJv7ndVMbX4cZEyq1VJDFojjlyEo0+x0NmKbE8okjmGjmbkVkhDUm1iVUciEEyy+vktZFNbis+ve1Sv0mj6MIJ3AK5xDAFdThDhrQBAJjeIZXePOU9+K9ex+L1oKXzxzDH3ifP5FXj68=</latexit>F<latexit sha1_base64="Z2BwPHtICJV56+jNcSMdaPtC3Mk=">AAAB8nicbVBNS8NAFHypX7V+VT16WSyCp5KooN4KgnisYG0hLWWz3bRLN5uw+yKU0J/hxYMiXv013vw3btoctHVgYZh5j503QSKFQdf9dkorq2vrG+XNytb2zu5edf/g0cSpZrzFYhnrTkANl0LxFgqUvJNoTqNA8nYwvsn99hPXRsTqAScJ70V0qEQoGEUr+d2I4igIs9tppV+tuXV3BrJMvILUoECzX/3qDmKWRlwhk9QY33MT7GVUo2CSTyvd1PCEsjEdct9SRSNuetks8pScWGVAwljbp5DM1N8bGY2MmUSBncwjmkUvF//z/BTDq14mVJIiV2z+UZhKgjHJ7ycDoTlDObGEMi1sVsJGVFOGtqW8BG/x5GXyeFb3zuvu/UWtcV3UUYYjOIZT8OASGnAHTWgBgxie4RXeHHRenHfnYz5acoqdQ/gD5/MH516Q9w==</latexit>

F1<latexit sha1_base64="Ebti3VxoDqcsvQZzcTC2AEX0/Ew=">AAAB9HicbVDLSgMxFL3js9ZX1aWbYBFclRkV1F1BEJcV7APaoWTSO21oJjMmmUIZ+h1uXCji1o9x59+YabvQ1gOBwzn3ck9OkAiujet+Oyura+sbm4Wt4vbO7t5+6eCwoeNUMayzWMSqFVCNgkusG24EthKFNAoENoPhbe43R6g0j+WjGSfoR7QvecgZNVbyOxE1gyDM7iZdr9gtld2KOwVZJt6clGGOWrf01enFLI1QGiao1m3PTYyfUWU4EzgpdlKNCWVD2se2pZJGqP1sGnpCTq3SI2Gs7JOGTNXfGxmNtB5HgZ3MQ+pFLxf/89qpCa/9jMskNSjZ7FCYCmJikjdAelwhM2JsCWWK26yEDaiizNie8hK8xS8vk8Z5xbuouA+X5erNvI4CHMMJnIEHV1CFe6hBHRg8wTO8wpszcl6cd+djNrrizHeO4A+czx8UtZGb</latexit>

F`<latexit sha1_base64="K/uBK+ocIla+gKS2lbvERVNb55Y=">AAAB+3icbVDLSsNAFJ34rPVV69LNYBFclUQFdVcQxGUF+4AmhMn0ph06mYSZiVhCfsWNC0Xc+iPu/BsnbRbaemDgcM693DMnSDhT2ra/rZXVtfWNzcpWdXtnd2+/dlDvqjiVFDo05rHsB0QBZwI6mmkO/UQCiQIOvWByU/i9R5CKxeJBTxPwIjISLGSUaCP5tbobET0Owuw29zMXOM+rfq1hN+0Z8DJxStJAJdp+7csdxjSNQGjKiVIDx060lxGpGeWQV91UQULohIxgYKggESgvm2XP8YlRhjiMpXlC45n6eyMjkVLTKDCTRVK16BXif94g1eGVlzGRpBoEnR8KU451jIsi8JBJoJpPDSFUMpMV0zGRhGpTV1GCs/jlZdI9azrnTfv+otG6LuuooCN0jE6Rgy5RC92hNuogip7QM3pFb1ZuvVjv1sd8dMUqdw7RH1ifP/5ElF4=</latexit>

FL<latexit sha1_base64="U4zlQGJYazqwpeWacw47A6eoGFc=">AAAB9HicbVDLSgMxFL1TX7W+qi7dBIvgqsyooO4KgrhwUcE+oB1KJs20oUlmTDKFMvQ73LhQxK0f486/MdPOQlsPBA7n3Ms9OUHMmTau++0UVlbX1jeKm6Wt7Z3dvfL+QVNHiSK0QSIeqXaANeVM0oZhhtN2rCgWAaetYHST+a0xVZpF8tFMYuoLPJAsZAQbK/ldgc0wCNPbae++1CtX3Ko7A1omXk4qkKPeK391+xFJBJWGcKx1x3Nj46dYGUY4nZa6iaYxJiM8oB1LJRZU++ks9BSdWKWPwkjZJw2aqb83Uiy0nojATmYh9aKXif95ncSEV37KZJwYKsn8UJhwZCKUNYD6TFFi+MQSTBSzWREZYoWJsT1lJXiLX14mzbOqd151Hy4qteu8jiIcwTGcggeXUIM7qEMDCDzBM7zCmzN2Xpx352M+WnDynUP4A+fzBz28kbY=</latexit>

1<latexit sha1_base64="8jLNIv/S8WBvHtfFxahEaJPNMnU=">AAAB6HicbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cW7Ae0oWy2k3btZhN2N0IJ/QVePCji1Z/kzX/jts1BWx8MPN6bYWZekAiujet+O4W19Y3NreJ2aWd3b/+gfHjU0nGqGDZZLGLVCahGwSU2DTcCO4lCGgUC28H4bua3n1BpHssHM0nQj+hQ8pAzaqzU8Prlilt15yCrxMtJBXLU++Wv3iBmaYTSMEG17npuYvyMKsOZwGmpl2pMKBvTIXYtlTRC7WfzQ6fkzCoDEsbKljRkrv6eyGik9SQKbGdEzUgvezPxP6+bmvDGz7hMUoOSLRaFqSAmJrOvyYArZEZMLKFMcXsrYSOqKDM2m5INwVt+eZW0LqreZdVtXFVqt3kcRTiBUzgHD66hBvdQhyYwQHiGV3hzHp0X5935WLQWnHzmGP7A+fwBe0eMtw==</latexit>

<̀latexit sha1_base64="tO8xm45MXnYyaIGHhHOD/E2+VfU=">AAAB63icbVBNS8NAEJ3Ur1q/qh69LBbBU0lU0GPRi8cK9gPaUDbbSbt0swm7G6GE/gUvHhTx6h/y5r9x0+agrQ8GHu/NMDMvSATXxnW/ndLa+sbmVnm7srO7t39QPTxq6zhVDFssFrHqBlSj4BJbhhuB3UQhjQKBnWByl/udJ1Sax/LRTBP0IzqSPOSMmlzqoxCDas2tu3OQVeIVpAYFmoPqV38YszRCaZigWvc8NzF+RpXhTOCs0k81JpRN6Ah7lkoaofaz+a0zcmaVIQljZUsaMld/T2Q00noaBbYzomasl71c/M/rpSa88TMuk9SgZItFYSqIiUn+OBlyhcyIqSWUKW5vJWxMFWXGxlOxIXjLL6+S9kXdu6y7D1e1xm0RRxlO4BTOwYNraMA9NKEFDMbwDK/w5kTOi/PufCxaS04xcwx/4Hz+AA4Qjj0=</latexit>

Summary | Large-Scale Distributed Deep Learning

2 Rm⇥m<latexit sha1_base64="kMLkKnmGVKRipm3Ckj/wTithe6U=">AAACAnicbVA9SwNBEJ3zM8avqJXYLAbBKtypoGXQxjKK+YDcGfY2m2TJ7t6xuyeE47Dxr9hYKGLrr7Dz37iXpNDEBwOP92aYmRfGnGnjut/OwuLS8spqYa24vrG5tV3a2W3oKFGE1knEI9UKsaacSVo3zHDaihXFIuS0GQ6vcr/5QJVmkbwzo5gGAvcl6zGCjZU6pX2fSV9gMwjD9Da7T4VvmKAaiaxTKrsVdww0T7wpKcMUtU7py+9GJBFUGsKx1m3PjU2QYmUY4TQr+ommMSZD3KdtSyW2e4J0/EKGjqzSRb1I2ZIGjdXfEykWWo9EaDvza/Wsl4v/ee3E9C6ClMk4MVSSyaJewpGJUJ4H6jJFieEjSzBRzN6KyAArTIxNrWhD8GZfnieNk4p3WnFvzsrVy2kcBTiAQzgGD86hCtdQgzoQeIRneIU358l5cd6dj0nrgjOd2YM/cD5/AOPFl7w=</latexit>

Large mini-batch training is essential• The goal of large-scale distributed deep learning is accelerating training by using many

GPUs at the same time.• Larger mini-batch allows us to utilize more GPUs and to converges in fewer iterations.

But, large mini-batch training is difficult …• SGD with large mini-batch (> 8K) suffers from the generalization gap (P. Goyal+, 2018).• Previous approaches tried to solve this problem by ad-hoc modification of SGD.• K-FAC with large mini-batch (<= 2K) was said to be able to generalize well (J. Ba+, 2017),

but the comparison to SGD wasn’t done with well-trained models.• No one has shown the clear advantage of K-FAC over SGD since K-FAC requires extra

(heavy) computation/communication that need to be distributed among many GPUs.

K-FAC can handle extremely large mini-batch (Our Contribution)• We implemented Distributed K-FAC which allowed us to scale K-FAC up to 1024 GPUs.• We trained ResNet-50 on ImageNet with extremely large mini-batch size (131K).• We show the advantage of K-FAC over SGD for large mini-batch for the first time.

Benchmark | Training ResNet-50 on ImageNet in 978 iterations / 10 minutes with K-FAC

https://github.com/tyohei/chainerkfac