NeurIPS2024
Training Binary Neural Networks via Gaussian Variational Inference and Low-Rank Semidefinite Programming
Lorenzo Orecchia, Jiawei Hu, Xue He, Wang Mark, Xulei Yang, Min Wu, Xue Geng
Abstract
Improving the training of Binarized Neural Networks (BNNs) is a longstanding challenge whose outcome can significantly affect our ability to deploy deep learning ubiquitously. Current methods heavily rely on latent weights and the heuristic straight-through estimator (STE), which enable the application of SGD-based optimizers to the combinatorial training problem, but remain theoretically poorly understood. In this paper, we propose an optimization framework for BNN training based on Gaussian variational inference. Our approach yields a non-convex linear programming formulation that theoretically motivates the use of latent weights, STE and weight clipping . More importantly, it allows us to go beyond latent weights to formulate and solve low-rank semidefinite programming (SDP) relaxations that explicitly model and learn pairwise correlations between weights during training , resulting in improved accuracy. Our empirical evaluation on CIFAR-10, CIFAR-100, Tiny-ImageNet and ImageNet datasets shows our method consistently outperforms all state-of-the-art algorithms for training BNNs.