Skip to content

Commit eba36e8

Browse files
authored
Merge pull request karpathy#309 from ho2103/master
Fix AssertionError on macOS - need to check CUDA availability for bf16
2 parents 4eb7a96 + 1eaceae commit eba36e8

3 files changed

Lines changed: 3 additions & 3 deletions

File tree

bench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
real_data = True
1616
seed = 1337
1717
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
18-
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
18+
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
1919
compile = True # use PyTorch 2.0 to compile the model to be faster
2020
profile = False # use pytorch profiler, or just simple benchmarking?
2121
exec(open('configurator.py').read()) # overrides from command line or config file

sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
1919
seed = 1337
2020
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
21-
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
21+
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
2222
compile = False # use PyTorch 2.0 to compile the model to be faster
2323
exec(open('configurator.py').read()) # overrides from command line or config file
2424
# -----------------------------------------------------------------------------

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
backend = 'nccl' # 'nccl', 'gloo', etc.
7171
# system
7272
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
73-
dtype = 'bfloat16' if torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
73+
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
7474
compile = True # use PyTorch 2.0 to compile the model to be faster
7575
# -----------------------------------------------------------------------------
7676
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]

0 commit comments

Comments
 (0)