使用 accelerator 进行多机训练

本文以DiffSynth-Studio为示例,展示基于开发机进行 accelerator 多机训练的具体操作方式。

DiffSynth-Studio 是一个面向主流 Diffusion 模型的统一训练与推理平台,支持多模型接入、可视化与高效分布式加速。本次 accelerator 多机训练以 DiffSynth-Studio中 Wan-AI为主,使用的是 Wan2.2-I2V-A14B 模型。

准备开发机

建议准备4台A800 8卡的开发机,分别命名为node0node1node2node3,同时挂载一同块共享存储至/data目录,具体参考这里

准备源码并安装依赖

从github下载代码:
cd /data
git clone -b wan2.2 https://ghfast.top/github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
安装依赖的 pip 包:
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
安装额外依赖:
pip install deepspeed -i https://pypi.tuna.tsinghua.edu.cn/simple
apt-get update
apt install -y netcat
apt install net-tools -y
apt-get install git-lfs

注意:

  • 4 台开发机都需要安装 Python 包和额外依赖。
下载示例视频数据集

DiffSynth-Studio 源码中包含了一个示例视频数据集,可用于测试训练流程。运行以下代码进行数据集下载:

cd /data
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./example_video_dataset

检查下载结果:

ls -lh ./example_video_dataset

如果显示文件大小仅为几 KB,说明实际数据未下载成功。此时执行以下命令拉取数据:

cd example_video_dataset
git lfs install
git lfs pull

准备训练配置

本算法使用 YAML 文件作为启动参数,示例位于:examples/wanvideo/model_training/full/accelerate_config_14B.yaml

将其复制展示如下:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  deepspeed_multinode_launcher: standard
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: true
  zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_config:
  dynamo_backend: INDUCTOR
  dynamo_mode: default
  dynamo_use_dynamic: true
  dynamo_use_fullgraph: true
  dynamo_use_regional_compilation: true
enable_cpu_affinity: false
machine_rank: 0
main_process_ip: '10.233.xx.xx'
main_process_port: 12345
main_training_function: main
mixed_precision: bf16
num_machines: 4
num_processes: 32
rdzv_backend: static
same_network: false
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

我们有4个节点,需要为每个节点准备一个 YAML 文件,分别命名为:

accelerate_config_14B_0.yaml
accelerate_config_14B_1.yaml
accelerate_config_14B_2.yaml
accelerate_config_14B_3.yaml
重点参数介绍
  • machine_rank:当前节点在所有训练节点中的序号,4 个节点分别设置为: 0, 1, 2, 3
  • main_process_ip:主节点的 IP 地址,我们选取 node0eth0 网络的IP地址。
  • main_process_port:主节点用于分布式通信的端口。确保端口未被占用。示例中使用端口 12345
  • num_machines:参与训练的总机器数。本实例为 4
  • num_processes:总进程数,等于 机器数 × 每机 GPU 数。本实例为 4 × 8 = 32

启动多机训练

启动脚本位于源代码该路径下:examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh

为便于区分,可将启动脚本copy4份,分别命名为:

Wan2.2-T2V-A14B_0.sh
Wan2.2-T2V-A14B_1.sh
Wan2.2-T2V-A14B_2.sh
Wan2.2-T2V-A14B_3.sh

为了确保在英博云环境中,多机训练中的通信效率和稳定性,推荐脚本开头加入 NCCL 相关环境变量及参数配置,如下所示:

export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_DISABLE=0
export NCCL_IB_GID_INDEX=3
相关启动参数修改
  • --config_file:改为对应的 YAML 文件路径,需要改成上述生成的 YAML 配置文件,注意各台开发机的启动脚本中该参数,需要改为各自开发机对应的 YAML 配置文件路径。
  • --model_paths:改为对应模型的本地路径,英博云在共享存储中提供多种内置模型及数据集,使用本地模型无需等待下载时间,该实例使用 Wan2.2-I2V-A14B 模型,可以将路径地址修改如下:
'[
    [
      "/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00001-of-00006.safetensors",
      "/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00002-of-00006.safetensors",
      "/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00003-of-00006.safetensors",
      "/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00004-of-00006.safetensors",
      "/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00005-of-00006.safetensors",
      "/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/high_noise_model/diffusion_pytorch_model-00006-of-00006.safetensors"
    ],
    "/public/huggingface-models/Wan-AI/Wan2.2-I2V-A14B/models_t5_umt5-xxl-enc-bf16.pth",
    "/public/huggingface-models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"
  ]'
  • --model_id_with_origin_paths:联网下载的 model_id 匹配,由于我们设置了本地路径参数 --model_paths,所以请删除该参数设置。

注意:

在4个node分别启动训练脚本

node0 开发机上运行:

bash Wan2.2-T2V-A14B_0.sh

在剩余 3 台开发机上依次运行:

bash Wan2.2-T2V-A14B_1.sh
bash Wan2.2-T2V-A14B_2.sh
bash Wan2.2-T2V-A14B_3.sh

等待训练结束,完成时间约10-20分钟