-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
296 lines (262 loc) · 15.8 KB
/
train.py
File metadata and controls
296 lines (262 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import argparse
import json
import os
import random
import shutil
import sys
import time
import numpy as np
import torch
import wandb
from torch.backends import cudnn
from torch.utils.data import DataLoader
from torchio import transforms
from data.mri_dataset import MRIDataset, save_MRI, AddGaussianNoise
from trainers.mtransinr_trainer import MTransINRTrainer
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.deterministic = True
cudnn.benchmark = False
if __name__=='__main__':
# parse options
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# experiment specifics
parser.add_argument('--name', type=str, default='brats_t1ce', help='name of the experiment')
parser.add_argument('--config_file', type=str,default='./configs/brats.json')
parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
parser.add_argument('--model', type=str, default='pix2pix', help='which model to use')
parser.add_argument('--norm_G', type=str, default='instanceaffine', help='instance normalization or batch normalization')
parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
parser.add_argument('--encoder', type=str, default='resunet', help='`cones` for ConesEncoder or `resunet` for ResidualUNet3D')
parser.add_argument('--inr_type', type=str, default='relu', help='`relu` for LeakyReLU or `siren` for SIREN')
parser.add_argument('--context_window', type=int, default=None, help='context window for the context module')
# input/output sizes
parser.add_argument('--batchSize', type=int, default=8, help='input batch size')
parser.add_argument('--label_nc', type=int, default=3, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.')
parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels')
# Hyperparameters
parser.add_argument('--lr_width', type=int, default=64, help='low res stream strided conv number of channles')
parser.add_argument('--lr_max_width', type=int, default=1024, help='low res stream conv number of channles')
parser.add_argument('--lr_depth', type=int, default=7, help='low res stream number of conv layers')
parser.add_argument('--hr_width', type=int, default=64, help='high res stream number of MLP channles')
parser.add_argument('--hr_depth', type=int, default=5, help='high res stream number of MLP layers')
parser.add_argument('--latent_dim', type=int, default=256, help='high res stream number of MLP layers')
parser.add_argument('--reflection_pad', action='store_true', help='if specified, use reflection padding at lr stream')
parser.add_argument('--replicate_pad', action='store_true', help='if specified, use replicate padding at lr stream')
parser.add_argument('--netG', type=str, default='ASAPNets', help='selects model to use for netG')
parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
parser.add_argument('--hr_coor', choices=('cosine', 'None','siren'), default='cosine')
parser.add_argument('--add_coords_noise', action='store_true', help='Add Gaussian noise to the coordinates during training')
parser.add_argument('--nef', type=int, default=32, help='# of encoder filters in the first conv layer')
parser.add_argument('--cones_block_expansion', type=int, default=2, help='Block expansion for ConesEncoder')
parser.add_argument('--use_gan', action='store_true', help='enable training with an image encoder.')
# for training
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay')
parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')
parser.add_argument('--old_normalization', action='store_true', help='normalize data to [0, 1] instead of [-1, 1]')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--lr_scheduler', type=str, choices=('lambdalr', 'cosine', 'constant'), default='lambdalr', help='learning rate scheduler')
parser.add_argument('--lr_T0', type=int, default=20, help='T0 for cosine learning rate scheduler')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
parser.add_argument('--lambda_vgg', type=float, default=10.0, help='weight for vgg loss')
parser.add_argument('--lambda_MSE', type=float, default=10.0, help='weight for MSE loss')
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
parser.add_argument('--lambda_ll', type=float, default=10.0, help='weight for L1 loss')
parser.add_argument('--no_adv_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
parser.add_argument('--MSE_loss', action='store_true', help='if specified, use MSE loss')
parser.add_argument('--L1_loss', action='store_true', help='if specified, use L1 loss')
parser.add_argument('--MaskedL1_loss', action='store_true', help='if specified, use L1 loss with mask')
parser.add_argument('--alpha_MaskedL1', type=float, default=0.5, help='Weight for the masked L1 loss component (0 = masked version is not used, 1 = only masked version is used)')
parser.add_argument('--latent_code_regularization', action='store_true', help='if specified, use weight decay loss on the estimated parameters from LR')
parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
# for discriminators
parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)')
parser.add_argument('--lambda_kld', type=float, default=0.05)
parser.add_argument('--netD_subarch', type=str, default='n_layer', help='architecture of each discriminator')
parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to be used in multiscale')
parser.add_argument('--n_layers_D', type=int, default=4, help='# layers in each discriminator')
parser.add_argument('--ndf_max', type=int, default=512, help='maximal number of discriminator filters')
# print options to help debugging
opt = parser.parse_args()
assert not (opt.MaskedL1_loss and opt.L1_loss), 'MaskedL1_loss and L1_loss cannot be used together.'
opt.isTrain = True # train or test
setup_seed(100)
#################### set gpu ids ####################
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
assert len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0, \
"Batch size %d is wrong. It must be a multiple of # GPUs %d." \
% (opt.batchSize, len(opt.gpu_ids))
#################### print configs ###################
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
print(' '.join(sys.argv))
################ load the dataset ###################
with open(opt.config_file) as json_data:
try:
configs = json.load(json_data)
print(configs)
except json.JSONDecodeError:
print('invalid json format')
sys.exit()
#### create checkpoints directory ####
if not os.path.isdir(opt.checkpoints_dir):
os.mkdir(opt.checkpoints_dir)
experiment_dir = os.path.join(opt.checkpoints_dir, opt.name)
if not os.path.isdir(experiment_dir):
os.mkdir(experiment_dir)
if not os.path.isdir(os.path.join(experiment_dir, 'train_imgs')):
os.mkdir(os.path.join(experiment_dir, 'train_imgs'))
if opt.continue_train:
with open(f"{experiment_dir}/wandb_id.txt", "r") as f:
wandb_run_id = f.read()
wandb.init(
project="Image translation&segmentation",
name=opt.name,
config=vars(opt),
dir=experiment_dir,
resume="must",
id=wandb_run_id
)
start_epoch = None
else:
wandb.init(
project="Image translation&segmentation",
name=opt.name,
config=vars(opt),
# dir=experiment_dir
)
with open(f"{experiment_dir}/wandb_id.txt", "w") as f:
f.write(wandb.run.id)
start_epoch = 1
dataset_dict = configs['dataset']
img_height = dataset_dict['img_height']
img_width = dataset_dict['img_width']
img_depth = dataset_dict['img_depth']
data_root = dataset_dict['dataset_dir']
input_modal = dataset_dict['input_modalities']
output_modal = dataset_dict['output_modality']
modal_list = dataset_dict['modal_list']
train_test_split_file = dataset_dict['train_test_split_file']
clipping_per_modality = dataset_dict['clipping_per_modality']
brats_desired_shape = dataset_dict['brats_desired_shape'] if 'brats_desired_shape' in dataset_dict else None
transform_tr = transforms.Compose([
transforms.RandomGamma(log_gamma=(-2, 0.3), p=0.2),
transforms.RandomBlur(std=(0.5,1.0), p=0.3),
AddGaussianNoise(0, 0.02, 0.3)
])
USE_MASK = "prostatex" in opt.name
train_instance = MRIDataset(data_root, train_test_split_file, modal_list, \
input_modal, output_modal,(img_height, img_width, img_depth), transform_tr, False, True,
clipping_per_modality=clipping_per_modality, brats_desired_shape=brats_desired_shape,
use_mask=USE_MASK, old_normalization=opt.old_normalization)
# use_mask=opt.MaskedL1_loss)
print("dataset [%s] of size %d was created" %
(type(train_instance).__name__, len(train_instance)))
dataloader = DataLoader(
train_instance,
batch_size=opt.batchSize,
shuffle=True,
num_workers=int(opt.nThreads),
drop_last=opt.isTrain
)
datajson_dir = os.path.join(opt.checkpoints_dir, opt.name, 'data.json')
shutil.copy(opt.config_file, datajson_dir)
# initialize trainer
trainer = MTransINRTrainer(opt)
if start_epoch is None:
start_epoch = trainer.previous_epoch + 1
if opt.use_gan:
use_gan = True
else:
use_gan = False
batches_per_epoch = len(dataloader)
total_epochs = opt.niter + opt.niter_decay
best_loss = np.inf
for epoch in range(start_epoch, total_epochs + 1):
epoch_start_time = time.time()
print("epoch:%d" % epoch)
rec_loss = 0
latent_loss = 0
gan_loss = 0
gan_feat_loss = 0
for data_i in dataloader:
# train generator
g_losses, pred_img = trainer.run_generator_one_step(data_i, use_gan)
rec_loss += g_losses['L1'].item()
gan_loss += g_losses['GAN'].item()
gan_feat_loss += g_losses['GAN_Feat'].item()
latent_loss += g_losses['latent_loss'].item()
# train discriminator
if use_gan:
trainer.run_discriminator_one_step(data_i)
lr_G_before_update = trainer.lr_scheduler_G.get_last_lr()[0]
lr_D_before_update = trainer.lr_scheduler_D.get_last_lr()[0]
trainer.update_learning_rate(epoch-1)
print('Updating learning rate: lr_G:%f, lr_D:%f' % \
(trainer.lr_scheduler_G.get_last_lr()[0], trainer.lr_scheduler_D.get_last_lr()[0]))
logs = {}
logs['rec_loss'] = rec_loss / batches_per_epoch # Reconstruction loss = L1 loss
logs['latent_loss'] = latent_loss / batches_per_epoch # Regularization loss = torch.mean(lr_features ** 2) of the hypernetwork (they call it encoder of the generator)
logs['GAN'] = gan_loss / batches_per_epoch # Adversarial loss = Hinge loss
logs['GAN_Feat'] = gan_feat_loss / batches_per_epoch # Feature matching loss (read article for more details)
logs['lr_G'] = lr_G_before_update
logs['lr_D'] = lr_D_before_update
is_best = (rec_loss / batches_per_epoch) < best_loss
if is_best:
best_loss = rec_loss / batches_per_epoch
print(f'Prediction image shape: {pred_img.shape}')
middle_slice = pred_img.shape[-1] // 2
for i in range(data_i['image'].shape[1]):
logs['real_img_'+output_modal[i]] = wandb.Image(data_i['image'][:,i,:,:,middle_slice].unsqueeze(1))
if epoch % 10 == 0:
for b in range(data_i['image'].shape[0]):
save_MRI(data_i['image'][b,i,:,:,:], os.path.join(experiment_dir, 'train_imgs', 'epoch_' + str(epoch) + '_real_'+str(b)+'_'+output_modal[i]+'.nii.gz'))
for i in range(pred_img.shape[1]):
logs['synthesized_img_'+output_modal[i]] = wandb.Image(pred_img[:,i,:,:,middle_slice].unsqueeze(1))
if epoch % 10 == 0:
for b in range(pred_img.shape[0]):
save_MRI(pred_img[b,i,:,:,:], os.path.join(experiment_dir, 'train_imgs', 'epoch_' + str(epoch) + '_synthesized_'+str(b)+'_'+output_modal[i]+'.nii.gz'))
for i in range(data_i['label'].shape[1]):
logs['input_img_'+input_modal[i]] = wandb.Image(data_i['label'][:,i,:,:,middle_slice].unsqueeze(1))
if epoch % 10 == 0:
for b in range(data_i['label'].shape[0]):
save_MRI(data_i['label'][b,i,:,:,:], os.path.join(experiment_dir, 'train_imgs', 'epoch_' + str(epoch) + '_input_'+str(b)+'_'+input_modal[i]+'.nii.gz'))
wandb.log(logs)
time_per_epoch = time.time() - epoch_start_time
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, total_epochs, time_per_epoch))
trainer.save(epoch, is_best=is_best)
print('Training was successfully finished.')
wandb.finish()