• 大小: 20KB
    文件类型: .py
    金币: 1
    下载: 0 次
    发布日期: 2021-06-03
  • 语言: Python
  • 标签: CNN  ResNet  

资源简介

Attention-CNN 注意力机制细腻度图片分类。 ResNet改造

资源截图

代码片段和文件信息

import mxnet as mx
import proposal
import proposal_target
from rcnn.config import config

eps = 2e-5
use_global_stats = True
workspace = 512
res_deps = {‘50‘: (3 4 6 3) ‘101‘: (3 4 23 3) ‘152‘: (3 8 36 3) ‘200‘: (3 24 36 3)}
units = res_deps[‘101‘]
filter_list = [256 512 1024 2048]


def residual_unit(data num_filter stride dim_match name):
    bn1 = mx.sym.BatchNorm(data=data fix_gamma=False eps=eps use_global_stats=use_global_stats name=name + ‘_bn1‘)
    act1 = mx.sym.Activation(data=bn1 act_type=‘relu‘ name=name + ‘_relu1‘)
    conv1 = mx.sym.Convolution(data=act1 num_filter=int(num_filter * 0.25) kernel=(1 1) stride=(1 1) pad=(0 0)
                               no_bias=True workspace=workspace name=name + ‘_conv1‘)
    bn2 = mx.sym.BatchNorm(data=conv1 fix_gamma=False eps=eps use_global_stats=use_global_stats name=name + ‘_bn2‘)
    act2 = mx.sym.Activation(data=bn2 act_type=‘relu‘ name=name + ‘_relu2‘)
    conv2 = mx.sym.Convolution(data=act2 num_filter=int(num_filter * 0.25) kernel=(3 3) stride=stride pad=(1 1)
                               no_bias=True workspace=workspace name=name + ‘_conv2‘)
    bn3 = mx.sym.BatchNorm(data=conv2 fix_gamma=False eps=eps use_global_stats=use_global_stats name=name + ‘_bn3‘)
    act3 = mx.sym.Activation(data=bn3 act_type=‘relu‘ name=name + ‘_relu3‘)
    conv3 = mx.sym.Convolution(data=act3 num_filter=num_filter kernel=(1 1) stride=(1 1) pad=(0 0) no_bias=True
                               workspace=workspace name=name + ‘_conv3‘)
    if dim_match:
        shortcut = data
    else:
        shortcut = mx.sym.Convolution(data=act1 num_filter=num_filter kernel=(1 1) stride=stride no_bias=True
                                      workspace=workspace name=name + ‘_sc‘)
    sum = mx.sym.ElementWiseSum(*[conv3 shortcut] name=name + ‘_plus‘)
    return sum


def get_resnet_conv(data):
    # res1
    data_bn = mx.sym.BatchNorm(data=data fix_gamma=True eps=eps use_global_stats=use_global_stats name=‘bn_data‘)
    conv0 = mx.sym.Convolution(data=data_bn num_filter=64 kernel=(7 7) stride=(2 2) pad=(3 3)
                               no_bias=True name=“conv0“ workspace=workspace)
    bn0 = mx.sym.BatchNorm(data=conv0 fix_gamma=False eps=eps use_global_stats=use_global_stats name=‘bn0‘)
    relu0 = mx.sym.Activation(data=bn0 act_type=‘relu‘ name=‘relu0‘)
    pool0 = mx.symbol.Pooling(data=relu0 kernel=(3 3) stride=(2 2) pad=(1 1) pool_type=‘max‘ name=‘pool0‘)

    # res2
    unit = residual_unit(data=pool0 num_filter=filter_list[0] stride=(1 1) dim_match=False name=‘stage1_unit1‘)
    for i in range(2 units[0] + 1):
        unit = residual_unit(data=unit num_filter=filter_list[0] stride=(1 1) dim_match=True name=‘stage1_unit%s‘ % i)

    # res3
    unit = residual_unit(data=unit num_filter=filter_list[1] stride=(2 2) dim_match=False name=‘stage2_unit1‘)
    for i in range(2 units[1] + 1):
      

评论

共有 条评论