Skip to content

Commit 906cf71

Browse files
committed
update for new version of torch
1 parent ce55cc9 commit 906cf71

1 file changed

Lines changed: 56 additions & 50 deletions

File tree

tutorial-contents/504_batch_normalization.py

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torch import nn
1212
from torch.nn import init
1313
import torch.utils.data as Data
14-
import torch.nn.functional as F
1514
import matplotlib.pyplot as plt
1615
import numpy as np
1716

@@ -24,7 +23,7 @@
2423
EPOCH = 12
2524
LR = 0.03
2625
N_HIDDEN = 8
27-
ACTIVATION = F.tanh
26+
ACTIVATION = torch.tanh
2827
B_INIT = -0.2 # use a bad bias constant initializer
2928

3029
# training data
@@ -48,6 +47,7 @@
4847
plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train')
4948
plt.legend(loc='upper left')
5049

50+
5151
class Net(nn.Module):
5252
def __init__(self, batch_normalization=False):
5353
super(Net, self).__init__()
@@ -89,20 +89,20 @@ def forward(self, x):
8989

9090
nets = [Net(batch_normalization=False), Net(batch_normalization=True)]
9191

92-
print(*nets) # print net architecture
92+
# print(*nets) # print net architecture
9393

9494
opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets]
9595

9696
loss_func = torch.nn.MSELoss()
9797

98-
f, axs = plt.subplots(4, N_HIDDEN+1, figsize=(10, 5))
99-
plt.ion() # something about plotting
100-
plt.show()
98+
10199
def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
102-
for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):
100+
for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):
103101
[a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]
104-
if i == 0: p_range = (-7, 10);the_range = (-7, 10)
105-
else:p_range = (-4, 4);the_range = (-1, 1)
102+
if i == 0:
103+
p_range = (-7, 10);the_range = (-7, 10)
104+
else:
105+
p_range = (-4, 4);the_range = (-1, 1)
106106
ax_pa.set_title('L' + str(i))
107107
ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)
108108
ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF')
@@ -111,44 +111,50 @@ def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
111111
axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act')
112112
plt.pause(0.01)
113113

114-
# training
115-
losses = [[], []] # recode loss for two networks
116-
for epoch in range(EPOCH):
117-
print('Epoch: ', epoch)
118-
layer_inputs, pre_acts = [], []
119-
for net, l in zip(nets, losses):
120-
net.eval() # set eval mode to fix moving_mean and moving_var
121-
pred, layer_input, pre_act = net(test_x)
122-
l.append(loss_func(pred, test_y).data[0])
123-
layer_inputs.append(layer_input)
124-
pre_acts.append(pre_act)
125-
net.train() # free moving_mean and moving_var
126-
plot_histogram(*layer_inputs, *pre_acts) # plot histogram
127-
128-
for step, (b_x, b_y) in enumerate(train_loader):
129-
for net, opt in zip(nets, opts): # train for each network
130-
pred, _, _ = net(b_x)
131-
loss = loss_func(pred, b_y)
132-
opt.zero_grad()
133-
loss.backward()
134-
opt.step() # it will also learns the parameters in Batch Normalization
135-
136-
137-
plt.ioff()
138-
139-
# plot training loss
140-
plt.figure(2)
141-
plt.plot(losses[0], c='#FF9359', lw=3, label='Original')
142-
plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')
143-
plt.xlabel('step');plt.ylabel('test loss');plt.ylim((0, 2000));plt.legend(loc='best')
144-
145-
# evaluation
146-
# set net to eval mode to freeze the parameters in batch normalization layers
147-
[net.eval() for net in nets] # set eval mode to fix moving_mean and moving_var
148-
preds = [net(test_x)[0] for net in nets]
149-
plt.figure(3)
150-
plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original')
151-
plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization')
152-
plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')
153-
plt.legend(loc='best')
154-
plt.show()
114+
115+
if __name__ == "__main__":
116+
f, axs = plt.subplots(4, N_HIDDEN + 1, figsize=(10, 5))
117+
plt.ion() # something about plotting
118+
plt.show()
119+
120+
# training
121+
losses = [[], []] # recode loss for two networks
122+
123+
for epoch in range(EPOCH):
124+
print('Epoch: ', epoch)
125+
layer_inputs, pre_acts = [], []
126+
for net, l in zip(nets, losses):
127+
net.eval() # set eval mode to fix moving_mean and moving_var
128+
pred, layer_input, pre_act = net(test_x)
129+
l.append(loss_func(pred, test_y).data.item())
130+
layer_inputs.append(layer_input)
131+
pre_acts.append(pre_act)
132+
net.train() # free moving_mean and moving_var
133+
plot_histogram(*layer_inputs, *pre_acts) # plot histogram
134+
135+
for step, (b_x, b_y) in enumerate(train_loader):
136+
for net, opt in zip(nets, opts): # train for each network
137+
pred, _, _ = net(b_x)
138+
loss = loss_func(pred, b_y)
139+
opt.zero_grad()
140+
loss.backward()
141+
opt.step() # it will also learns the parameters in Batch Normalization
142+
143+
plt.ioff()
144+
145+
# plot training loss
146+
plt.figure(2)
147+
plt.plot(losses[0], c='#FF9359', lw=3, label='Original')
148+
plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')
149+
plt.xlabel('step');plt.ylabel('test loss');plt.ylim((0, 2000));plt.legend(loc='best')
150+
151+
# evaluation
152+
# set net to eval mode to freeze the parameters in batch normalization layers
153+
[net.eval() for net in nets] # set eval mode to fix moving_mean and moving_var
154+
preds = [net(test_x)[0] for net in nets]
155+
plt.figure(3)
156+
plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original')
157+
plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization')
158+
plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')
159+
plt.legend(loc='best')
160+
plt.show()

0 commit comments

Comments
 (0)