GTX_AI 发表于 2019-5-12 23:31:34

Pytorch_Unet图像分割

Pytorch_Unet图像分割:
# -*- coding: utf-8 -*-
import sys
sys.path.append(r"D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\u_net_liver-master")
import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
import cv2
import random

# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
#print(device)

x_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(, )
])

# mask只需要转换为tensor
y_transforms = transforms.ToTensor()

#参数解析
parse=argparse.ArgumentParser()

def train_model(model, criterion, optimizer, dataload, num_epochs=20):
    model.cuda(0)
    for epoch in range(num_epochs):
      print('Epoch {}/{}'.format(epoch, num_epochs - 1))
      print('-' * 10)
      dt_size = len(dataload.dataset)
      epoch_loss = 0
      step = 0
      #with torch.no_grad():
      for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            del x,y
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
            
            del inputs,labels,outputs
            torch.cuda.empty_cache()
            
      print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
    torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
    return model

#训练模型
def train():
    model = Unet(3, 1).to(device)
    batch_size = args.batch_size
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    train_model(model, criterion, optimizer, dataloaders)

#显示模型的输出结果
def test_gpu():
    model = Unet(3, 1).to(device)
#    model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
    model.load_state_dict(torch.load(args.ckp, map_location='cuda'))
    liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.cuda(0)
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
      for x, _ in dataloaders:
            y=model(x.to(device))
            img_y=torch.squeeze(y).cpu().numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
            
            torch.cuda.empty_cache()
            
            cv2.imwrite('./resultsImages/'+str(random.random())+'.tiff', img_y)

      plt.show()
    torch.cuda.empty_cache()
   
def test_cpu():
    model = Unet(3, 1)
    model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
    liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion()
    with torch.no_grad():
      for x, _ in dataloaders:
            y=model(x)
            img_y=torch.squeeze(y).numpy()
            plt.imshow(img_y)
            plt.pause(0.01)
      plt.show()
      
    torch.cuda.empty_cache()
   
if __name__ == '__main__':
    parse = argparse.ArgumentParser()
#    parse.add_argument("action", type=str, help="train or test")
    parse.add_argument("--action", type=str, help="train or test", default="train")
    parse.add_argument("--batch_size", type=int, default=4)
#    parse.add_argument("--ckp", type=str, help="the path of model weight file")
    parse.add_argument("--ckp", type=str, help="the path of model weight file", default="weights_19.pth")
    args = parse.parse_args()

    print(args.action)

    if args.action=="train":
      train()
    elif args.action=="test":
      test_gpu()
      
    test_gpu()
    尤其要注意的是:
del inputs,labels,outputs
torch.cuda.empty_cache()
清理GPU缓存,不然报GPU内存不够错误。

unet.py如下:
import torch.nn as nn
import torch
from torch import autograd

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
      super(DoubleConv, self).__init__()
      self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
      )

    def forward(self, input):
      return self.conv(input)


class Unet(nn.Module):
    def __init__(self,in_ch,out_ch):
      super(Unet, self).__init__()

      self.conv1 = DoubleConv(in_ch, 64)
      self.pool1 = nn.MaxPool2d(2)
      self.conv2 = DoubleConv(64, 128)
      self.pool2 = nn.MaxPool2d(2)
      self.conv3 = DoubleConv(128, 256)
      self.pool3 = nn.MaxPool2d(2)
      self.conv4 = DoubleConv(256, 512)
      self.pool4 = nn.MaxPool2d(2)
      self.conv5 = DoubleConv(512, 1024)
      self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
      self.conv6 = DoubleConv(1024, 512)
      self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
      self.conv7 = DoubleConv(512, 256)
      self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
      self.conv8 = DoubleConv(256, 128)
      self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
      self.conv9 = DoubleConv(128, 64)
      self.conv10 = nn.Conv2d(64,out_ch, 1)

    def forward(self,x):
      c1=self.conv1(x)
      p1=self.pool1(c1)
      c2=self.conv2(p1)
      p2=self.pool2(c2)
      c3=self.conv3(p2)
      p3=self.pool3(c3)
      c4=self.conv4(p3)
      p4=self.pool4(c4)
      c5=self.conv5(p4)
      up_6= self.up6(c5)
      merge6 = torch.cat(, dim=1)
      c6=self.conv6(merge6)
      up_7=self.up7(c6)
      merge7 = torch.cat(, dim=1)
      c7=self.conv7(merge7)
      up_8=self.up8(c7)
      merge8 = torch.cat(, dim=1)
      c8=self.conv8(merge8)
      up_9=self.up9(c8)
      merge9=torch.cat(,dim=1)
      c9=self.conv9(merge9)
      c10=self.conv10(c9)
      out = nn.Sigmoid()(c10)
      return out
参考:
【1】读取文件夹内图像--方法1
【2】读取文件夹内图像--方法2
【3】https://github.com/gupta-abhay/pytorch-modelzoo
【4】https://github.com/yt4766269/pytorch_zoo
【5】https://github.com/takahiro-itazuri/model-zoo-pytorch
【6】http://222.195.93.137/gitlab/winston.wen/kaggle-1
【7】3-6之百度网盘下载链接:https://pan.baidu.com/s/17pXf2M3lIAFkJuoMXOZ9pA 提取码:k0ou
















页: [1]
查看完整版本: Pytorch_Unet图像分割