Skip to content

Commit 6170531

Browse files
committed
enable sdpa for nonzero dropout
1 parent ae3a8d5 commit 6170531

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def __init__(self, config):
4949
self.n_head = config.n_head
5050
self.n_embd = config.n_embd
5151
self.dropout = config.dropout
52-
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
53-
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
52+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
53+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
5454
if not self.flash:
55-
print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
55+
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
5656
# causal mask to ensure that attention is only applied to the left in the input sequence
5757
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
5858
.view(1, 1, config.block_size, config.block_size))

0 commit comments

Comments
 (0)