Halcom 发表于 2021-5-30 16:54:40

OCR深度学习模型CRNN+BiLSTM 模型1

OCR深度学习模型CRNN+BiLSTM:
import torch.nn as nn
import torch.nn.functional as F

class BidirectionalLSTM(nn.Module):
    # Inputs hidden units Out
    def __init__(self, nIn, nHidden, nOut):
      super(BidirectionalLSTM, self).__init__()

      self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
      self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
      recurrent, _ = self.rnn(input)
      T, b, h = recurrent.size()
      t_rec = recurrent.view(T * b, h)

      output = self.embedding(t_rec)#
      output = output.view(T, b, -1)

      return output

class CRNN(nn.Module):
    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
      super(CRNN, self).__init__()
      assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

      ks =
      ps =
      ss =
      nm =

      cnn = nn.Sequential()

      def convRelu(i, batchNormalization=False):
            nIn = nc if i == 0 else nm
            nOut = nm
            cnn.add_module('conv{0}'.format(i),
                           nn.Conv2d(nIn, nOut, ks, ss, ps))
            if batchNormalization:
                cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
            if leakyRelu:
                cnn.add_module('relu{0}'.format(i),
                               nn.LeakyReLU(0.2, inplace=True))
            else:
                cnn.add_module('relu{0}'.format(i), nn.ReLU(True))

      convRelu(0)
      cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))# 64x16x64
      convRelu(1)
      cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))# 128x8x32
      convRelu(2, True)
      convRelu(3)
      cnn.add_module('pooling{0}'.format(2),
                     nn.MaxPool2d((2, 2), (2, 1), (0, 1)))# 256x4x16
      convRelu(4, True)
      convRelu(5)
      cnn.add_module('pooling{0}'.format(3),
                     nn.MaxPool2d((2, 2), (2, 1), (0, 1)))# 512x2x16
      convRelu(6, True)# 512x1x16

      self.cnn = cnn
      self.rnn = nn.Sequential(
            BidirectionalLSTM(512, nh, nh),
            BidirectionalLSTM(nh, nh, nclass))

    def forward(self, input):

      # conv features
      conv = self.cnn(input)
      b, c, h, w = conv.size()
      # print(conv.size())
      assert h == 1, "the height of conv must be 1"
      conv = conv.squeeze(2) # b *512 * width
      conv = conv.permute(2, 0, 1)#
      output = F.log_softmax(self.rnn(conv), dim=2)

      return output

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
      m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
      m.weight.data.normal_(1.0, 0.02)
      m.bias.data.fill_(0)

def get_crnn(config):

    model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
    model.apply(weights_init)

    return model


页: [1]
查看完整版本: OCR深度学习模型CRNN+BiLSTM 模型1