Skip to content

Commit d198a44

Browse files
committed
fix dropout ddpm.unet
1 parent 4db39b5 commit d198a44

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

labml_nn/diffusion/ddpm/unet.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import torch
2828
from torch import nn
29-
import torch.nn.functional as F
3029

3130
from labml_helpers.module import Module
3231

@@ -92,13 +91,14 @@ class ResidualBlock(Module):
9291
Each resolution is processed with two residual blocks.
9392
"""
9493

95-
def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_groups: int = 32, dropout_rate: float = 0.1):
94+
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
95+
n_groups: int = 32, dropout: float = 0.1):
9696
"""
9797
* `in_channels` is the number of input channels
9898
* `out_channels` is the number of input channels
9999
* `time_channels` is the number channels in the time step ($t$) embeddings
100100
* `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
101-
* `dropout_rate` is the dropout rate
101+
* `dropout` is the dropout rate
102102
"""
103103
super().__init__()
104104
# Group normalization and the first convolution layer
@@ -122,6 +122,8 @@ def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_gr
122122
self.time_emb = nn.Linear(time_channels, out_channels)
123123
self.time_act = Swish()
124124

125+
self.dropout = nn.Dropout(dropout)
126+
125127
def forward(self, x: torch.Tensor, t: torch.Tensor):
126128
"""
127129
* `x` has shape `[batch_size, in_channels, height, width]`
@@ -132,7 +134,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor):
132134
# Add time embeddings
133135
h += self.time_emb(self.time_act(t))[:, :, None, None]
134136
# Second convolution layer
135-
h = self.conv2(F.dropout(self.act2(self.norm2(h)), self.dropout_rate))
137+
h = self.conv2(self.dropout(self.act2(self.norm2(h))))
136138

137139
# Add the shortcut connection and return
138140
return h + self.shortcut(x)

0 commit comments

Comments
 (0)