• 大小: 858KB
    文件类型: .zip
    金币: 2
    下载: 1 次
    发布日期: 2021-06-17
  • 语言: Python
  • 标签:

资源简介

Pytorch implementation of CRAFT text detector

资源截图

代码片段和文件信息

“““  
Copyright (c) 2019-present NAVER Corp.
MIT License
“““

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F

from basenet.vgg16_bn import vgg16_bn init_weights

class double_conv(nn.Module):
    def __init__(self in_ch mid_ch out_ch):
        super(double_conv self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch + mid_ch mid_ch kernel_size=1)
            nn.BatchNorm2d(mid_ch)
            nn.ReLU(inplace=True)
            nn.Conv2d(mid_ch out_ch kernel_size=3 padding=1)
            nn.BatchNorm2d(out_ch)
            nn.ReLU(inplace=True)
        )

    def forward(self x):
        x = self.conv(x)
        return x


class CRAFT(nn.Module):
    def __init__(self pretrained=False freeze=False):
        super(CRAFT self).__init__()

        “““ base network “““
        self.basenet = vgg16_bn(pretrained freeze)

        “““ U network “““
        self.upconv1 = double_conv(1024 512 256)
        self.upconv2 = double_conv(512 256 128)
        self.upconv3 = double_conv(256 128 64)
        self.upconv4 = double_conv(128 64 32)

        num_class = 2
        self.conv_cls = nn.Sequential(
            nn.Conv2d(32 32 kernel_size=3 padding=1) nn.ReLU(inplace=True)
            nn.Conv2d(32 32 kernel_size=3 padding=1) nn.ReLU(inplace=True)
            nn.Conv2d(32 16 kernel_size=3 padding=1) nn.ReLU(inplace=True)
            nn.Conv2d(16 16 kernel_size=1) nn.ReLU(inplace=True)
            nn.Conv2d(16 num_class kernel_size=1)
        )

        init_weights(self.upconv1.modules())
        init_weights(self.upconv2.modules())
        init_weights(self.upconv3.modules())
        init_weights(self.upconv4.modules())
        init_weights(self.conv_cls.modules())
        
    def forward(self x):
        “““ base network “““
        sources = self.basenet(x)

        “““ U network “““
        y = torch.cat([sources[0] sources[1]] dim=1)
        y = self.upconv1(y)

        y = F.interpolate(y size=sources[2].size()[2:] mode=‘bilinear‘ align_corners=False)
        y = torch.cat([y sources[2]] dim=1)
        y = self.upconv2(y)

        y = F.interpolate(y size=sources[3].size()[2:] mode=‘bilinear‘ align_corners=False)
        y = torch.cat([y sources[3]] dim=1)
        y = self.upconv3(y)

        y = F.interpolate(y size=sources[4].size()[2:] mode=‘bilinear‘ align_corners=False)
        y = torch.cat([y sources[4]] dim=1)
        feature = self.upconv4(y)

        y = self.conv_cls(feature)

        return y.permute(0231) feature

if __name__ == ‘__main__‘:
    model = CRAFT(pretrained=True).cuda()
    output _ = model(torch.randn(1 3 768 768).cuda())
    print(output.shape)

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2019-08-02 13:10  CRAFT-pytorch-master\
     文件          40  2019-08-02 13:10  CRAFT-pytorch-master\.gitignore
     文件        1064  2019-08-02 13:10  CRAFT-pytorch-master\LICENSE
     文件        3586  2019-08-02 13:10  CRAFT-pytorch-master\README.md
     目录           0  2019-08-02 13:10  CRAFT-pytorch-master\basenet\
     文件           0  2019-08-02 13:10  CRAFT-pytorch-master\basenet\__init__.py
     文件        2805  2019-08-02 13:10  CRAFT-pytorch-master\basenet\vgg16_bn.py
     文件        2753  2019-08-02 13:10  CRAFT-pytorch-master\craft.py
     文件        9099  2019-08-02 13:10  CRAFT-pytorch-master\craft_utils.py
     目录           0  2019-08-02 13:10  CRAFT-pytorch-master\figures\
     文件      870634  2019-08-02 13:10  CRAFT-pytorch-master\figures\craft_example.gif
     文件        2870  2019-08-02 13:10  CRAFT-pytorch-master\file_utils.py
     文件        2195  2019-08-02 13:10  CRAFT-pytorch-master\imgproc.py
     文件          95  2019-08-02 13:10  CRAFT-pytorch-master\requirements.txt
     文件        4833  2019-08-02 13:10  CRAFT-pytorch-master\test.py

评论

共有 条评论