• 大小: 2KB
    文件类型: .zip
    金币: 2
    下载: 1 次
    发布日期: 2021-06-17
  • 语言: Python
  • 标签: python  svm  

资源简介

本代码用于二维数据不能采用简单线性划分去做划分的情况,采用python,支持向量机方式实现,对数据进行二分类,并自动绘制出三维立体图像。内含数据集txt格式,可直接运行。

资源截图

代码片段和文件信息

from numpy import *
import matplotlib.pyplot as plt
from matplotlib import cm
import mpl_toolkits.mplot3d
from sklearn import svm
from sklearn.model_selection import train_test_split

# 加载数据集
def loadDataSet(filename):
    dataMat = []; labelMat = []
    fr = open(filename)
    for line in fr.readlines():
        lineArr = line.strip().split(‘\t‘)
        dataMat.append([float(lineArr[0]) float(lineArr[1])])
        labelMat.append(float(lineArr[2]))
    return array(dataMat) array(labelMat)


if __name__ == “__main__“:
    data target = loadDataSet(“testSetRBF2.txt“)
    c_data = []
    for item in data:
        c_data.append([item[0] ** 2 item[0] * item[1] item[1] ** 2])
    c_data = array(c_data)

    index1 = where(target == 1)
    X1 = c_data[index1]
    index2 = where(target == -1)
    X2 = c_data[index2]


    cls = svm.LinearSVC()
    cls.fit(c_data target)
    print(‘Coefficients:%s intercept %s‘ % (cls.coef_ cls.intercept_))

    w = cls.coef_[0]
    d = cls.intercept_[0]
    # w[0] * x + w[1] * y + w[2] * z + d = 0 ---- z =

    # 转换为三维空间
    ax = plt.figure().add_subplot(111 projection=‘3d‘)
    ax.scatter(X1[: 0] X1[: 1] X1[: 2] c=‘r‘ marker=‘o‘)
    ax.scatter(X2[: 0] X2[: 1] X2[: 2] c=‘b‘ marker=‘x‘)

    X = arange(0 0.4 0.1)
    Y = arange(-0.4 0.4 0.1)
    X Y = meshgrid(X Y)
    Z = (-d - w[0] * X - w[1] * Y) / w[2]
    ax.plot_surface(X Y Z rstride=1 cstride=1
                    cmap=cm.jet linewidth=0 antialiased=False)
    plt.show()

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     文件        2151  2018-05-28 18:30  testSetRBF2.txt
     文件        1592  2018-05-28 20:41  testSetRBF2_Axes3D.py

评论

共有 条评论