Unity

5 minute read

Published:

title: Unity: Accelerating DNN Training Through Joint Optimization of Algebraic Transformations and Parallelization

authors: Colin Unger, Zhihao Jia, Wei Wu, Sina Lin, Mandeep Baines, Carlos Efrain Quintero Narvaez, Vinay Ramakrishnaiah, Nirmal Prajapati, Pat McCormick, Jamaludin Mohd-Yusof, Xi Luo, Dheevatsa Mudigere, Jongsoo Park, Misha Smelyanskiy, and Alex Aiken

(PS: This note is written in the writing test for Ph.D. recruitment in IIIS, Tsinghua 2023SU)

Note: Unity

author: Bohan Yang

With the growth of foundation models, the training cost for storage and time is rapidly increasing. How to effectively train models in distributed nodes with high throughput could be a significant factor in industrial applications. The prior works on large distributed DNN training systems separated the algebraic transformation and parallelization strategies. This intuitive method limits the further optimization of the computation graph and may lead to a suboptimal mapping. In this paper, the authors solve this problem by automatic joint optimizations of algebraic transformation and parallelization. The end2end system contains Unified Parallel Computation Graph(PCG), Graph Substitution Generation, and Joint Optimization by hierarchical search. In short, they use a unified representation to include both algebraic transformation and parallelization at the same time. Any transformation could be implemented by graph substitutions. Candidate PCGs will be evaluated and selected. Based on cost models, the optimal substitution will show up and finish one iteration.

The biggest merit of this work is the solution to the exponential search space of joint optimizations. In fact, the reason that the prior works only limit to a specific topic is the overhead to search the optimal strategy. Hybrid optimization is attractive but combined search space will probably explode. This paper, however, uses the hierarchical search to effectively converge at the optimal point. There are three levels, graph splitting with scalability to large models, substitution candidate selection, and machine mapping. Also, Unity predefines the worst threshold for candidate pruning. The bad search zone will be moved from the candidate stack, and the candidate PCG with the best performance will be chosen and applied as an iteration. By this means, the search time for Unity is reduced to an acceptable level while gaining the performance benefit from hybrid optimization.

However, note that there are still other DNN structures and optimization methods for PCG which are not considered in Unity. Unity assumes that modern network architectures consist of static linear chains of independent strands of parallel computation. But this is not always the case, while some networks have feedback branches (RNN), nonlinear branches(NASNet), and dynamic branches(SkipNet) to have better representation power. These structures violate Unity’s simple assumption and require a more unified method to include these networks. Also, optimization methods are not limited to algebraic transformation and parallelization. Recent works propose more tricks like tensor offloading and rematerialization which have their effect on specific workloads and explore more optimization chances. Actually, under certain conditions, Unity may lead to a suboptimal point and hurt overall performance.

As for further research, I think it can be conducted in the following ways. First, the heterogeneous computing system is a hotspot because real workloads present different burdens to the hardware. Some may heavily rely on the control flow of CPUs, while others may require the data parallel ability of GPU/NPUs. For DNN workloads, both are needed to support future architecture, and this could be a disaster for the scheduler. As we all know, the even distribution of workloads is the key factor for load balance and has been researched a lot in modern operating systems since many CPU cores are not equal in energy consumption and performance. Maybe we could learn something from the core scheduler and apply methods to machine learning systems (e.g. TVM with Unity) to scale to more general and heterogeneous hardware. Second, this work is mainly talking about optimizing DNN training and may neglect the characteristics of inference workloads. In inference, latency and throughput are the most important aspects that could influence the application success of this inference engine. So if we want to move to inference workloads, more work must be done to shorten the key path and explore more parallelization chances. Note that the inference computation graph may be smaller than the training PCG, and data dependency is limited so more tricks can be used. I believe we could find more parallelization methods by profiling and analyzing the inference PCG. Finally, I plan to research more on the heuristic search to converge more quickly in the search space. A good search strategy is a key to convergence and prior works rely on greedy algorithms based on cost function. And clearly, this paper still follows this trend. But we all know that the greedy method may not be the best descent algorithm, which can be proven in optimization theory. That is why momentum (Adam), BFGS, and Newton method were proposed to solve this challenge. They all show their strength in their scenarios and may be transferred to the search algorithm to accelerate the predictably huge parameter space of LLM in the future. And I believe we could do more on this.