Skip to content

Commit 39ae397

Browse files
committed
Remove pos unsqueeze(0)
1 parent 7fe4a09 commit 39ae397

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,11 +178,11 @@ def forward(self, idx, targets=None):
178178
device = idx.device
179179
b, t = idx.size()
180180
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
181-
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
181+
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
182182

183183
# forward the GPT model itself
184184
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
185-
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
185+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
186186
x = self.transformer.drop(tok_emb + pos_emb)
187187
for block in self.transformer.h:
188188
x = block(x)

0 commit comments

Comments
 (0)