@@ -77,13 +77,11 @@ def __init__(self, dim):
7777 self .c_fc = Conv1D (dim , config ['n_embd' ])
7878 self .c_proj = Conv1D (config ['n_embd' ], dim )
7979 self .act = nn .functional .gelu
80- self .dropout = nn .Dropout (config ['resid_pdrop' ])
8180
8281 def forward (self , hidden_states ):
8382 hidden_states = self .c_fc (hidden_states )
8483 hidden_states = self .act (hidden_states )
8584 hidden_states = self .c_proj (hidden_states )
86- hidden_states = self .dropout (hidden_states )
8785 return hidden_states
8886
8987
@@ -98,9 +96,6 @@ def __init__(self):
9896 self .c_att = Conv1D (config ['n_embd' ] * 3 , config ['n_embd' ])
9997 self .c_proj = Conv1D (config ['n_embd' ], config ['n_embd' ])
10098
101- self .resid_dropout = nn .Dropout (config ['resid_pdrop' ])
102- self .attn_dropout = nn .Dropout (config ['attn_pdrop' ])
103-
10499 def _split_heads (self , tensor , num_heads , attn_head_size ):
105100 """
106101 Splits hidden_size dim into attn_head_size and num_heads
@@ -123,7 +118,7 @@ def forward(self, hidden_states):
123118 key ,
124119 value ,
125120 attn_mask = None ,
126- dropout_p = self . attn_dropout . p if self . training else 0.0 ,
121+ dropout_p = 0.0 ,
127122 is_causal = True , # for the triangular mask
128123 )
129124
@@ -132,7 +127,6 @@ def forward(self, hidden_states):
132127 attn_output = attn_output .view (batch_size , seq_length , self .embed_dim )
133128
134129 attn_output = self .c_proj (attn_output )
135- attn_output = self .resid_dropout (attn_output )
136130
137131 return attn_output
138132
@@ -168,8 +162,6 @@ def __init__(self):
168162 self .token_embedding = nn .Embedding (config ['vocab_size' ], config ['n_embd' ])
169163 self .position_embedding = nn .Embedding (config ['n_positions' ], config ['n_embd' ])
170164
171- self .dropout = nn .Dropout (p = config ['embd_pdrop' ], inplace = False )
172-
173165 self .blocks = nn .ModuleList ([Block () for _ in range (config ['n_layer' ])])
174166
175167 self .final_norm = nn .LayerNorm (config ['n_embd' ], eps = config ['layer_norm_epsilon' ])
@@ -183,9 +175,7 @@ def forward(self, input_ids):
183175 position_ids = torch .arange (input_shape ) # T C
184176 position_embeddings = self .position_embedding (position_ids ) # B T C
185177
186- embeddings = token_embeddings + position_embeddings
187-
188- hidden_states = self .dropout (embeddings )
178+ hidden_states = token_embeddings + position_embeddings
189179
190180 for block in self .blocks :
191181 hidden_states = block (hidden_states )
0 commit comments