• 大小: 14KB
    文件类型: .py
    金币: 1
    下载: 0 次
    发布日期: 2021-06-14
  • 语言: Python
  • 标签: yolov3  keras  python3  

资源简介

使用keras版yolov3绘制loss曲线程序。将该文件替换掉原工程中的train.py,运行即可。

资源截图

代码片段和文件信息

“““
Retrain the YOLO model for your own dataset.
“““
import time
import numpy as np
import keras.backend as K
from keras.layers import Input Lambda
from keras.models import Model
from keras.optimizers import Adam
from keras.callbacks import TensorBoard ModelCheckpoint ReduceLROnPlateau EarlyStopping

from yolo3.model import preprocess_true_boxes yolo_body tiny_yolo_body yolo_loss
from yolo3.utils import get_random_data

import keras
import matplotlib.pyplot as plt

# 构建绘图模块
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self logs={}):
        self.losses = {‘batch‘: [] ‘epoch‘: []}
        self.accuracy = {‘batch‘: [] ‘epoch‘: []}
        self.val_loss = {‘batch‘: [] ‘epoch‘: []}
        self.val_acc = {‘batch‘: [] ‘epoch‘: []}
    def on_batch_end(self batch logs={}):
        self.losses[‘batch‘].append(logs.get(‘loss‘))
        self.accuracy[‘batch‘].append(logs.get(‘acc‘))
        self.val_loss[‘batch‘].append(logs.get(‘val_loss‘))
        self.val_acc[‘batch‘].append(logs.get(‘val_acc‘))
        if int(time.time()) % 5 == 0:
            self.draw_loss(self.losses[‘batch‘] ‘loss‘ ‘train_batch‘)
            self.draw_loss_50(self.losses[‘batch‘] ‘loss‘ ‘train_batch_50‘)
            self.draw_loss_100(self.losses[‘batch‘] ‘loss‘ ‘train_batch_100‘)
            self.draw_loss_200(self.losses[‘batch‘] ‘loss‘ ‘train_batch_200‘)
            self.draw_loss_500(self.losses[‘batch‘] ‘loss‘ ‘train_batch_500‘)
            self.draw_loss_1000(self.losses[‘batch‘] ‘loss‘ ‘train_batch_1000‘)
            self.draw_p(self.accuracy[‘batch‘] ‘acc‘ ‘train_batch‘)
            self.draw_p(self.val_loss[‘batch‘] ‘loss‘ ‘val_batch‘)
            self.draw_p(self.val_acc[‘batch‘] ‘acc‘ ‘val_batch‘)
    def on_epoch_end(self batch logs={}):
        self.losses[‘epoch‘].append(logs.get(‘loss‘))
        self.accuracy[‘epoch‘].append(logs.get(‘acc‘))
        self.val_loss[‘epoch‘].append(logs.get(‘val_loss‘))
        self.val_acc[‘epoch‘].append(logs.get(‘val_acc‘))
        if int(time.time()) % 5 == 0:
            self.draw_loss(self.losses[‘epoch‘] ‘loss‘ ‘train_epoch‘)
            self.draw_loss_50(self.losses[‘batch‘] ‘loss‘ ‘train_batch_50‘)
            self.draw_loss_100(self.losses[‘batch‘] ‘loss‘ ‘train_batch_100‘)
            self.draw_loss_200(self.losses[‘batch‘] ‘loss‘ ‘train_batch_200‘)
            self.draw_loss_500(self.losses[‘batch‘] ‘loss‘ ‘train_batch_500‘)
            self.draw_loss_1000(self.losses[‘batch‘] ‘loss‘ ‘train_batch_500‘)
            self.draw_p(self.accuracy[‘epoch‘] ‘acc‘ ‘train_epoch‘)
            self.draw_p(self.val_loss[‘epoch‘] ‘loss‘ ‘val_epoch‘)
            self.draw_p(self.val_acc[‘epoch‘] ‘acc‘ ‘val_epoch‘)
    def draw_p(self lists label type):
        plt.figure()
        plt.plot(range(len(lists)) lists ‘r‘ label=label)
        #plt.ylim((0 150))
        plt.ylabel(label)
        plt.xlabel(type)

        plt.legend(loc=“upper right“)
        plt.sa

评论

共有 条评论