• 大小: 2KB
    文件类型: .py
    金币: 2
    下载: 1 次
    发布日期: 2021-06-17
  • 语言: Python
  • 标签: pytorch  data  

资源简介

对自己准备训练的数据集进行读取,即选取路径,读入数据其次将数据加载如train_loader中对图像进行训练等操作

资源截图

代码片段和文件信息

# -*- coding: utf-8 -*-
“““
Created on Sat Oct  6 09:19:17 2018

@author: Administrator
“““
import torch
import torchvision
from torchvision import datasets transforms
import matplotlib.pyplot as plt 
import numpy as np
import os
D=299
num=3000
change=300

# Data augmentation and normalization for training 
# Just normalization for validation
data_transforms = {
    ‘train‘: transforms.Compose([
        transforms.RandomSizedCrop(224)
        transforms.RandomHorizontalFlip()
        transforms.ToTensor()
        transforms.Normalize([0.485 0.456 0.406] [0.229 0.224 0.225])
    ])
    ‘val‘: transforms.Compose([
        transforms.Scale(256)
        transforms.CenterCrop(224)
        transforms.ToTensor()
        transforms.Normalize([0.485 0.456 0.406] [0.229 0.224 0.225])
    ])
}


data_dir = ‘C:\\Users\\Administrator.SKY-20180518VHY\\Desktop\\rx‘

train_sets = datasets.ImageFolder(os.path.join(data_dir ‘train‘) data_transforms[‘train‘])
train_loader = torch.utils.data.DataLoader(train_sets batch_size=10 shuffle=True num_workers=4)
train_size = len(train_sets

评论

共有 条评论