• 大小: 4KB
    文件类型: .py
    金币: 1
    下载: 0 次
    发布日期: 2021-05-12
  • 语言: Python
  • 标签: TF分类  

资源简介

针对已训练好的tensorflow模型,模型是根据自身需要训练的,将模型其应用的遥感影像分类中,并显示分类结果。

资源截图

代码片段和文件信息

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy.io as scio
import cv2
import datetimetime

os.environ[“CUDA_VISIBLE_DEVICES“] = ‘01‘

def discrete_matshow(data labels_names=[] title=““):
    # get discrete colormap
    cmap = plt.get_cmap(‘Paired‘ np.max(data) - np.min(data) + 1)
    # set limits .5 outside true range
    mat = plt.matshow(data
                      cmap=cmap
                      vmin=np.min(data) - .5
                      vmax=np.max(data) + .5)
    # tell the colorbar to tick at integers
    cax = plt.colorbar(mat
                       ticks=np.arange(np.min(data) np.max(data) + 1))

    # The names to be printed aside the colorbar
    if labels_names:
        cax.ax.set_yticklabels(labels_names)



    if title:
        plt.suptitle(title fontsize=14 fontweight=‘bold‘)


def next_batch(image ii h):
    j = 14
    temp = []
    while j < h - 14:
        rgb = image[ii - 14:ii + 14 j - 14:j + 14 :]
        temp.append(rgb)
        j += 1
    temp = np.array(temp)
    # print(temp.shape)
    # assert temp.shape[0] == 3972
    # print(temp.shape)
    return temp

img = cv2.imread(‘jimo_resize_2000.tif‘)
img = cv2.cvtColor(img cv2.COLOR_BGR2RGB)
img = np.multiply(img 1.0/255.0)
print(img.shape)
m = img.shape[0]
n = img.shape[1]

print(‘load the model....‘)
vgg_saver = tf.train.import_meta_graph(‘2017.09.11-03.31.ckpt.meta‘)
vgg_graph = tf.get_default_graph()

# for n in tf.get_default_graph().as_graph_def().node:
#     print(n.name)

x = tf.get_default_graph().get_tensor_by_name(‘Placeholder:0‘)
z = tf.get_default_graph().get_tensor_by_name(‘Placeholder_1:0‘)



feature = vgg_graph.get_tensor_by_name(“D_conv_mnist/fully_connected_2/BiasAdd:0“)
print(feature)
pred = tf.nn.softmax(feature)
print(‘extract jimo image feature...‘)
result = []

start_time = datetime.datetime.now()
with tf.Session() as sess:
    vgg_saver.restore(sess ‘./2017.09.11-03.31.ckpt‘)
    i = 14
    segmentation_ = []
    z_sample 

评论

共有 条评论

相关资源