自定義模型接入TorchAcc
阿里云PAI為您提供了部分典型場景下的示例模型,便于您便捷地接入TorchAcc進(jìn)行訓(xùn)練加速,同時(shí)也支持將自行開發(fā)的模型接入TorchAcc進(jìn)行加速,本文為您介紹如何在自定義模型中接入TorchAcc以提高分布式訓(xùn)練速度和效率。
背景信息
TorchAcc的優(yōu)化方式分為以下兩類,您可以根據(jù)實(shí)際需求選擇合適的優(yōu)化方式,以提高模型訓(xùn)練速度和效率。
編譯優(yōu)化
TorchAcc支持將PyTorch動(dòng)態(tài)圖轉(zhuǎn)換為靜態(tài)圖,并進(jìn)行計(jì)算圖優(yōu)化和編譯,以提高模型訓(xùn)練速度和效率。TorchAcc會(huì)將計(jì)算圖轉(zhuǎn)換為高效的計(jì)算圖,并使用JIT編譯器將其編譯為更高效的代碼。這樣可以避免PyTorch動(dòng)態(tài)圖計(jì)算過程中的一些性能損失,并提高模型訓(xùn)練速度和效率。
定制優(yōu)化
當(dāng)模型包含Dynamic Shape、Custom算子、Dynamic ControlFlow等特性時(shí),暫時(shí)無法應(yīng)用全局編譯優(yōu)化進(jìn)行分布式訓(xùn)練加速。針對此類場景,TorchAcc提供了定制優(yōu)化:
IO優(yōu)化
計(jì)算(Kernel)優(yōu)化
顯存優(yōu)化
TorchAcc編譯優(yōu)化
接入分布式訓(xùn)練
接入TorchAcc的Compiler進(jìn)行分布式訓(xùn)練,具體操作步驟如下:
固定隨機(jī)種子。
通過固定隨機(jī)種子保證每個(gè)Worker權(quán)重的初始化保持一致,用于代替權(quán)重broadcast的效果。
torch.manual_seed(SEED_NUMBER) 替換為: xm.set_rng_state(SEED_NUMBER)
在獲取xla_device后,調(diào)用set_replication、封裝dataloader并設(shè)置model device placement。
device = xm.xla_device() xm.set_replication(device, [device]) # Wrapper dataloader data_loader_train = pl.MpDeviceLoader(data_loader_train, device) data_loader_val = pl.MpDeviceLoader(data_loader_val, device) # Dispatch device to model model.to(device)
分布式初始化。
將dist.init_process_group的backend參數(shù)配置為'xla':
dist.init_process_group(backend='xla', init_method='env://')
梯度allreduce通信。
在loss backward后對梯度進(jìn)行allreduce操作:
gradients=xm._fetch_gradients(optimizer) xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())
重要如果使用混合精度AMP訓(xùn)練,且手動(dòng)調(diào)用了scaler.unscale_,一定要在scaler.unscale_之前調(diào)用xm.all_reduce,以確保基于all_reduce之后的梯度進(jìn)行溢出檢測。
使用xlarun拉起任務(wù)。
xlarun --nproc_per_node=8 YOUR_MODEL.py
說明多機(jī)情況使用方法與torchrun相同。
接入混合精度
通過混合精度訓(xùn)練可以加速模型訓(xùn)練速度,在單卡訓(xùn)練或分布式訓(xùn)練的基礎(chǔ)上按照以下步驟完成AMP邏輯的實(shí)現(xiàn)。在上一章節(jié)基礎(chǔ)上接入混合精度進(jìn)行TorchAcc編譯優(yōu)化的具體操作步驟如下。
按照pytorch原生功能實(shí)現(xiàn)AMP。
TorchAcc混合精度與Pytorch原生混合精度使用方法基本一致,請先參照以下文檔實(shí)現(xiàn)Pytorch原生的AMP功能。
替換GradScaler。
將torch.cuda.amp.GradScaler替換為torchacc.torch_xla.amp.GradScaler:
from torchacc.torch_xla.amp import GradScaler
替換optimizer。
使用原生PyTorch optimizer性能會(huì)稍差,可將torch.optim的optimizer替換為syncfree optimizer來進(jìn)一步提升訓(xùn)練速度。
from torchacc.torch_xla.amp import syncfree adam_optimizer = syncfree.Adam() adamw_optimizer = syncfree.AdamW() sgd_optimizer = syncfree.SGD()
目前syncfree optimizer只提供了以上三類optimizer的實(shí)現(xiàn),其它類型optimizer可繼續(xù)使用PyTorch原生optimizer即可。
接入案例
以Bert-base模型為例,代碼示例如下:
import argparse
import os
import time
import torch
import torch.distributed as dist
from datasets import load_from_disk
from datetime import datetime as dt
from time import gmtime, strftime
from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding
# Pytorch1.12 default set False.
torch.backends.cuda.matmul.allow_tf32=True
parser = argparse.ArgumentParser()
parser.add_argument("--amp-level", choices=["O1"], default="O1", help="amp level.")
parser.add_argument("--profile", action="store_true")
parser.add_argument("--profile_folder", type=str, default="./profile_folder")
parser.add_argument("--dataset_path", type=str, default="./sst_data/train")
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--max_seq_length", type=int, default=512)
parser.add_argument("--break_step_for_profiling", type=int, default=20)
parser.add_argument("--model_name", type=str, default="bert-base-cased")
parser.add_argument("--local_rank", type=int, default="-1")
parser.add_argument("--log-interval", type=int, default="10")
parser.add_argument('--max-steps', type=int, default=200, help='total training epochs.')
args = parser.parse_args()
print("Job running args: ", args)
args.local_rank = os.getenv("LOCAL_RANK", 0)
+def enable_torchacc_compiler():
+ return os.getenv('TORCHACC_COMPILER_OPT') is not None
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.get_rank() == 0:
print(message, flush=True)
def print_test_update(epoch, step, batch_size, loss, time_elapsed, samples_per_step, peak_mem):
# Getting the current date and time
dt = strftime("%a, %d %b %Y %H:%M:%S", gmtime())
print_rank_0(train_format_string.format(dt, epoch, step, batch_size, loss, time_elapsed, samples_per_step, peak_mem))
def log_metrics(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem):
batch_time = f"{batch_time:.3f}"
samples_per_step = f"{samples_per_step:.3f}"
peak_mem = f"{peak_mem:.3f}"
+ if enable_torchacc_compiler():
+ import torchacc.torch_xla.core.xla_model as xm
+ xm.add_step_closure(
+ print_test_update, args=(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem), run_async=True)
+ else:
print_test_update(epoch, step, batch_size, loss, batch_time, samples_per_step, peak_mem)
+if enable_torchacc_compiler():
+ import torchacc.torch_xla.core.xla_model as xm
+ import torchacc.torch_xla.distributed.parallel_loader as pl
+ import torchacc.torch_xla.distributed.xla_backend
+ from torchacc.torch_xla.amp import autocast, GradScaler, syncfree
+ xm.set_rng_state(101)
+ dist.init_process_group(backend="xla", init_method="env://")
+else:
from torch.cuda.amp import autocast, GradScaler
dist.init_process_group(backend="nccl", init_method="env://")
dist.barrier()
args.world_size = dist.get_world_size()
args.rank = dist.get_rank()
print("world size:", args.world_size, " rank:", args.rank, " local rank:", args.local_rank)
def get_autocast_and_scaler():
+ if enable_torchacc_compiler():
+ return autocast, GradScaler()
return autocast, GradScaler()
def loop_with_amp(model, inputs, optimizer, autocast, scaler):
with autocast():
outputs = model(**inputs)
loss = outputs["loss"]
scaler.scale(loss).backward()
+ if enable_torchacc_compiler():
+ gradients = xm._fetch_gradients(optimizer)
+ xm.all_reduce('sum', gradients, scale=1.0/xm.xrt_world_size())
scaler.step(optimizer)
scaler.update()
return loss, optimizer
def loop_without_amp(model, inputs, optimizer):
outputs = model(**inputs)
loss = outputs["loss"]
loss.backward()
+ if enable_torchacc_compiler():
+ xm.optimizer_step(optimizer)
+ else:
optimizer.step()
return loss, optimizer
def full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler, profiler=None):
model.train()
iteration_time = time.time()
num_steps = int(len(train_device_loader.dataset) / args.batch_size)
for step, inputs in enumerate(train_device_loader):
if step > args.max_steps:
break
+ if not enable_torchacc_compiler():
inputs.to(device)
optimizer.zero_grad()
if args.amp_level == "O1":
loss, optimizer = loop_with_amp(model, inputs, optimizer, autocast, scaler)
else:
loss, optimizer = loop_without_amp(model, inputs, optimizer)
if args.profile and profiler:
profiler.step()
if step % args.log_interval == 0:
time_elapsed = (time.time() - iteration_time) / args.log_interval
iteration_time = time.time()
samples_per_step = float(args.batch_size / time_elapsed) * args.world_size
peak_mem = torch.cuda.memory_allocated()/1024.0/1024.0/1024.0
log_metrics(epoch, step, args.batch_size, loss, time_elapsed, samples_per_step, peak_mem)
def train_bert():
model = AutoModelForSequenceClassification.from_pretrained(args.model_name, cache_dir="./model")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.model_max_length = args.max_seq_length
training_dataset = load_from_disk(args.dataset_path)
collator = DataCollatorWithPadding(tokenizer)
training_dataset = training_dataset.remove_columns(['text'])
train_device_loader = torch.utils.data.DataLoader(
training_dataset, batch_size=args.batch_size, collate_fn=collator, shuffle=True, num_workers=4)
+ if enable_torchacc_compiler():
+ device = xm.xla_device()
+ xm.set_replication(device, [device])
+ train_device_loader = pl.MpDeviceLoader(train_device_loader, device)
+ model = model.to(device)
+ else:
device = torch.device(f"cuda:{args.local_rank}")
torch.cuda.set_device(device)
model = model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
+ if enable_torchacc_compiler() and args.amp_level == "O1":
+ optimizer = syncfree.Adam(model.parameters(), lr=1e-3)
+ else:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
autocast, scaler = None, None
if args.amp_level == "O1":
autocast, scaler = get_autocast_and_scaler()
if args.profile:
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=2, warmup=2, active=20),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.profile_folder)) as prof:
for epoch in range(args.num_epochs):
full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler, profiler=prof)
else:
for epoch in range(args.num_epochs):
full_train_epoch(epoch, model, train_device_loader, device, optimizer, autocast, scaler)
if __name__ == "__main__":
train_bert()
TorchAcc定制優(yōu)化
IO優(yōu)化
Data Prefetcher
支持預(yù)先讀取訓(xùn)練數(shù)據(jù),且提供preprocess_fn參數(shù)支持?jǐn)?shù)據(jù)預(yù)處理。
+ from torchacc.runtime.io.prefetcher import Prefetcher
data_loader = build_data_loader()
model = build_model()
optimizer = build_optimizer()
# define preprocess function
preprocess_fn = None
+ prefetcher = Prefetcher(data_loader, preprocess_fn)
for iter, samples in enumerate(prefetcher):
loss = model(samples)
loss.backward()
# Prefetch to CPU first. Call after backward and before update.
# At this point we are waiting for kernels launched by cuda graph
# to finish, so CPU is idle. Take advantage of this by loading next
# input batch before calling step.
+ prefetcher.prefetch_CPU()
optimizer.step()
# Prefetch to GPU. Call after optimizer step.
+ prefetcher.prefetch_GPU()
Pack Dataset
語言數(shù)據(jù)集都存在變長的情況,例如文本句子、語音等。為了提高計(jì)算效率,利用樣本的長短不一致的問題,將幾個(gè)樣本打包到一起,組成一個(gè)固定shape的batch,減少padding的0值占比和batch data的動(dòng)態(tài)性,從而提高EPOCH的(分布式)訓(xùn)練效率。
pin memory
在dataloader定義時(shí)增加pin_memory參數(shù),并適量增加num_workers。
計(jì)算優(yōu)化
Kernel Fusion優(yōu)化
支持以下幾種優(yōu)化方式:
FusedLayerNorm
# LayerNorm的等價(jià)替換kernel from torchacc.runtime import hooks # add before import torch hooks.enable_fused_layer_norm()
FusedAdam
# Adam/AdamW的等價(jià)替換kernel from torchacc.runtime import hooks # add before import torch hooks.enable_fused_adam()
QuickGelu
# 用QuickGelu替換nn.GELU from torchacc.runtime.nn.quick_gelu import QuickGelu
fused_bias_dropout_add
# from torchacc.runtime.nn import dropout_add_fused_train, #將Dropout和element-wise的bias add等操作fuse起來 if self.training: # train mode with torch.enable_grad(): x = dropout_add_fused_train(x, to_add, drop_rate) else: # inference mode x = dropout_add_fused(x, to_add, drop_rate)
WindowProcess
# WindowProcess優(yōu)化kernel 融合了SwinTransformer中關(guān)于shift window及window劃分的操作,包括 - window cyclic shift和window partition - window merge和reverse cyclic shift。 from torchacc.runtime.nn.window_process import WindowProcess if not fused: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C else: x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) from torchacc.runtime.nn.window_process import WindowProcessReverse if not fused: shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
FusedSwinFmha
# 融合了SwinTransformer中MHA的qk_result + relative_position_bias + mask + softmax部分 from torchacc.runtime.nn.fmha import FusedSwinFmha FusedSwinFmha.apply(attn, relative_pos_bias, attn_mask, batch_size, window_num, num_head, window_len)
nms/nms_normal/soft_nms/batched_soft_nms
# 融合了nms/nms_normal/soft_nms/batched_soft_nms等四類算子cuda kernel實(shí)現(xiàn)。 from torchacc.runtime.nn.nms import nms, nms_normal from torchacc.runtime.nn.nms import soft_nms, batched_soft_nms
Parallelized Kernel優(yōu)化
DCN/DCNv2:
# 對dcn_v2_cuda后向進(jìn)行了并行計(jì)算優(yōu)化。
from torchacc.runtime.nn.dcn_v2 import DCN, DCNv2
self.conv = DCN(chi, cho, kernel_size, stride, padding, dilation, deformable_groups)
Multi-stream Kernel優(yōu)化
利用多個(gè)stream來并發(fā)計(jì)算函數(shù)的一組輸入,計(jì)算邏輯同mmdet.core.multi_apply函數(shù)。
from torchacc.runtime.utils.misc import multi_apply_multi_stream
from mmdet.core import multi_apply
def test_func(t1, t2, t3):
t1 = t1 * 2.0
t2 = t2 + 2.0
t3 = t3 - 2.0
return (t1, t2, t3)
cuda = torch.device('cuda')
t1 = torch.empty((100, 1000), device=cuda).normal_(0.0, 1.0)
t2 = torch.empty((100, 1000), device=cuda).normal_(0.0, 2.0)
t3 = torch.empty((100, 1000), device=cuda).normal_(0.0, 3.0)
if enable_torchacc:
result = multi_apply_multi_stream(test_func, 2, t1, t2, t3)
else:
result = multi_apply(test_func, t1, t2, t3)
顯存優(yōu)化
Gradient Checkpointing
import torchacc
model = torchacc.auto_checkpoint(model)