Vgg16预训练下的Unet
Vgg16预训练下的Unet# -*- coding: utf-8 -*-
"""
Created on Tue Aug 27 22:06:49 2019
@author: Solem
"""
import sys
sys.path.append(r"D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\Vgg16_Unet")
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import torch
import torch.utils.data as data
from torch import autograd, optim
from torchvision.transforms import transforms
from torch.autograd import Variable
# 配置文件
from configs import configXML
from pytorch_zoo import unet, resnet38unet
# begin
net = unet.Vgg16bn(num_classes=configXML.num_classes)
net.train()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
#print(device)
net = net.cuda()
criterion1 =torch.nn.BCELoss()
if configXML.initial_checkpoints is not None:
net = torch.load(configXML.initial_checkpoints)
best_val_loss = np.inf
num_epoch_has_trained = 0
for epoch in range(num_epoch_has_trained, configXML.num_epochs):
iter_count = 0;
train_losses = 0.0;
optimizer = optim.SGD(net.parameters(), lr = configXML.lr_init, momentum = configXML.moment_init)
torch.cuda.empty_cache()
net.train()
for i,(train_image,train_mask) in enumerate(configXML.train_loader):
iter_count = iter_count+1
train_image = Variable(train_image.cuda())
train_mask = Variable(train_mask.cuda())
train_logits = net(train_image)
train_loss = criterion1(train_logits, train_mask)
train_losses = train_losses + train_loss.item()
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
torch.cuda.empty_cache()
torch.save(net.state_dict(), './checkPoints/Vgg16_Val_normal.pth')
configXML.py:
# -*- coding: utf-8 -*-
"""
Created on Sun Sep8 20:47:37 2019
@author: Solem
"""
from XMLdataSet import dataset_improved, dataset_package
import torch
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
Train_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images_train.txt'
Val_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images_val.txt'
Imagepath=r"D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images"
Images_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Images.txt'
Maskspath=r"D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Masks"
Masks_file=r'D:\2-LearningCode\992-DataEnhanced\Samples\EnhancedImages\Masks.txt'
num_classes = 1
img_size = 256
batch_size = 4
num_workers = 0
initial_checkpoints = None
#initial_checkpoints = r'D:\2-LearningCode\999-AI-Pytorch\3_AI_nets\u_net_liver-master\Vgg16_Unet\checkPoints\Vgg16_Val_normal.pth'
num_epochs = 1000
lr_init = 0.01
moment_init = 0.9
# 均值0.5,方差0.5
x_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(, )
])
# mask只需要转换为tensor
y_transforms = transforms.ToTensor()
#train_liver_dataset = dataset_improved.LiverDataset(imagepath, Train_file, transform=x_transforms, target_transform=y_transforms)
#train_loader = DataLoader(train_liver_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)
train_dataset = dataset_package.opencvDateset(img_root_path=Imagepath, mask_root_path=Maskspath, txt_file=Train_file)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
#val_liver_dataset = dataset_improved.LiverDataset(imagepath, Val_file, transform=x_transforms, target_transform=y_transforms)
#val_loader = DataLoader(val_liver_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers)
val_dataset = dataset_package.opencvDateset(img_root_path=Imagepath, mask_root_path=Maskspath, txt_file=Val_file)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)dataset_package.py如下:
# -*- coding: utf-8 -*-
"""
Created on Sun May 19 22:28:08 2019
@author: Solem
"""
import os
import PIL.Image as Image
import numpy as np
import cv2
import torch
import torch.utils.data as data
from torch import autograd, optim
from torchvision.transforms import transforms
from torch.autograd import Variable
from configs import configXML
def make_dataset(root):
imgs=[]
# n=len(os.listdir(root))//2
# for i in range(100):
for i in range(100):
for j in range(16):
img=os.path.join(root,"%03d_%d.png"%(i,j+1))
mask=os.path.join(root,"%03d_%d_mask.png"%(i,j+1))
imgs.append((img,mask))
return imgs
class LiverDataset(data.Dataset):
def __init__(self, root, transform=None, target_transform=None):
imgs = make_dataset(root)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
x_path, y_path = self.imgs
img_x = Image.open(x_path)
img_y = Image.open(y_path)
if self.transform is not None:
img_x = self.transform(img_x)
if self.target_transform is not None:
img_y = self.target_transform(img_y)
return img_x, img_y
def __len__(self):
return len(self.imgs)
class opencvDateset(data.Dataset):
def __init__(self,img_root_path,mask_root_path, txt_file, transform=None):
self.transform = transform
self.img_root_path = img_root_path
self.mask_root_path = mask_root_path
with open(txt_file) as f:
self.indexs = f.readlines()
def __getitem__(self, idx):
index = self.indexs.split('.')
image_path = os.path.join(self.img_root_path, index+'.jpg')
mask_path = os.path.join(self.mask_root_path, index+'_label.png')
img = cv2.imread(image_path)
img = cv2.resize(img, (configXML.img_size, configXML.img_size))
img = img.astype(np.float32)/255.0
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
# print(np.shape(mask))
mask = cv2.resize(mask, (configXML.img_size, configXML.img_size))
print(np.shape(mask))
if self.transform is not None:
for trans in self.transform:
img, mask = trans(img, mask)
img = torch.FloatTensor(img)
mask = torch.FloatTensor(mask)
return img, mask
def __len__(self):
return len(self.indexs)
参考:
【1】Pytorch_Unet图像分割
【2】http://222.195.93.137/gitlab/winston.wen/kaggle-1
dataset_package.py的第70行和第71行,应该可以这样修改,不然容易报错:def image_to_tensor(image, mean = 0.0, std = 1.0):
image = (image-mean)/std
image = image.transpose(2,0,1)
image = torch.from_numpy(image)
return image
def mask_to_tensor(mask):
mask = (mask>128).astype(np.float32)
mask = torch.from_numpy(mask)
return mask第61行和第65行,是否显得多余。
针对主程序而言:
from configs import configXML
这样import,那么你确实需要configXML.initial_checkpoints,configXML.num_epochs来调用类的成员变量
可以改为
from configs.configXML import *
然后你使用的都是全局变量了,可以直接使用initial_checkpoints,num_epochs
记得每个文件夹放一个__init__.py文件。
页:
[1]