• 大小: 7.99MB
    文件类型: .zip
    金币: 2
    下载: 1 次
    发布日期: 2023-11-10
  • 语言: Python
  • 标签: USPS  python  

资源简介

数据集为usps手写数据集(.mat形式),共9298张图片,维度16*16,内附有python版的使用代码

资源截图

代码片段和文件信息

import numpy as np  #用于数据处理
import matplotlib.pyplot as plt  # 用于展示图片
import scipy.io as sio  # 用于读取.mat

def load_dataset(dataset=‘usps‘):
    # 加载usps数据集
    if dataset == ‘usps‘:
        data = sio.loadmat(‘usps_resampled.mat‘)
        x_train y_train x_test y_test = data[‘train_patterns‘].T data[‘train_labels‘].T data[‘test_patterns‘].T data[‘test_labels‘].T
        x = np.concatenate((x_train x_test))
        y_train = [np.argmax(l) for l in y_train]  # 将onehot编码转成一般编码
        y_test = [np.argmax(l) for l in y_test]  # 将onehot编码转成一般编码
        y = np.concatenate((np.array(y_train) np.array(y_test))).astype(np.int32)
        x = x.reshape((-1 16 16 1)).astype(np.float32)   # 便于使用卷积层
        # x = x.reshape((x.shape[0] 16*16)).astype(np.float32)   # 便于使用全连接层
        x = np.divide(x 255.)  # 归一化
        print(‘USPS:‘ x.shape y.shape)  # (9298 16 16 1)
        return x y

    else:
        print(‘The dataSet name is useless‘)
        exit(0)


def show_figure(data):  # 显示前200张图片
    digit_size = data.shape[1]  # 16 或者 28
    data = np.squeeze(data)  # 去掉1维
    figure = np.zeros((digit_size * 10 digit_size * 20))
    t = 0
    for i in range(10):  # 10行
        for j in range(20):  # 每行展示20个数据
            figure[i * digit_size: (i+1) * digit_size j * digit_size: (j+1) * digit_size] = data[t]
            t = t + 1
    plt.figure(figsize=(15 15))
    plt.imshow(figure)
    plt.show()

if __name__ == ‘__main__‘:
    # load dataset
    x y = load_dataset(‘usps‘)
    print(y[:200])  # 展示前200个样本的标签
    show_figure(x)  # 展示前200个样本数据

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2020-05-25 10:11  usps手写数据集+使用代码\
     文件        1759  2020-05-25 10:14  usps手写数据集+使用代码\test_usps.py
     文件    19228688  2006-03-13 20:48  usps手写数据集+使用代码\usps_resampled.mat

评论

共有 条评论