11import torch
22import torch .nn as nn
3- from transformers import AutoTokenizer
43from labml_nn .lora import Linear , Embedding
54
6- tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
7-
8- config = {
9- "layer_norm_epsilon" : 1e-05 ,
10- "n_embd" : 768 ,
11- "n_head" : 12 ,
12- "n_layer" : 12 ,
13- "n_positions" : 1024 ,
14- "vocab_size" : 50257 ,
15- "device" : "cuda"
16- }
17-
185
196class FFN (nn .Module ):
20- def __init__ (self , dim ):
7+ def __init__ (self , dim : int , n_embed : int , r : int ):
218 super ().__init__ ()
22- self .c_fc = Linear (config [ 'n_embd' ] , dim , r = 32 , bias = True )
23- self .c_proj = Linear (dim , config [ 'n_embd' ] , r = 32 , bias = True )
9+ self .c_fc = Linear (n_embed , dim , r = r , bias = True )
10+ self .c_proj = Linear (dim , n_embed , r = r , bias = True )
2411 self .act = nn .functional .gelu
2512
2613 def forward (self , hidden_states ):
@@ -31,15 +18,15 @@ def forward(self, hidden_states):
3118
3219
3320class MultiHeadAttention (nn .Module ):
34- def __init__ (self ):
21+ def __init__ (self , n_embed : int , r : int ):
3522 super ().__init__ ()
36- self .embed_dim = config [ 'n_embd' ]
37- self .num_heads = config [ 'n_head' ]
23+ self .embed_dim = n_embed
24+ self .num_heads = n_embed
3825 self .head_dim = self .embed_dim // self .num_heads
3926 self .split_size = self .embed_dim
4027
41- self .c_att = Linear (config [ 'n_embd' ], config [ 'n_embd' ] * 3 , r = 32 , bias = True )
42- self .c_proj = Linear (config [ 'n_embd' ], config [ 'n_embd' ] , r = 32 , bias = True )
28+ self .c_att = Linear (n_embed , n_embed * 3 , r = r , bias = True )
29+ self .c_proj = Linear (n_embed , n_embed , r = r , bias = True )
4330
4431 def _split_heads (self , tensor , num_heads , attn_head_size ):
4532 """
@@ -76,12 +63,12 @@ def forward(self, hidden_states):
7663
7764
7865class Block (nn .Module ):
79- def __init__ (self ):
66+ def __init__ (self , n_embed : int , layer_norm_epsilon : float , r : int ):
8067 super ().__init__ ()
81- self .pre_norm = nn .LayerNorm (config [ 'n_embd' ] , eps = config [ ' layer_norm_epsilon' ] )
82- self .attn = MultiHeadAttention ()
83- self .post_norm = nn .LayerNorm (config [ 'n_embd' ] , eps = config [ ' layer_norm_epsilon' ] )
84- self .ffn = FFN (config [ 'n_embd' ] * 4 )
68+ self .pre_norm = nn .LayerNorm (n_embed , eps = layer_norm_epsilon )
69+ self .attn = MultiHeadAttention (n_embed , r )
70+ self .post_norm = nn .LayerNorm (n_embed , eps = layer_norm_epsilon )
71+ self .ffn = FFN (n_embed * 4 , n_embed , r )
8572
8673 def forward (self , hidden_states ):
8774 residual = hidden_states
@@ -99,23 +86,27 @@ def forward(self, hidden_states):
9986
10087
10188class GPTModel (nn .Module ):
102- def __init__ (self ):
89+ def __init__ (self , layer_norm_epsilon : float , n_embd : int , n_layer : int , n_positions : int ,
90+ vocab_size : int , r : int , device : torch .device ):
10391 super ().__init__ ()
10492
105- self .token_embedding = Embedding (config ['vocab_size' ], config ['n_embd' ], r = 32 )
106- self .position_embedding = Embedding (config ['n_positions' ], config ['n_embd' ], r = 32 )
93+ self .token_embedding = Embedding (vocab_size , n_embd , r = r )
94+ self .position_embedding = Embedding (n_positions , n_embd , r = r )
95+
96+ self .blocks = nn .ModuleList ([Block (n_embd , layer_norm_epsilon , r = r )
97+ for _ in range (n_layer )])
10798
108- self .blocks = nn .ModuleList ([ Block () for _ in range ( config [ 'n_layer' ])] )
99+ self .final_norm = nn .LayerNorm ( n_embd , eps = layer_norm_epsilon )
109100
110- self .final_norm = nn . LayerNorm ( config [ ' n_embd' ], eps = config [ 'layer_norm_epsilon' ] )
101+ self .lm_head = Linear ( n_embd , vocab_size , r = r , bias = False )
111102
112- self .lm_head = Linear ( config [ 'n_embd' ], config [ 'vocab_size' ], r = 32 , bias = False )
103+ self .device = device
113104
114105 def forward (self , input_ids ):
115106 batch_size , input_shape = input_ids .size ()
116107
117108 token_embeddings = self .token_embedding (input_ids ) # B T C
118- position_ids = torch .arange (input_shape , device = config [ ' device' ] ) # T C
109+ position_ids = torch .arange (input_shape , device = self . device ) # T C
119110 position_embeddings = self .position_embedding (position_ids ) # B T C
120111
121112 hidden_states = token_embeddings + position_embeddings
0 commit comments