File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1515import torch .nn as nn
1616from torch .nn import functional as F
1717
18- # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
19- def new_gelu (x ):
20- """
21- Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
22- Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
23- """
24- return 0.5 * x * (1.0 + torch .tanh (math .sqrt (2.0 / math .pi ) * (x + 0.044715 * torch .pow (x , 3.0 ))))
25-
2618class LayerNorm (nn .Module ):
2719 """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
2820
@@ -88,12 +80,13 @@ class MLP(nn.Module):
8880 def __init__ (self , config ):
8981 super ().__init__ ()
9082 self .c_fc = nn .Linear (config .n_embd , 4 * config .n_embd , bias = config .bias )
83+ self .gelu = nn .GELU ()
9184 self .c_proj = nn .Linear (4 * config .n_embd , config .n_embd , bias = config .bias )
9285 self .dropout = nn .Dropout (config .dropout )
9386
9487 def forward (self , x ):
9588 x = self .c_fc (x )
96- x = new_gelu (x )
89+ x = self . gelu (x )
9790 x = self .c_proj (x )
9891 x = self .dropout (x )
9992 return x
You can’t perform that action at this time.
0 commit comments