• 大小:
    文件类型: .zip
    金币: 2
    下载: 1 次
    发布日期: 2021-08-02
  • 语言: Python
  • 标签: SVM,MNIST  

资源简介

SVM分类手动鼠标手写数字-python版本

资源截图

代码片段和文件信息

from sklearn import datasets
#导入交叉验证库
from sklearn import cross_validation
#导入SVM分类算法库
from sklearn import svm
#导入图表库
import cv2
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from sklearn.externals import joblib

drawing = False #鼠标按下为真
mode = True #如果为真,画矩形,按m切换为曲线
ixiy=-1-1
mnist = input_data.read_data_sets(“MNIST“ one_hot=True)

# X_train = mnist.train.images
# X_test =  mnist.test.images
number_train = len(mnist.train.images)
number_test = len(mnist.test.images)
X_train = mnist.train.images.reshape((number_train 784))
X_test =  mnist.test.images.reshape((number_test 784))
# print(X_train.shape)
y_train = np.zeros(number_train)
y_test = np.zeros(number_test)

# print(mnist.test.labels[2])
for i in range (054999):
    for j in range (09):
        if mnist.train.labels[i][j] == 1:
            y_train[i] = j
            j = 0
            break

# print(y_train[10])
for i in range (09999):
    for j in range (09):
        if mnist.test.labels[i][j] == 1:
            y_test[i] = j
# print(y_test)

#生成SVM分类模型
clf = svm.SVC(max_iter= 20000)
#使用训练集对svm分类模型进行训练
clf.fit(X_train y_train)
joblib.dump(clf“model_2/SVM_MNIST.pkl“)
print(1)
score = clf.score(X_testy_test)
print(“准确率 : “score)


# def draw(eventxyflagsparam):
#     global ixiydrawingmode
#
#     if event == cv2.EVENT_LBUTTONDOWN:
#         drawing = True
#         ixiy=xy
#
#
#     elif event == cv2.EVENT_MOUSEMOVE:
#         if drawing == True:
#             cv2.circle(img (x y) 5 (255 255 255) -1)
#     elif event == cv2.EVENT_LBUTTONUP:
#         drawing = False
#         cv2.circle(img (x y) 5 (255 255 255) -1)
#
#
# img = np.zeros((1281281)np.uint8)
#
# cv2.namedWindow(‘image‘)
# cv2.setMouseCallback(‘image‘draw)
#
# while(1):
#
#     cv2.imshow(‘image‘img)
#     resized_image = cv2.resize(img(2828))
#     resized_image = cv2.normalize(resized_imageresized_image01cv2.NORM_MINMAXcv2.CV_32F)
#     # print(resized_image)
#     # print(resized_image.shape)
#     resized_image = resized_image.reshape((1 784))
#     # cv2.imshow(“zero“resized_image)
#     k = cv2.waitKey(1) & 0xFF
#     if k == ord(‘m‘) :
#         mode = not mode
#     elif k == 13:
#         predict = clf.predict(resized_image)
#         print(“predict : “ predict)
#     elif k == 27:
#         break
# cv2.destroyAllWindows()

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     文件        2586  2018-01-11 13:07  MNIST.py
     文件        1596  2018-01-11 13:05  手写3.py

评论

共有 条评论