Skip to content

Commit f08abb4

Browse files
authored
Merge pull request karpathy#274 from apivovarov/gelu
Use nn.GELU - 1.27x faster training
2 parents 18ee6b6 + 594068e commit f08abb4

1 file changed

Lines changed: 2 additions & 9 deletions

File tree

model.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,6 @@
1515
import torch.nn as nn
1616
from torch.nn import functional as F
1717

18-
# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
19-
def new_gelu(x):
20-
"""
21-
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
22-
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
23-
"""
24-
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
25-
2618
class LayerNorm(nn.Module):
2719
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
2820

@@ -88,12 +80,13 @@ class MLP(nn.Module):
8880
def __init__(self, config):
8981
super().__init__()
9082
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
83+
self.gelu = nn.GELU()
9184
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
9285
self.dropout = nn.Dropout(config.dropout)
9386

9487
def forward(self, x):
9588
x = self.c_fc(x)
96-
x = new_gelu(x)
89+
x = self.gelu(x)
9790
x = self.c_proj(x)
9891
x = self.dropout(x)
9992
return x

0 commit comments

Comments
 (0)