From tensorflow.examples.tutorials.mnist import Input_data
First you need to download the data set by networking:
Mnsit = Input_data.read_data_sets (train_dir= './mnist_data ', one_hot=true)
# If there is no mnist_data under the current folder, the folder is created first, Then download the mnist dataset
Partition of training set and test set:
X_train, Y_train = Mnist.train.images, mnist.train.labels
# returns X_train is a multidimensional array under NumPy, (55000, 784)
X_test, y_tes t = mnist.test.images, mnist.test.labels
# (10000, 784)
x_valid, y_valid = Mnist.valid.images, Mnist.valid.labels
# (5000, 784)
Of course, you can read the data in an iterative form with a certain batch_size:
Mnist.train.next_batch (100)
Mnist.train.next_batch () ⇒ Returns two values, one is the image data, one is the image data corresponding category information.
>> X_batch, Y_batch = mnist.train.next_batch (MB)
>> x_batch.shape
(784)
>> Y_ Batch.shape
(M) # One hot code
1. Visualization
# images:9* (28*28) 's Numpy.ndarray
# Y_ represents its true label information
def plot_mnist_3_3 (images, Y_, y=none):
assert Images.shape[0] = = Len (y_)
fig, axes = plt.subplots (3, 3)
for I, Ax in Enumerate (Axes.flat):
ax.imshow ( Images[i].reshape (image_shp), cmap= ' binary ')
if Y is None:
xlabel = ' True: {} '. Format (Y_[i])
else:
Xlabel = ' True: {0}, Pred: {1} '. Format (Y_[i], y[i])
Ax.set_xlabel (Xlabel)
ax.set_xticks ([])
Ax.set_yticks ([])
plt.show ()