2626
2727import torch
2828from torch import nn
29- import torch .nn .functional as F
3029
3130from labml_helpers .module import Module
3231
@@ -92,13 +91,14 @@ class ResidualBlock(Module):
9291 Each resolution is processed with two residual blocks.
9392 """
9493
95- def __init__ (self , in_channels : int , out_channels : int , time_channels : int , n_groups : int = 32 , dropout_rate : float = 0.1 ):
94+ def __init__ (self , in_channels : int , out_channels : int , time_channels : int ,
95+ n_groups : int = 32 , dropout : float = 0.1 ):
9696 """
9797 * `in_channels` is the number of input channels
9898 * `out_channels` is the number of input channels
9999 * `time_channels` is the number channels in the time step ($t$) embeddings
100100 * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
101- * `dropout_rate ` is the dropout rate
101+ * `dropout ` is the dropout rate
102102 """
103103 super ().__init__ ()
104104 # Group normalization and the first convolution layer
@@ -122,6 +122,8 @@ def __init__(self, in_channels: int, out_channels: int, time_channels: int, n_gr
122122 self .time_emb = nn .Linear (time_channels , out_channels )
123123 self .time_act = Swish ()
124124
125+ self .dropout = nn .Dropout (dropout )
126+
125127 def forward (self , x : torch .Tensor , t : torch .Tensor ):
126128 """
127129 * `x` has shape `[batch_size, in_channels, height, width]`
@@ -132,7 +134,7 @@ def forward(self, x: torch.Tensor, t: torch.Tensor):
132134 # Add time embeddings
133135 h += self .time_emb (self .time_act (t ))[:, :, None , None ]
134136 # Second convolution layer
135- h = self .conv2 (F .dropout (self .act2 (self .norm2 (h )), self . dropout_rate ))
137+ h = self .conv2 (self .dropout (self .act2 (self .norm2 (h ))))
136138
137139 # Add the shortcut connection and return
138140 return h + self .shortcut (x )
0 commit comments