资源简介

代码用于猫-非猫图片的二分类问题(附件内给出h5格式的数据集),基于Pytorch神经网络工具包,采用比较经典的逻辑回归(Logistic Regression)算法。

资源截图

代码片段和文件信息

import numpy as np
import torch as t
import h5py
import os
os.environ[‘KMP_DUPLICATE_LIB_OK‘] = ‘True‘


class LogisticRegression(t.nn.Module):
    def __init__(self):
        super(LogisticRegression self).__init__()
        self.lg = t.nn.Sequential(
            t.nn.Linear(12288 1) t.nn.Sigmoid()
        )

    def forward(self x):
        output = self.lg(x)
        return output


lg_model = LogisticRegression()
cost_func = t.nn.BCELoss()
optimizer = t.optim.SGD(lg_model.parameters() lr=0.001 momentum=0.9)
epochs = 500

train_dataset = h5py.File(‘train_catvnoncat.h5‘ “r“)
train_set_x_orig = t.tensor(train_dataset[“train_set_x“][:])
train_set_y = t.tensor(np.array(train_dataset[“train_set_y“][:]))/1.0

test_dataset = h5py.File(‘test_catvnoncat.h5‘ “r“)
test_set_x_orig = t.tensor(np.array(test_dataset[“test_set_x“][:]))
test_set_y = t.tensor(np.array(test_dataset[“test_set_y“][:]))/1.0

num_train = train_set_x_orig.shape[0]
num_test = test_set_x_orig.shape[0]

train_set_x = train_set_x_orig.reshape(num_train -1)/255.0
test_set_x = test_set_x_orig.reshape(num_test -1)/255.0
train_set_y = train_set_y.reshape(num_train 1)
test_set_y = test_set_y.reshape(num_test 1)

train_loss = 0
for epoch in range(epochs):
    lg_model.train()
    y_out = lg_model(train_set_x)
    train_loss = cost_func(y_out train_set_y)
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()

    with t.no_grad():
y_pred = y_out.ge(0.5).float()
        num_correct = (y_pred == train_set_y).sum().item()
        acc_rate = num_correct * 100.0 / num_train
        print(“世代数: %d 训练集正确率: %.1f%%“ % (epoch acc_rate))

lg_model.eval()
        y_out = lg_model(test_set_x)
        y_pred = y_out.ge(0.5).float()
        num_correct = (y_pred == test_set_y).sum().item()
        acc_rate = num_correct * 100.0 / num_test
        print(“测试集正确率: %.1f%%“ % acc_rate)
    pass

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     文件        1991  2020-08-04 17:05  猫-非猫图二分类\Logistic_Regression.py
     文件      616958  2020-07-31 15:31  猫-非猫图二分类\test_catvnoncat.h5
     文件     2572022  2020-07-31 15:31  猫-非猫图二分类\train_catvnoncat.h5
     目录           0  2020-08-04 17:00  猫-非猫图二分类\

评论

共有 条评论