NeurIPS2021

Efficiently Identifying Task Groupings for Multi-Task Learning

Chris Fifty, Ehsan Amid, Zhe Zhao, Tianhe Yu, Rohan Anil, Chelsea Finn

313 citations

Abstract

Multi-task learning can leverage information learned by one task to benefit the training of other tasks. Despite this capacity, naïvely training all tasks together in one model often degrades performance, and exhaustively searching through combinations of task groupings can be prohibitively expensive. As a result, efficiently identifying the tasks that would benefit from training together remains a challenging design question without a clear solution. In this paper, we suggest an approach to select which tasks should train together in multi-task learning models. Our method determines task groupings in a single run by training all tasks together and quantifying the effect to which one task's gradient would affect another task's loss. On the large-scale Taskonomy computer vision dataset, we find this method can decrease test loss by 10.0% compared to simply training all tasks together while operating 11.6 times faster than a state-of-the-art task grouping method. Related Work Task Groupings. Prevailing wisdom suggests tasks which are similar or share a similar underlying structure may benefit from training together in a multi-task system [9, 8, 4] . Early work in this domain pertaining to the convex setting assume all tasks share a common latent feature representation, and find that model performance can be significantly improved by clustering tasks based on the basis vectors they share in this latent space [26, 30] . However, early convex methods to determine task groupings often make prohibitive assumptions that do not scale to deep neural networks. Deciding which tasks should train together in multi-task neural networks has traditionally been addressed with costly cross-validation techniques or high variance human intuition. An altogether