Skip to content

Commit 3ec5fa9

Browse files
committed
fix typo mha
1 parent 05321d6 commit 3ec5fa9

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

docs/transformers/mha.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ <h2>Prepare for multi-head attention</h2>
208208
<a href='#section-10'>#</a>
209209
</div>
210210
<p>Output has shape <code class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">d_k</span><span class="p">]</span></code>
211-
or <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
211+
or <code class="highlight"><span></span><span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">heads</span><span class="p">,</span> <span class="n">d_model</span><span class="p">]</span></code>
212212
</p>
213213

214214
</div>

labml_nn/transformers/mha.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def forward(self, x: torch.Tensor):
6262
# Split last dimension into heads
6363
x = x.view(*head_shape, self.heads, self.d_k)
6464

65-
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
65+
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`
6666
return x
6767

6868

0 commit comments

Comments
 (0)