tensor-datasets(load)

MNIST
import tensorflow_datasets as tfds
mnist, mnist_info=tfds.load('mnist', with_info=True, shuffle_files=False)
print(mnist_info)
tfds.core.DatasetInfo(
    name='mnist',
    full_name='mnist/3.0.1',
    description="""
    The MNIST database of handwritten digits.
    """,
    homepage='http://yann.lecun.com/exdb/mnist/',
    data_path='C:\\Users\\csian\\tensorflow_datasets\\mnist\\3.0.1',
    download_size=11.06 MiB,
    dataset_size=21.00 MiB,
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=10000, num_shards=1>,
        'train': <SplitInfo num_examples=60000, num_shards=1>,
    },
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
)
print(mnist.keys())
dict_keys([Split('train'), Split('test')])
ds_train=mnist['train']
ds_train=ds_train.map(lambda item:(item['image'], item['label']))
ds_train=ds_train.batch(10)
batch=next(iter(ds_train))
print(batch[0].shape, batch[1])dict_keys([Split('train'), Split('test')])
(10, 28, 28, 1) tf.Tensor([4 1 0 7 8 1 2 7 1 6], shape=(10,), dtype=int64)
fig=plt.figure(figsize=(15, 6))
for i, (image, label) in enumerate(zip(batch[0], batch[1])):
ax=fig.add_subplot(2, 5, i+1)
ax.set_xtick([]);ax.set_yticks([])
ax.imshow(image[::, 0], cmap='grey_r')
ax.set_title(' {}'.format(label), size=15)
plt.show()