mnist作为最基础的图片数据集,在以后的cnn,rnn任务中都会用到
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data #数据集存放地址,采用0-1编码 mnist = input_data.read_data_sets( 'F:/mnist/data/' ,one_hot = True ) print (mnist.train.num_examples) print (mnist.test.num_examples) trainimg = mnist.train.images trainlabel = mnist.train.labels testimg = mnist.test.images testlabel = mnist.test.labels #打印相关信息 print ( type (trainimg)) print (trainimg.shape,) print (trainlabel.shape,) print (testimg.shape,) print (testlabel.shape,) nsample = 5 randidx = np.random.randint(trainimg.shape[ 0 ],size = nsample) #输出几张数字的图 for i in randidx: curr_img = np.reshape(trainimg[i,:],( 28 , 28 )) curr_label = np.argmax(trainlabel[i,:]) plt.matshow(curr_img,cmap = plt.get_cmap( 'gray' )) plt.title(" "+str(i)+" th Training Data "+" label is " + str (curr_label)) print (" "+str(i)+" th Training Data "+" label is " + str (curr_label)) plt.show() |
程序运行结果如下:
1
2
3
4
5
6
7
8
9
10
11
12
|
Extracting F: / mnist / data / train - images - idx3 - ubyte.gz Extracting F: / mnist / data / train - labels - idx1 - ubyte.gz Extracting F: / mnist / data / t10k - images - idx3 - ubyte.gz Extracting F: / mnist / data / t10k - labels - idx1 - ubyte.gz 55000 10000 < class 'numpy.ndarray' > ( 55000 , 784 ) ( 55000 , 10 ) ( 10000 , 784 ) ( 10000 , 10 ) 52636th |
输出的图片如下:
Training Datalabel is9
下面还有四张其他的类似图片
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/Missayaaa/article/details/80056103