We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents a8ddc7d + ffafaf1 commit 334fb05Copy full SHA for 334fb05
1 file changed
labml_nn/transformers/vit/__init__.py
@@ -191,11 +191,11 @@ def forward(self, x: torch.Tensor):
191
"""
192
# Get patch embeddings. This gives a tensor of shape `[patches, batch_size, d_model]`
193
x = self.patch_emb(x)
194
- # Add positional embeddings
195
- x = self.pos_emb(x)
196
# Concatenate the `[CLS]` token embeddings before feeding the transformer
197
cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
198
x = torch.cat([cls_token_emb, x])
+ # Add positional embeddings
+ x = self.pos_emb(x)
199
200
# Pass through transformer layers with no attention masking
201
for layer in self.transformer_layers:
0 commit comments