File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -49,10 +49,10 @@ def __init__(self, config):
4949 self .n_head = config .n_head
5050 self .n_embd = config .n_embd
5151 self .dropout = config .dropout
52- # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
53- self .flash = hasattr (torch .nn .functional , 'scaled_dot_product_attention' ) and self . dropout == 0.0
52+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
53+ self .flash = hasattr (torch .nn .functional , 'scaled_dot_product_attention' )
5454 if not self .flash :
55- print ("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0 .0" )
55+ print ("WARNING: using slow attention. Flash Attention requires PyTorch >= 2 .0" )
5656 # causal mask to ensure that attention is only applied to the left in the input sequence
5757 self .register_buffer ("bias" , torch .tril (torch .ones (config .block_size , config .block_size ))
5858 .view (1 , 1 , config .block_size , config .block_size ))
You can’t perform that action at this time.
0 commit comments