• 大小: 14KB
    文件类型: .py
    金币: 2
    下载: 1 次
    发布日期: 2021-06-17
  • 语言: Python
  • 标签:

资源简介

在此之前,脑肿瘤专栏中的2D网络预测的时候,是把所有的切片预测完指标再求平均值,这样测的值极容易收到一些差的切片而影响整体的指标.所以以后的2D网络预测都采用下面方式进行计算指标,即把所有预测的切片拼接回3D,然后对3D数据整体进行计算指标.这样计算的值会偏高点.不只是2D网络这样,3D网络也是如此,把所有分块拼接后再对整体进行指标的计算.这样统一之后,我们就可以将2D和3D网络进行对比了.此外,代码预测生成的数据都是NII格式的,可以通过ITK-SNAP软件查看三维的分割效果,如果想看2D切片的分割效果,可以用该软件导出即可.

资源截图

代码片段和文件信息

# -*- coding: utf-8 -*-

import time
import os
import math
import argparse
from glob import glob
from collections import OrderedDict
import random
import warnings
from datetime import datetime

import numpy as np
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from skimage.io import imread imsave

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torchvision
from torchvision import datasets models transforms

from dataset import Dataset

import Unet
from metrics import dice_coef batch_iou mean_iou iou_score ppvsensitivity
import losses
from utils import str2bool count_params
from sklearn.externals import joblib
#from hausdorff import hausdorff_distance
import imageio
#import ttach as tta
import SimpleITK as sitk

wt_dices = []
tc_dices = []
et_dices = []
wt_sensitivities = []
tc_sensitivities = []
et_sensitivities = []
wt_ppvs = []
tc_ppvs = []
et_ppvs = []
wt_Hausdorf = []
tc_Hausdorf = []
et_Hausdorf = []


def hausdorff_distance(lTlP):
    labelPred=sitk.GetImageFromArray(lP isVector=False)
    labelTrue=sitk.GetImageFromArray(lT isVector=False)
    hausdorffcomputer=sitk.HausdorffDistanceImageFilter()
    hausdorffcomputer.Execute(labelTrue>0.5labelPred>0.5)
    return hausdorffcomputer.GetAverageHausdorffDistance()#hausdorffcomputer.GetHausdorffDistance()

def CalculateWTTCET(wtpbregionwtmaskregiontcpbregiontcmaskregionetpbregionetmaskregion):
    #开始计算WT
    dice = dice_coef(wtpbregionwtmaskregion)
    wt_dices.append(dice)
    ppv_n = ppv(wtpbregion wtmaskregion)
    wt_ppvs.append(ppv_n)
    Hausdorff = hausdorff_distance(wtmaskregion wtpbregion)
    wt_Hausdorf.append(Hausdorff)
    sensitivity_n = sensitivity(wtpbregion wtmaskregion)
    wt_sensitivities.append(sensitivity_n)
    # 开始计算TC
    dice = dice_coef(tcpbregion tcmaskregion)
    tc_dices.append(dice)
    ppv_n = ppv(tcpbregion tcmaskregion)
    tc_ppvs.append(ppv_n)
    Hausdorff = hausdorff_distance(tcmaskregion tcpbregion)
    tc_Hausdorf.append(Hausdorff)
    sensitivity_n = sensitivity(tcpbregion tcmaskregion)
    tc_sensitivities.append(sensitivity_n)
    # 开始计算ET
    dice = dice_coef(etpbregion etmaskregion)
    et_dices.append(dice)
    ppv_n = ppv(etpbregion etmaskregion)
    et_ppvs.append(ppv_n)
    Hausdorff = hausdorff_distance(etmaskregion etpbregion)
    et_Hausdorf.append(Hausdorff)
    sensitivity_n = sensitivity(etpbregion etmaskregion)
    et_sensitivities.append(sensitivity_n)


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(‘--name‘ default=None
                        help=‘model name‘)
    parser.add_argument(‘--mode‘ default=None
                        help=‘‘)


评论

共有 条评论

相关资源