• 大小: 8KB
    文件类型: .py
    金币: 2
    下载: 1 次
    发布日期: 2021-06-17
  • 语言: Python
  • 标签: tensorflow  

资源简介

VGG,V3,RESNET迁移学习,tensorflow和keras写的程序

资源截图

代码片段和文件信息

# -*- coding: utf-8 -*-  
import os  
from keras.utils import plot_model  
from keras.applications.resnet50 import ResNet50  
from keras.applications.vgg19 import VGG19  
from keras.applications.inception_v3 import InceptionV3  
from keras.layers import DenseFlattenGlobalAveragePooling2D  
from keras.models import Modelload_model  
from keras.optimizers import SGD  
from keras.preprocessing.image import ImageDataGenerator  
import matplotlib.pyplot as plt  
  
class PowerTransferMode:  
    #数据准备  
    def DataGen(self dir_path img_row img_col batch_size is_train):  
        if is_train:  
            datagen = ImageDataGenerator(rescale=1./255  #值将在执行其他处理前乘到整个图像上,
                                                          # 我们的图像在RGB通道都是0~255的整数,
                                                    # 这样的操作可能使图像的值过高或过低,所以我们将这个值定为0~1之间的数。
                zoom_range=0.25        #随机缩放的幅度
                rotation_range=15.      #数据提升时图片随机转动的角度
                channel_shift_range=25.    #随机通道偏移的幅度
                width_shift_range=0.02     #数据提升时图片随机水平偏移的幅度
                height_shift_range=0.02    #数据提升时图片随机竖直偏移的幅度
                horizontal_flip=True     #水平旋转
                fill_mode=‘constant‘)  #当进行变换时超出边界的点将根据本参数给定的方法进行处理
        else:  
            datagen = ImageDataGenerator(rescale=1./255)  
  
        generator = datagen.flow_from_directory(  
            dir_path target_size=(img_row img_col)  
            batch_size=batch_size  
            #class_mode=‘binary‘  
            shuffle=is_train)  
  
        return generator  
  
    #ResNet模型  
    def ResNet50_model(self lr=0.005 decay=1e-6 momentum=0.9 nb_classes=2 img_rows=197 img_cols=197 RGB=True is_plot_model=False):  
        color = 3 if RGB else 1  
        base_model = ResNet50(weights=‘imagenet‘ include_top=False pooling=None input_shape=(img_rows img_cols color)  
                              classes=nb_classes)  
  
        #冻结base_model所有层,这样就可以正确获得bottleneck特征  
        for layer in base_model.layers:  
            layer.trainable = False  
  
        x = base_model.output  
        #添加自己的全链接分类层  
        x = Flatten()(x)  
        #x = GlobalAveragePooling2D()(x)  
        #x = Dense(1024 activation=‘relu‘)(x)  
        predictions = Dense(nb_classes activation=‘softmax‘)(x)  
  
        #训练模型  
        model = Model(inputs=base_model.input outputs=predictions)  
        sgd = SGD(lr=lr decay=decay momentum=momentum nesterov=True)  
        model.compile(loss=‘categorical_crossentropy‘ optimizer=sgd metrics=[‘accuracy‘])  
  
        #绘制模型  
        if is_plot_model:  
            plot_model(model to_file=‘resnet50_model.png‘show_shapes=True)  
  
        return model  
  
  
    #VGG模型  
    def VGG19_model(self lr=0.005 decay=1e-6 momentum=0.9 nb_classes=2 img_rows=197 img_cols=197 RGB=True is_plot_model=False):  
        color = 3 if RGB else 1  
        base_model = VGG19(weights=‘imagenet‘ include_top=False pooling=

评论

共有 条评论