• 大小: 9KB
    文件类型: .py
    金币: 1
    下载: 0 次
    发布日期: 2021-05-17
  • 语言: Python
  • 标签: mnist  

资源简介

模仿mnist数据集制作自己的数据集,并读取自己的数据集

资源截图

代码片段和文件信息

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License Version 2.0 (the “License“);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an “AS IS“ BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

“““Functions for downloading and reading MNIST data.“““
#-*-coding:utf-8-*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip

import numpy
from six.moves import xrange  # pylint: disable=redefined-builtin

from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.python.framework import dtypes

SOURCE_URL = ‘http://yann.lecun.com/exdb/mnist/‘


def _read32(bytestream):
  dt = numpy.dtype(numpy.uint32).newbyteorder(‘>‘)
  return numpy.frombuffer(bytestream.read(4) dtype=dt)[0]


def extract_images(f):
  “““Extract the images into a 4D uint8 numpy array [index y x depth].

  Args:
    f: A file object that can be passed into a gzip reader.

  Returns:
    data: A 4D uint8 numpy array [index y x depth].

  Raises:
    ValueError: If the bytestream does not start with 2051.

  “““
  print(‘Extracting‘ f.name)
  with gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError(‘Invalid magic number %d in MNIST image file: %s‘ %
                       (magic f.name))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf dtype=numpy.uint8)
    data = data.reshape(num_images rows cols 1)
    return data


def dense_to_one_hot(labels_dense num_classes):
  “““Convert class labels from scalars to one-hot vectors.“““
  num_labels = labels_dense.shape[0]
  index_offset = numpy.arange(num_labels) * num_classes
  labels_one_hot = numpy.zeros((num_labels num_classes))
  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  return labels_one_hot


def extract_labels(f one_hot=False num_classes=2):              #这里打num_classes手动设置为自己的类别数
  “““Extract the labels into a 1D uint8 numpy array [index].

  Args:
    f: A file object that can be passed into a gzip reader.
    one_hot: Does one hot encoding for the result.
    num_classes: Number of classes for the one hot encoding.

  Returns:
    labels: a 1D uint8 numpy array.

  Raises:
    ValueError: If the bystream doesn‘t start with 2049.
  “““
  print(‘Extracting‘ f.name)
  with gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read

评论

共有 条评论