使用 accelerator 进行多机训练
本文以DiffSynth-Studio为示例,展示基于开发机进行 accelerator 多机训练的具体操作方式。
DiffSynth-Studio 是一个面向主流 Diffusion 模型的统一训练与推理平台,支持多模型接入、可视化与高效分布式加速。本次 accelerator 多机训练以 DiffSynth-Studio中 Wan-AI为主,使用的是 Wan2.2-I2V-A14B 模型。
准备开发机
建议准备4台A800 8卡的开发机,分别命名为node0
,node1
,node2
,node3
,同时挂载一同块共享存储至/data
目录,具体参考这里。
准备源码并安装依赖
从github下载代码:
cd /data
git clone -b wan2.2 https://ghfast.top/github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
安装依赖的 pip 包:
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
安装额外依赖:
pip install deepspeed -i https://pypi.tuna.tsinghua.edu.cn/simple
apt-get update
apt install -y netcat
apt install net-tools -y
apt-get install git-lfs
注意:
- 4 台开发机都需要安装 Python 包和额外依赖。
下载示例视频数据集
DiffSynth-Studio 源码中包含了一个示例视频数据集,可用于测试训练流程。运行以下代码进行数据集下载:
cd /data
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./example_video_dataset
检查下载结果:
ls -lh ./example_video_dataset
如果显示文件大小仅为几 KB,说明实际数据未下载成功。此时执行以下命令拉取数据:
cd example_video_dataset
git lfs install
git lfs pull
准备训练配置
本算法使用 YAML 文件作为启动参数,示例位于:examples/wanvideo/model_training/full/accelerate_config_14B.yaml
。
将其复制展示如下:
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_config:
dynamo_backend: INDUCTOR
dynamo_mode: default
dynamo_use_dynamic: true
dynamo_use_fullgraph: true
dynamo_use_regional_compilation: true
enable_cpu_affinity: false
machine_rank: 0
main_process_ip: '10.233.xx.xx'
main_process_port: 12345
main_training_function: main
mixed_precision: bf16
num_machines: 4
num_processes: 32
rdzv_backend: static
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
我们有4个节点,需要为每个节点准备一个 YAML 文件,分别命名为:
accelerate_config_14B_0.yaml
accelerate_config_14B_1.yaml
accelerate_config_14B_2.yaml
accelerate_config_14B_3.yaml
重点参数介绍
machine_rank
:当前节点在所有训练节点中的序号,4 个节点分别设置为:0, 1, 2, 3
。main_process_ip
:主节点的 IP 地址,我们选取node0
的eth0
网络的IP地址。main_process_port
:主节点用于分布式通信的端口。确保端口未被占用。示例中使用端口12345
。num_machines
:参与训练的总机器数。本实例为4
。num_processes
:总进程数,等于 机器数 × 每机 GPU 数。本实例为 4 × 8 =32
。
启动多机训练
启动脚本位于源代码该路径下:examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh
为便于区分,可将启动脚本copy4份,分别命名为:
Wan2.2-T2V-A14B_0.sh
Wan2.2-T2V-A14B_1.sh
Wan2.2-T2V-A14B_2.sh
Wan2.2-T2V-A14B_3.sh
为了确保在英博云环境中,多机训练中的通信效率和稳定性,推荐脚本开头加入 NCCL 相关环境变量及参数配置,如下所示:
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_DISABLE=0
export NCCL_IB_GID_INDEX=3
相关启动参数修改
--config_file
:改为对应的 YAML 文件路径,需要改成上述生成的 YAML 配置文件,注意各台开发机的启动脚本中该参数,需要改为各自开发机对应的 YAML 配置文件路径。--model_paths
:改为对应模型的本地路径,英博云在共享存储中提供多种内置模型及数据集,使用本地模型无需等待下载时间,该实例使用Wan2.2-I2V-A14B
模型,可以将路径地址修改如下:
'[
[
"/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors",
"/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors",
"/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors",
"/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors",
"/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors",
"/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors"
],
"/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/models_t5_umt5-xxl-enc-bf16.pth",
"/public/huggingface-models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
]'
--model_id_with_origin_paths
:联网下载的 model_id 匹配,由于我们设置了本地路径参数--model_paths
,所以请删除该参数设置。
注意:
- 详细参数设置可以参考源文档:https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/README_zh.md。
- 章节位于:模型训练 -> Step 2: 加载模型
在4个node分别启动训练脚本
在 node0
开发机上运行:
bash Wan2.2-T2V-A14B_0.sh
在剩余 3 台开发机上依次运行:
bash Wan2.2-T2V-A14B_1.sh
bash Wan2.2-T2V-A14B_2.sh
bash Wan2.2-T2V-A14B_3.sh
等待训练结束,完成时间约10-20分钟