分布式訓(xùn)練加速(TorchAcc)
PAI-TorchAcc(Torch Accelerator)是基于PyTorch的訓(xùn)練加速框架,通過GraphCapture技術(shù)將PyTorch動(dòng)態(tài)圖轉(zhuǎn)換為靜態(tài)執(zhí)行圖,然后進(jìn)一步基于計(jì)算圖完成分布式優(yōu)化、計(jì)算優(yōu)化,從而提高PyTorch模型訓(xùn)練的效率,使其更加易于使用。
技術(shù)簡介
TorchAcc是動(dòng)靜一體的分布式訓(xùn)練加速框架,主要功能如下:
通過GraphCapture技術(shù)將動(dòng)態(tài)圖轉(zhuǎn)化為靜態(tài)圖。
通過編譯優(yōu)化手段提升訓(xùn)練性能。
通過顯存優(yōu)化降低資源開銷。
通過半精度通信、通信壓縮、通信overlap等通信優(yōu)化技術(shù)來提高通信效率。
提供自動(dòng)和半自動(dòng)分布式策略,支持大模型高效訓(xùn)練。
訓(xùn)練數(shù)據(jù)讀取優(yōu)化:
Prefetcher:進(jìn)行數(shù)據(jù)預(yù)取,讓數(shù)據(jù)預(yù)處理和訓(xùn)練能夠同時(shí)進(jìn)行,從而減少數(shù)據(jù)處理的等待時(shí)間,提高訓(xùn)練效率。
Packed dataset:通過高效的數(shù)據(jù)打包方式,減少無效計(jì)算,并提高數(shù)據(jù)讀取效率。
Preprocess Cache:緩存預(yù)處理后的數(shù)據(jù),減少數(shù)據(jù)預(yù)處理開銷。
產(chǎn)品架構(gòu)
深度學(xué)習(xí)框架按照執(zhí)行模式可分為兩個(gè)大的類別:
graph mode:以TensorFlow 1.*為代表的框架采用graph mode的方式運(yùn)行。其優(yōu)點(diǎn)是系統(tǒng)優(yōu)化友好、面向生產(chǎn)、訓(xùn)推一體,而缺點(diǎn)是面向用戶不夠友好、代碼撰寫不夠靈活、開發(fā)和Debug困難。
eager mode:以Pytorch為代表的框架采用eager mode的方式運(yùn)行。其優(yōu)點(diǎn)是靈活性好、容易開發(fā)和Debug。但是對于框架優(yōu)化來說不友好,系統(tǒng)優(yōu)化困難。
針對上述問題,TorchAcc的目標(biāo)是在保持Pytorch靈活性的基礎(chǔ)上,為模型訓(xùn)練提供系統(tǒng)的優(yōu)化處理。TorchAcc的架構(gòu)圖如下所示。
TorchAcc的核心邏輯如下:
通過LazyTensor+HybridDispatcher將PyTorch中的eager execution轉(zhuǎn)換為IR表達(dá)式,然后進(jìn)行計(jì)算優(yōu)化、顯存優(yōu)化以及自動(dòng)并行化等一系列優(yōu)化處理后,再將優(yōu)化處理后的IR交給后端進(jìn)行進(jìn)一步的優(yōu)化和CodeGen。同時(shí),TorchAcc還提供了許多手動(dòng)算子優(yōu)化的kernel實(shí)現(xiàn),可以針對特定算子進(jìn)行優(yōu)化,來提高算子的計(jì)算效率和性能。