Pytorch_Unet图像分割
Pytorch_Unet图像分割:# -*- coding: utf-8 -*-
import sys
sys.path.append(r"D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\u_net_liver-master")
import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
import cv2
import random
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
#print(device)
x_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(, )
])
# mask只需要转换为tensor
y_transforms = transforms.ToTensor()
#参数解析
parse=argparse.ArgumentParser()
def train_model(model, criterion, optimizer, dataload, num_epochs=20):
model.cuda(0)
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
#with torch.no_grad():
for x, y in dataload:
step += 1
inputs = x.to(device)
labels = y.to(device)
del x,y
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
del inputs,labels,outputs
torch.cuda.empty_cache()
print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
return model
#训练模型
def train():
model = Unet(3, 1).to(device)
batch_size = args.batch_size
criterion = torch.nn.BCELoss()
optimizer = optim.Adam(model.parameters())
liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_model(model, criterion, optimizer, dataloaders)
#显示模型的输出结果
def test_gpu():
model = Unet(3, 1).to(device)
# model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
model.load_state_dict(torch.load(args.ckp, map_location='cuda'))
liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=1)
model.cuda(0)
model.eval()
import matplotlib.pyplot as plt
plt.ion()
with torch.no_grad():
for x, _ in dataloaders:
y=model(x.to(device))
img_y=torch.squeeze(y).cpu().numpy()
plt.imshow(img_y)
plt.pause(0.01)
torch.cuda.empty_cache()
cv2.imwrite('./resultsImages/'+str(random.random())+'.tiff', img_y)
plt.show()
torch.cuda.empty_cache()
def test_cpu():
model = Unet(3, 1)
model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=1)
model.eval()
import matplotlib.pyplot as plt
plt.ion()
with torch.no_grad():
for x, _ in dataloaders:
y=model(x)
img_y=torch.squeeze(y).numpy()
plt.imshow(img_y)
plt.pause(0.01)
plt.show()
torch.cuda.empty_cache()
if __name__ == '__main__':
parse = argparse.ArgumentParser()
# parse.add_argument("action", type=str, help="train or test")
parse.add_argument("--action", type=str, help="train or test", default="train")
parse.add_argument("--batch_size", type=int, default=4)
# parse.add_argument("--ckp", type=str, help="the path of model weight file")
parse.add_argument("--ckp", type=str, help="the path of model weight file", default="weights_19.pth")
args = parse.parse_args()
print(args.action)
if args.action=="train":
train()
elif args.action=="test":
test_gpu()
test_gpu()
尤其要注意的是:
del inputs,labels,outputs
torch.cuda.empty_cache()
清理GPU缓存,不然报GPU内存不够错误。
unet.py如下:
import torch.nn as nn
import torch
from torch import autograd
class DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Unet(nn.Module):
def __init__(self,in_ch,out_ch):
super(Unet, self).__init__()
self.conv1 = DoubleConv(in_ch, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.conv6 = DoubleConv(1024, 512)
self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv7 = DoubleConv(512, 256)
self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv8 = DoubleConv(256, 128)
self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv9 = DoubleConv(128, 64)
self.conv10 = nn.Conv2d(64,out_ch, 1)
def forward(self,x):
c1=self.conv1(x)
p1=self.pool1(c1)
c2=self.conv2(p1)
p2=self.pool2(c2)
c3=self.conv3(p2)
p3=self.pool3(c3)
c4=self.conv4(p3)
p4=self.pool4(c4)
c5=self.conv5(p4)
up_6= self.up6(c5)
merge6 = torch.cat(, dim=1)
c6=self.conv6(merge6)
up_7=self.up7(c6)
merge7 = torch.cat(, dim=1)
c7=self.conv7(merge7)
up_8=self.up8(c7)
merge8 = torch.cat(, dim=1)
c8=self.conv8(merge8)
up_9=self.up9(c8)
merge9=torch.cat(,dim=1)
c9=self.conv9(merge9)
c10=self.conv10(c9)
out = nn.Sigmoid()(c10)
return out
参考:
【1】读取文件夹内图像--方法1
【2】读取文件夹内图像--方法2
【3】https://github.com/gupta-abhay/pytorch-modelzoo
【4】https://github.com/yt4766269/pytorch_zoo
【5】https://github.com/takahiro-itazuri/model-zoo-pytorch
【6】http://222.195.93.137/gitlab/winston.wen/kaggle-1
【7】3-6之百度网盘下载链接:https://pan.baidu.com/s/17pXf2M3lIAFkJuoMXOZ9pA 提取码:k0ou
页:
[1]