Skip to content

Commit 09d0937

Browse files
committed
fix value pe double rotation
1 parent 2236f63 commit 09d0937

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

docs/transformers/rope/value_pe/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ <h2>Multi-head attention with rotary positional embeddings</h2>
412412

413413
</div>
414414
<div class='code'>
415-
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_rotary_pe</span><span class="p">(</span><span class="n">value</span><span class="p">))</span></pre></div>
415+
<div class="highlight"><pre><span class="lineno">234</span> <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ijbh,jbhd-&gt;ibhd&quot;</span><span class="p">,</span> <span class="n">attn</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span></pre></div>
416416
</div>
417417
</div>
418418
<div class='section' id='section-21'>

labml_nn/transformers/rope/value_pe/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def forward(self, *,
231231

232232
# Multiply by values
233233
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
234-
x = torch.einsum("ijbh,jbhd->ibhd", attn, self.value_rotary_pe(value))
234+
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
235235

236236
# Rotate in the opposite direction so that each embedding hold the relative positions
237237
x = self.value_reverse_rotary_pe(x)

0 commit comments

Comments
 (0)