File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 77torchvision
88matplotlib
99"""
10+ # library
11+ # standard library
12+ import os
13+
14+ # third-party library
1015import torch
1116import torch .nn as nn
1217from torch .autograd import Variable
2025EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
2126BATCH_SIZE = 50
2227LR = 0.001 # learning rate
23- DOWNLOAD_MNIST = True # set to False if you have downloaded
28+ DOWNLOAD_MNIST = False
2429
2530
2631# Mnist digits dataset
32+ if not (os .path .exists ('./mnist/' )) or not os .listdir ('./mnist/' ):
33+ # not mnist dir or mnist is empyt dir
34+ DOWNLOAD_MNIST = True
35+
2736train_data = torchvision .datasets .MNIST (
2837 root = './mnist/' ,
2938 train = True , # this is training data
3039 transform = torchvision .transforms .ToTensor (), # Converts a PIL.Image or numpy.ndarray to
3140 # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
32- download = DOWNLOAD_MNIST , # download it if you don't have it
41+ download = DOWNLOAD_MNIST ,
3342)
3443
3544# plot one example
You can’t perform that action at this time.
0 commit comments