Skip to content

Commit 334fb05

Browse files
authored
Merge pull request #221 from lizhuoq/fix
fix: fix cls_token bug in vit.
2 parents a8ddc7d + ffafaf1 commit 334fb05

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

labml_nn/transformers/vit/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,11 @@ def forward(self, x: torch.Tensor):
191191
"""
192192
# Get patch embeddings. This gives a tensor of shape `[patches, batch_size, d_model]`
193193
x = self.patch_emb(x)
194-
# Add positional embeddings
195-
x = self.pos_emb(x)
196194
# Concatenate the `[CLS]` token embeddings before feeding the transformer
197195
cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
198196
x = torch.cat([cls_token_emb, x])
197+
# Add positional embeddings
198+
x = self.pos_emb(x)
199199

200200
# Pass through transformer layers with no attention masking
201201
for layer in self.transformer_layers:

0 commit comments

Comments
 (0)