it could only leverage gpu 0 for training, toke me a lot of time for debugging. Here is the version for distrubuted training
parser.add_argument('--model', type=str, default='pix2pix', help='which model to use')
parser.add_argument('--norm_G', type=str, default='instanceaffine', help='instance normalization or batch normalization')
parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
parser.add_argument('--encoder', type=str, default='resunet', help='`cones` for ConesEncoder or `resunet` for ResidualUNet3D')
parser.add_argument('--inr_type', type=str, default='relu', help='`relu` for LeakyReLU or `siren` for SIREN')
parser.add_argument('--context_window', type=int, default=None, help='context window for the context module')
# input/output sizes
parser.add_argument('--batchSize', type=int, default=1, help='input batch size (total across all GPUs)')
parser.add_argument('--label_nc', type=int, default=1, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.')
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
# Hyperparameters
parser.add_argument('--lr_width', type=int, default=64, help='low res stream strided conv number of channles')
parser.add_argument('--lr_max_width', type=int, default=1024, help='low res stream conv number of channles')
parser.add_argument('--lr_depth', type=int, default=7, help='low res stream number of conv layers')
parser.add_argument('--hr_width', type=int, default=64, help='high res stream number of MLP channles')
parser.add_argument('--hr_depth', type=int, default=5, help='high res stream number of MLP layers')
parser.add_argument('--latent_dim', type=int, default=256, help='high res stream number of MLP layers')
parser.add_argument('--reflection_pad', action='store_true', help='if specified, use reflection padding at lr stream')
parser.add_argument('--replicate_pad', action='store_true', help='if specified, use replicate padding at lr stream')
parser.add_argument('--netG', type=str, default='ASAPNets', help='selects model to use for netG')
parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
parser.add_argument('--hr_coor', choices=('cosine', 'None','siren'), default='cosine')
parser.add_argument('--add_coords_noise', action='store_true', help='Add Gaussian noise to the coordinates during training')
parser.add_argument('--nef', type=int, default=32, help='# of encoder filters in the first conv layer')
parser.add_argument('--cones_block_expansion', type=int, default=2, help='Block expansion for ConesEncoder')
parser.add_argument('--use_gan', action='store_true', help='enable training with an image encoder.')
# for training
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--niter', type=int, default=200, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay')
parser.add_argument('--niter_decay', type=int, default=200, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')
parser.add_argument('--old_normalization', action='store_true', help='normalize data to [0, 1] instead of [-1, 1]')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--lr_scheduler', type=str, choices=('lambdalr', 'cosine', 'constant'), default='lambdalr', help='learning rate scheduler')
parser.add_argument('--lr_T0', type=int, default=20, help='T0 for cosine learning rate scheduler')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss')
parser.add_argument('--lambda_MSE', type=float, default=10.0, help='weight for MSE loss')
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
parser.add_argument('--lambda_ll', type=float, default=10.0, help='weight for L1 loss')
parser.add_argument('--no_adv_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
parser.add_argument('--MSE_loss', action='store_true', help='if specified, use MSE loss')
parser.add_argument('--L1_loss', action='store_true', help='if specified, use L1 loss')
parser.add_argument('--MaskedL1_loss', action='store_true', help='if specified, use L1 loss with mask')
parser.add_argument('--alpha_MaskedL1', type=float, default=0.5, help='Weight for the masked L1 loss component (0 = masked version is not used, 1 = only masked version is used)')
parser.add_argument('--latent_code_regularization', action='store_true', help='if specified, use weight decay loss on the estimated parameters from LR')
parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
# for discriminators
parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)')
parser.add_argument('--lambda_kld', type=float, default=0.05)
parser.add_argument('--netD_subarch', type=str, default='n_layer', help='architecture of each discriminator')
parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to be used in multiscale')
parser.add_argument('--n_layers_D', type=int, default=4, help='# layers in each discriminator')
parser.add_argument('--ndf_max', type=int, default=512, help='maximal number of discriminator filters')
opt = parser.parse_args()
assert not (opt.MaskedL1_loss and opt.L1_loss), 'MaskedL1_loss and L1_loss cannot be used together.'
opt.isTrain = True
# 将 gpu_ids 从字符串转换为列表
if isinstance(opt.gpu_ids, str):
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
# 初始化分布式环境
rank, world_size, local_rank = setup_distributed()
setup_seed(42 + rank)
# 如果是分布式训练,设置当前GPU
if world_size > 1:
# 在分布式训练中,每个进程只使用自己的GPU
torch.cuda.set_device(local_rank)
# 设置 opt.gpu_ids 为当前进程使用的GPU
opt.gpu_ids = [local_rank]
#################### print configs ###################
if rank == 0: # 只在主进程打印配置
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
print(' '.join(sys.argv))
################ load the dataset ###################
with open(opt.config_file) as json_data:
try:
configs = json.load(json_data)
if rank == 0:
print(configs)
except json.JSONDecodeError:
print('invalid json format')
sys.exit()
#### create checkpoints directory ####
if rank == 0: # 只在主进程创建目录
if not os.path.isdir(opt.checkpoints_dir):
os.mkdir(opt.checkpoints_dir)
experiment_dir = os.path.join(opt.checkpoints_dir, opt.name)
if not os.path.isdir(experiment_dir):
os.mkdir(experiment_dir)
if not os.path.isdir(os.path.join(experiment_dir, 'train_imgs')):
os.mkdir(os.path.join(experiment_dir, 'train_imgs'))
# 所有进程等待目录创建完成
if world_size > 1:
dist.barrier()
experiment_dir = os.path.join(opt.checkpoints_dir, opt.name)
# 继续训练逻辑
if opt.continue_train:
# wandb初始化代码...
start_iter = None
else:
# wandb初始化代码...
start_iter = 1
dataset_dict = configs['dataset']
img_height = dataset_dict['img_height']
img_width = dataset_dict['img_width']
img_depth = dataset_dict['img_depth']
data_root = dataset_dict['dataset_dir']
input_modal = dataset_dict['input_modalities']
output_modal = dataset_dict['output_modality']
modal_list = dataset_dict['modal_list']
train_test_split_file = dataset_dict['train_test_split_file']
clipping_per_modality = dataset_dict['clipping_per_modality']
brats_desired_shape = dataset_dict['brats_desired_shape'] if 'brats_desired_shape' in dataset_dict else None
transform_tr = transforms.Compose([
transforms.RandomGamma(log_gamma=(-2, 0.3), p=0.2),
transforms.RandomBlur(std=(0.5,1.0), p=0.3),
AddGaussianNoise(0, 0.02, 0.3)
])
USE_MASK = "prostatex" in opt.name
# 创建数据集
train_instance = MRIDataset(data_root, train_test_split_file, modal_list, \
input_modal, output_modal,(img_height, img_width, img_depth), transform_tr, False, True,
clipping_per_modality=clipping_per_modality, brats_desired_shape=brats_desired_shape,
use_mask=USE_MASK, old_normalization=opt.old_normalization)
if rank == 0:
print("dataset [%s] of size %d was created" %
(type(train_instance).__name__, len(train_instance)))
# 创建分布式采样器
if world_size > 1:
train_sampler = DistributedSampler(
train_instance,
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=100,
drop_last=True # 添加这个参数
)
per_gpu_batch_size = max(1, opt.batchSize // world_size)
else:
train_sampler = None
per_gpu_batch_size = opt.batchSize
# 修复DataLoader
dataloader = DataLoader(
train_instance,
batch_size=per_gpu_batch_size,
sampler=train_sampler,
shuffle=(train_sampler is None),
num_workers=int(opt.nThreads),
drop_last=True, # 固定为True,确保每个进程batch数相同
pin_memory=True
)
# 只在主进程保存配置文件
if rank == 0:
datajson_dir = os.path.join(opt.checkpoints_dir, opt.name, 'data.json')
shutil.copy(opt.config_file, datajson_dir)
# 初始化trainer
trainer = MTransINRTrainer(opt)
# 如果是分布式训练,包装模型为DistributedDataParallel
if world_size > 1 and hasattr(trainer, 'model'):
trainer.model = DistributedDataParallel(
trainer.model.cuda(),
device_ids=[local_rank],
output_device=local_rank,
find_unused_parameters=True
)
# ========== 修改这里:从epoch改为iteration ==========
# 计算iteration参数
total_samples = len(train_instance)
effective_batch_size = opt.batchSize * world_size
iterations_per_epoch = max(1, total_samples // effective_batch_size)
# 总epoch数
total_epochs = opt.niter + opt.niter_decay
# 总iteration次数
total_iterations = total_epochs * iterations_per_epoch
if rank == 0:
print(f"\n========== 训练参数 ==========")
print(f"总样本数: {total_samples}")
print(f"每GPU批大小: {per_gpu_batch_size}")
print(f"总批大小: {effective_batch_size}")
print(f"每epoch的iteration数: {iterations_per_epoch}")
print(f"总epoch数: {total_epochs}")
print(f"总iteration数: {total_iterations}")
print(f"============================\n")
# 设置起始iteration
if start_iter is None:
start_iter = trainer.previous_iteration + 1 if hasattr(trainer, 'previous_iteration') else 1
if opt.use_gan:
use_gan = True
else:
use_gan = False
best_loss = np.inf
# ========== 修改这里:改为iteration循环 ==========
#iteration = start_iter - 1 # 从0开始计数
from tqdm import tqdm
# ========== 修改这里:改为iteration循环 ==========
iteration = start_iter - 1 # 从0开始计数
save_interval = 100 # 每100个iteration保存一次
# 只在主进程显示进度条
if rank == 0:
pbar = tqdm(total=total_iterations, desc="训练进度")
while iteration < total_iterations:
iteration += 1
if rank == 0:
pbar.update(1)
# 计算当前"epoch"(用于sampler设置)
current_epoch = iteration // iterations_per_epoch
# 设置sampler的epoch(每个epoch开始时)
if world_size > 1 and train_sampler is not None and iteration % iterations_per_epoch == 1:
train_sampler.set_epoch(current_epoch)
# 获取当前batch的数据
# 我们需要一个无限迭代器
if iteration == start_iter or not hasattr(main, 'data_iter'):
# 创建迭代器
main.data_iter = iter(dataloader)
try:
data_i = next(main.data_iter)
except StopIteration:
# 重新创建迭代器
main.data_iter = iter(dataloader)
data_i = next(main.data_iter)
# 将数据移动到当前GPU
for key in data_i:
if isinstance(data_i[key], torch.Tensor):
data_i[key] = data_i[key].cuda(non_blocking=True)
# train generator
g_losses, pred_img = trainer.run_generator_one_step(data_i, use_gan)
# train discriminator
if use_gan:
trainer.run_discriminator_one_step(data_i)
#print(iteration)
# ========== 每100个iteration记录一次日志 ==========
if iteration % 100 == 0:
# 同步所有进程的损失
if world_size > 1:
rec_loss_tensor = torch.tensor(g_losses['L1'].item()).cuda()
gan_loss_tensor = torch.tensor(g_losses['GAN'].item()).cuda()
gan_feat_loss_tensor = torch.tensor(g_losses['GAN_Feat'].item()).cuda()
latent_loss_tensor = torch.tensor(g_losses['latent_loss'].item()).cuda()
dist.all_reduce(rec_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(gan_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(gan_feat_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(latent_loss_tensor, op=dist.ReduceOp.SUM)
avg_rec_loss = rec_loss_tensor.item() / world_size
avg_gan_loss = gan_loss_tensor.item() / world_size
avg_gan_feat_loss = gan_feat_loss_tensor.item() / world_size
avg_latent_loss = latent_loss_tensor.item() / world_size
else:
avg_rec_loss = g_losses['L1'].item()
avg_gan_loss = g_losses['GAN'].item()
avg_gan_feat_loss = g_losses['GAN_Feat'].item()
avg_latent_loss = g_losses['latent_loss'].item()
# 更新学习率(基于iteration)
if hasattr(trainer, 'update_learning_rate'):
trainer.update_learning_rate(iteration)
# 只在主进程记录日志和保存
if rank == 0:
logs = {}
logs['rec_loss'] = avg_rec_loss
logs['latent_loss'] = avg_latent_loss
logs['GAN'] = avg_gan_loss
logs['GAN_Feat'] = avg_gan_feat_loss
# 获取学习率
if hasattr(trainer, 'lr_scheduler_G'):
lr_G_before_update = trainer.lr_scheduler_G.get_last_lr()[0]
logs['lr_G'] = lr_G_before_update
else:
lr_G_before_update = 0
if use_gan and hasattr(trainer, 'lr_scheduler_D'):
lr_D_before_update = trainer.lr_scheduler_D.get_last_lr()[0]
logs['lr_D'] = lr_D_before_update
else:
lr_D_before_update = 0
print(f'Iteration {iteration}/{total_iterations}:')
print(f' rec_loss={avg_rec_loss:.4f}, latent_loss={avg_latent_loss:.4f}')
print(f' GAN={avg_gan_loss:.4f}, GAN_Feat={avg_gan_feat_loss:.4f}')
print(f' lr_G={lr_G_before_update:.6f}, lr_D={lr_D_before_update:.6f}')
is_best = avg_rec_loss < best_loss
if is_best:
best_loss = avg_rec_loss
print(f' 🎯 新的最佳损失: {best_loss:.4f}')
# 每1000个iteration保存一次图像(可选)
# if iteration % 1000 == 0:
# if 'image' in data_i and 'label' in data_i:
# for i in range(data_i['image'].shape[1]):
# for b in range(min(2, data_i['image'].shape[0])): # 只保存前2个样本
# save_MRI(data_i['image'][b,i,:,:,:],
# os.path.join(experiment_dir, 'train_imgs',
# f'iter_{iteration}_real_{b}_{output_modal[i]}.nii.gz'))
# for i in range(pred_img.shape[1]):
# for b in range(min(2, pred_img.shape[0])):
# save_MRI(pred_img[b,i,:,:,:],
# os.path.join(experiment_dir, 'train_imgs',
# f'iter_{iteration}_synthesized_{b}_{output_modal[i]}.nii.gz'))
# for i in range(data_i['label'].shape[1]):
# for b in range(min(2, data_i['label'].shape[0])):
# save_MRI(data_i['label'][b,i,:,:,:],
# os.path.join(experiment_dir, 'train_imgs',
# f'iter_{iteration}_input_{b}_{input_modal[i]}.nii.gz'))
# print(f' 💾 已保存图像到 {experiment_dir}/train_imgs/')
# 保存模型(基于iteration)
trainer.save(iteration, is_best=True)
print(f' 💾 已保存模型检查点')
# 所有进程等待主进程保存完成
if world_size > 1:
dist.barrier()
# 每500个iteration清理一次GPU缓存
if iteration % 500 == 0:
torch.cuda.empty_cache()
# 显示进度(每10个iteration)
if iteration % 10 == 0 and rank == 0:
print(f'进度: {iteration}/{total_iterations} ({iteration/total_iterations*100:.1f}%)')
# 训练结束
if rank == 0:
print('\n' + '='*50)
print(f'训练完成! 总共 {total_iterations} 个iteration')
print(f'最佳损失: {best_loss:.4f}')
print('='*50)
# 保存最终模型
if rank == 0:
trainer.save(total_iterations, is_best=False, is_final=True)
print('💾 已保存最终模型')
if rank == 0:
pbar.close()
# 清理分布式环境
if world_size > 1:
dist.destroy_process_group()
it could only leverage gpu 0 for training, toke me a lot of time for debugging. Here is the version for distrubuted training
import argparse
import json
import os
import random
import shutil
import sys
import time
import numpy as np
import torch
import wandb
from torch.backends import cudnn
from torch.utils.data import DataLoader, DistributedSampler
from torchio import transforms
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from data.mri_dataset import MRIDataset, save_MRI, AddGaussianNoise
from trainers.mtransinr_trainer import MTransINRTrainer
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.deterministic = True
cudnn.benchmark = False
def setup_distributed():
"""初始化分布式训练环境"""
# 检查是否使用torchrun启动
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
local_rank = int(os.environ['LOCAL_RANK'])
else:
# 单GPU模式
rank = 0
world_size = 1
local_rank = 0
def main():
"""主训练函数"""
# parse options
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# experiment specifics
parser.add_argument('--name', type=str, default='brats_t1', help='name of the experiment')
parser.add_argument('--config_file', type=str, default='./configs/brats_t1_to_t2pdmra.json')
parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data')
parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default')
parser.add_argument('--gpu_ids', type=str, default='0,1,2,3,4,5,6,7', help='gpu ids for single GPU training (ignored in distributed mode)')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
if name == 'main':
main()