訂閱
糾錯(cuò)
加入自媒體

UNet分割脊柱

隨著我們每天收集更多數(shù)據(jù),人工智能(AI)將越來越多地應(yīng)用于醫(yī)療領(lǐng)域。人工智能在醫(yī)療領(lǐng)域的一個(gè)關(guān)鍵應(yīng)用是診斷。醫(yī)療診斷中的人工智能有助于決策、管理、自動(dòng)化等。

脊柱是肌肉骨骼系統(tǒng)的重要組成部分,支撐著身體及其器官結(jié)構(gòu),同時(shí)在我們的活動(dòng)性和負(fù)荷轉(zhuǎn)移中發(fā)揮著重要作用。它還能保護(hù)脊髓免受撞擊造成的損傷和機(jī)械沖擊。

在自動(dòng)化脊柱處理管道中,脊柱標(biāo)記和分割是兩項(xiàng)基本任務(wù)。

可靠、準(zhǔn)確的脊柱圖像處理有望為臨床決策支持系統(tǒng)提供幫助,用于脊柱和骨骼健康的診斷、手術(shù)規(guī)劃和基于人群的分析。設(shè)計(jì)脊柱處理的自動(dòng)化算法具有挑戰(zhàn)性,這主要是因?yàn)榻馄蕦W(xué)和采集協(xié)議有相當(dāng)大的差異,以及公開可用數(shù)據(jù)的嚴(yán)重短缺。

在這個(gè)博客中,我將只關(guān)注給定CT掃描數(shù)據(jù)集中脊柱的分割。標(biāo)記每一個(gè)椎骨和進(jìn)一步診斷的任務(wù)沒有包含在這個(gè)博客中,可以作為這個(gè)任務(wù)的延續(xù)。

脊柱或脊柱分割是所有脊柱形態(tài)學(xué)和病理學(xué)自動(dòng)量化應(yīng)用中的關(guān)鍵步驟。

隨著深度學(xué)習(xí)的到來,對于計(jì)算機(jī)斷層掃描(CT)這樣的任務(wù)來說,大而多樣的數(shù)據(jù)是一個(gè)主要的熱門資源。然而,目前還沒有一個(gè)大規(guī)模的公共數(shù)據(jù)集。

VerSe是一個(gè)大型、多探測器、多站點(diǎn)的CT脊柱數(shù)據(jù)集,由355名患者的374次掃描組成。2019年和2020年都有數(shù)據(jù)集。在本博客中,我將兩個(gè)數(shù)據(jù)集合并為一個(gè)數(shù)據(jù)集,以從更多數(shù)據(jù)中獲益。

這些數(shù)據(jù)是根據(jù)CC BY-SA 4.0許可證提供的,因此完全是開源的。

NIfTI(神經(jīng)成像信息技術(shù)倡議)是神經(jīng)成像的一種文件格式。NIfTI文件在神經(jīng)科學(xué)甚至神經(jīng)放射學(xué)研究的成像信息學(xué)中非常常用。每個(gè)NIfTI文件包含多達(dá)7維的元數(shù)據(jù),并支持多種數(shù)據(jù)類型。

前三個(gè)維度用于定義三個(gè)空間維度x、y和z,而第四個(gè)維度用于定義時(shí)間點(diǎn)t。其余維度(從第五個(gè)維度到第七個(gè)維度)用于其他用途。然而,第五維仍然可以有一些預(yù)定義的用途,例如存儲(chǔ)特定于體素的分布參數(shù)或保存基于向量的數(shù)據(jù)。

ITK-SNAP是一個(gè)用于在3D醫(yī)學(xué)圖像中分割結(jié)構(gòu)的軟件應(yīng)用程序。它是可以安裝在不同平臺(tái)上的開源軟件。我用它可以在3D視圖中可視化NifTi文件,以及在原始圖像上加載和覆蓋3D遮罩。我強(qiáng)烈建議將其用于此任務(wù)。

計(jì)算機(jī)斷層掃描(CT)是一種x射線成像程序,在該程序中,x射線以快速旋轉(zhuǎn)的速度對準(zhǔn)患者。機(jī)器收集的信號(hào)將存儲(chǔ)在計(jì)算機(jī)中,以生成身體的橫截面圖像,也稱為“切片”。

這些切片被稱為斷層圖像,包含比常規(guī)x射線更詳細(xì)的信息。一系列切片可以數(shù)字“疊加”在一起,形成患者的3D圖像,從而更容易識(shí)別和定位基本結(jié)構(gòu)以及可能的腫瘤或異常。

步驟如下,首先開始下載2019年和2020年的數(shù)據(jù)集。

然后,將這兩個(gè)數(shù)據(jù)集合并到它們的訓(xùn)練、驗(yàn)證和測試文件夾中。下一步是讀取CT掃描圖像,并將CT掃描圖像的每個(gè)切片轉(zhuǎn)換為一系列PNG原始圖像和遮罩。后來,使用這個(gè)Github倉庫中的UNet模型,并訓(xùn)練了一個(gè)UNet模型。

數(shù)據(jù)理解:在開始數(shù)據(jù)處理和訓(xùn)練之前,我想加載幾個(gè)NIfTI文件,以便更熟悉它們的3D數(shù)據(jù)結(jié)構(gòu),能夠可視化它們并從圖像中提取元數(shù)據(jù)。

下載完VerSe數(shù)據(jù)集后,我打開了一個(gè)*.nii.gz*文件。通過讀取一個(gè)文件并查看CT掃描圖像的一個(gè)特定切片,我能夠運(yùn)行Numpy transpose功能,以軸向、矢狀和冠狀三種不同視圖查看一個(gè)切片。

在對原始圖像更加熟悉,能夠從原始3D圖像中提取一個(gè)切片后,現(xiàn)在是時(shí)候查看同一切片的遮罩文件了。

正如你在下面的圖片中所看到的,能夠?qū)⒄谡智衅采w在原始圖像切片上。我們在這里看到漸變色的原因是,遮罩文件不僅存在每個(gè)脊柱的定義區(qū)域,而且它們還有不同的標(biāo)簽(用不同的顏色顯示),以及每個(gè)脊柱的編號(hào)或標(biāo)簽。為了更好地理解脊柱標(biāo)記,你可以參考本頁。

數(shù)據(jù)準(zhǔn)備:

數(shù)據(jù)準(zhǔn)備的任務(wù)是從原始圖像和遮罩文件中的每個(gè)3D CT掃描文件生成圖像切片。

它首先使用NiBabel庫讀取“.zip”格式的原始圖像和遮罩圖像,并將其轉(zhuǎn)換為Numpy數(shù)組。然后檢查每個(gè)3D圖像,檢查每個(gè)圖像的視角,并嘗試將大部分圖像轉(zhuǎn)換為矢狀視圖。

接下來,我從每個(gè)切片生成PNG文件,并將其存儲(chǔ)為“L”格式,即灰度值。在這種情況下,我們不需要生成RGB圖像。

在這個(gè)任務(wù)中,使用了UNet體系結(jié)構(gòu),以便能夠在數(shù)據(jù)集上應(yīng)用語義分割。為了更好地了解UNet和語義切分,建議查看這個(gè)博客。

使用了Pytorch和Pytorchvision來完成這項(xiàng)任務(wù)。正如提到的,這個(gè)倉庫使用PyTorch很好地實(shí)現(xiàn)了UNet,一直在使用它的一些代碼。

由于正在使用NIfTI文件,并且為了能夠在python中讀取這些文件,將使用NiBabel庫。NiBabel是一個(gè)python庫,用于讀取和寫入一些常見的醫(yī)學(xué)和神經(jīng)成像文件格式,如NIfTI文件。

Dice分?jǐn)?shù):為了評估我們的模型在語義分割任務(wù)中的表現(xiàn),我們可以使用Dice分?jǐn)?shù)。Dice系數(shù)是2*重疊區(qū)域(預(yù)測遮罩區(qū)域和真實(shí)遮罩區(qū)域之間)除以兩幅圖像中的像素總數(shù)。

訓(xùn)練:首先我定義了UNet類,然后定義了PyTorch數(shù)據(jù)集類,其中包括讀取和預(yù)處理圖像。預(yù)處理任務(wù)包括加載PNG文件,將它們?nèi)空{(diào)整為一個(gè)大。ㄔ诒纠袨250x250),并將它們?nèi)哭D(zhuǎn)換為NumPy數(shù)組,然后再轉(zhuǎn)換為PyTorch張量。

通過調(diào)用dataset類(VerSeDataset),我們可以在我定義的批內(nèi)準(zhǔn)備數(shù)據(jù)。為了確保原始圖像和遮罩圖像之間的映射是正確的,我調(diào)用next(iter(valid_dataloader))來獲取批次中的下一個(gè)項(xiàng)目并將其可視化。

后來將模型定義為model=UNet(n_channels=1,n_classes=1)。通道數(shù)是1,因?yàn)橛幸粋(gè)灰度圖像而不是RGB,如果你的圖像是RGB圖像,你可以將n_channels改為3。類的數(shù)量是1,因?yàn)橹挥幸粋(gè)類來判斷一個(gè)像素是否是脊柱的一部分。如果你的問題是多類分割,你可以將類的數(shù)量設(shè)置為你有多少個(gè)類。

后來,訓(xùn)練了模型。對于每個(gè)批次,首先計(jì)算損失值,通過反向傳播更新參數(shù)。后來再次檢查了所有批次,只計(jì)算了驗(yàn)證數(shù)據(jù)集的損失,并存儲(chǔ)了損失值。接下來,對train和validation的損失值進(jìn)行了可視化觀察,并跟蹤了模型的性能。

保存模型后,能夠抓取其中一張圖像并將其傳遞給經(jīng)過訓(xùn)練的模型,并收到一張預(yù)測的遮罩圖像。通過將原始、真實(shí)蒙版和預(yù)測蒙版的三幅圖像并排繪制,能夠直觀地評估結(jié)果。

從上圖可以看出,模型在矢狀面和軸向視圖上都表現(xiàn)得非常好,因?yàn)轭A(yù)測的遮罩與真實(shí)的遮罩區(qū)域非常相似。

完整的代碼:

作者:Mazi Boustani

日期:2021年12月24日

目的:使用PyTorch訓(xùn)練UNet模型,使其能夠使用VerSe數(shù)據(jù)集分割脊柱

import numpy as np 

import pandas as pd

import os

from os import listdir

from os.path import splitext

import glob

import shutil

import random

from pathlib import Path

from PIL import Image

from tqdm import tqdm

import matplotlib.pyplot as plt

%matplotlib inline

try:

   import nibabel as nib

except:

   raise ImportError('Install NIBABEL')


import torch

import torch.nn as nn

from torch import Tensor

import torch.nn.functional as F

from torch import optim

import torchvision.transforms as T

from torch.utils.data import DataLoader, random_split

from torch.utils.data import Dataset


# set folder paths for train and validation data

data_folder_path = "/Users/mazi/Projects/other/CT/data"

train_data = data_folder_path + "/verse_19_20_training/"

validation_data = data_folder_path + "/verse_19_20_validation/"


數(shù)據(jù)理解

# get one image to load


train_data_raw_image = train_data + "/rawdata/sub-verse521/sub-verse521_dir-ax_ct.nii.gz"

one_image = nib.load(train_data_raw_image)

# look at image shape

print(one_image.shape)

# look at image header. To understand header please refer to: https://brainder.org/2012/09/23/the-nifti-file-format/

print(one_image.header)

# look at the raw data

one_image_data = one_image.get_fdata()

print(one_image_data)


# Visualize one image in three different angles

one_image_data_axial = one_image_data

# change the view

one_image_data_sagittal = np.transpose(one_image_data, [2,1,0])

one_image_data_sagittal = np.flip(one_image_data_sagittal, axis=0)

# change the view

one_image_data_coronal = np.transpose(one_image_data, [2,0,1])

one_image_data_coronal = np.flip(one_image_data_coronal, axis=0)

fig, ax = plt.subplots(1, 3, figsize = (60, 60))

ax[0].imshow(one_image_data_axial[:,:,10], cmap ='bone')

ax[0].set_title("Axial view", fontsize=60)

ax[1].imshow(one_image_data_sagittal[:,:,260], cmap ='bone')

ax[1].set_title("Sagittal view", fontsize=60)

ax[2].imshow(one_image_data_coronal[:,:,200], cmap ='bone')

ax[2].set_title("Coronal view", fontsize=60)

plt.show()

# Overlay a mask on top of raw image (one slice of CT-scan)

train_data_mask_image = train_data + "derivatives/sub-verse521/sub-verse521_dir-ax_seg-vert_msk.nii.gz"

train_data_mask_image = nib.load(train_data_mask_image).get_fdata()

plt.figure(figsize=(10,10))

rotated_raw = np.transpose(one_image_data, [2,1,0])

rotated_raw = np.flip(rotated_raw, axis=0)

plt.imshow(rotated_raw[:,:,260], cmap ='bone', interpolation='none')

train_data_mask_image[train_data_mask_image == 0 ] = np.nan

rotated_mask = np.transpose(train_data_mask_image, [2,1,0])

rotated_mask = np.flip(rotated_mask, axis=0)

plt.imshow(rotated_mask[:,:,260], cmap ='cool')

預(yù)處理數(shù)據(jù)

# Set paths to store processed train and validation raw images and masks

processed_train = "./processed_train/"

processed_validation = "./processed_validation/"


processed_train_raw_images = processed_train + "raw_images/"

processed_train_masks = processed_train + "masks/"


processed_validation_raw_images = processed_validation + "raw_images/"

processed_validation_masks = processed_validation + "masks/"

# Read all 2019 and 2020 raw files, both train and validation

raw_train_files = glob.glob(os.path.join(train_data, 'rawdatanii.gz'))

raw_validation_files = glob.glob(os.path.join(validation_data, 'rawdatanii.gz'))

print("Raw images count train: {0}, validation: {1}".format(len(raw_train_files), len(raw_validation_


# Read all 2019 and 2020 raw files, both train and validation

raw_train_files = glob.glob(os.path.join(train_data, 'rawdatanii.gz'))

raw_validation_files = glob.glob(os.path.join(validation_data, 'rawdatanii.gz'))

print("Raw images count train: {0}, validation: {1}".format(len(raw_train_files), len(raw_validation_files)))


# Read all 2019 and 2020 derivatives files, both train and validation

masks_train_files = glob.glob(os.path.join(train_data, 'derivativesnii.gz'))

masks_validation_files = glob.glob(os.path.join(validation_data, 'derivativesnii.gz'))


print("Masks images count train: {0}, validation: {1}".format(len(masks_train_files), len(masks_validation_files)))


def read_file(nii_file):

   '''

   Read .nii.gz file.

   Args:

     nii_file (str): a file path.
   

   Return:

     3D numpy array of CT image data.

   '''
   

   return np.a(chǎn)sanyarray(nib.load(nii_file).dataobj)


def save_file(raw_data, label_data, file_name, index, output_raw_file_path, output_label_file_path):

   '''

   Save file into npz format.


   Args:

     raw_data (array): 2D numpy array of raw image data.

     label_data (array): 2D numpy array of label image data.

     file_name (str): file name.

     index (int): slice of CT image.

     output_raw_file_path (str): Path to all raw files.

     output_label_file_path (str): Path to all mask files.

   '''

   # replace all non-zero pixels to 1

   label_data = np.where(label_data > 0, 1, label_data)

   unique_values = np.unique(label_data)
   

   # if data has pixel with value of 1 means it is a positive datapoint

   if len(unique_values) > 1:


       raw_file_name = "{0}{1}_{2}.png".format(output_raw_file_path, file_name, index)

       im = Image.fromarray(raw_data)

       im = im.convert("L")

       im.save(raw_file_name)

       label_file_name = "{0}{1}_{2}.png".format(output_label_file_path, file_name, index)

       im = Image.fromarray(label_data)

       im = im.convert("L")

       im.save(label_file_name)


def is_diagonal(matrix):

   '''

   Check if givem matrix is diagonal or not.

   

   Args:

       matrix (np array): numpy array

   '''

   for i in range(0, 3):

       for j in range(0, 3) :

           if ((i 。 j) and (matrix[i][j] 。 0)):

               return False

   return True


def generate_data(raw_file, label_file, file_name, output_raw_file_path, output_label_file_path):

   '''

   Main function to read each raw and label file and generate series of images

   per each slice.

   Args:

     raw_file (str): path to raw file.

     label_file (str): path to label file.

     file_name (str): file name.

     output_raw_file_path (str): Path to all raw files.

     output_label_file_path (str): Path to all mask files.

   '''
   

   # If skip every 2 slice. Adjacent slices can be very similar to each other and

   # will generate redundant data

   skip_slice = 3

   continue_it = True

   raw_data = read_file(raw_file)

   label_data = read_file(label_file)

   if "split" in raw_file:

       continue_it = False
   

   affine = nib.load(raw_file).a(chǎn)ffine


   if is_diagonal(affine[:3, :3]):

       transposed_raw_data = np.transpose(raw_data, [2,1,0])

       transposed_raw_data = np.flip(transposed_raw_data)

       transposed_label_data = np.transpose(label_data, [2,1,0])

       transposed_label_data = np.flip(transposed_label_data)


   else:

       transposed_raw_data = np.rot90(raw_data)

       transposed_raw_data = np.flip(transposed_raw_data)

       transposed_label_data = np.rot90(label_data)

       transposed_label_data = np.flip(transposed_label_data) 

   if continue_it:

       if transposed_raw_data.shape:

           slice_count = transposed_raw_data.shape[-1]

           print("File name: ", file_name, " - Slice count: ", slice_count)

           # skip some slices

           for each_slice in range(1, slice_count, skip_slice):

              save_file(transposed_raw_data[:,:,each_slice],

                        transposed_label_data[:,:,each_slice],

                         file_name,

                         each_slice,

                         output_raw_file_path,

                         output_label_file_path)


# Loop over raw images and masks and generate 'PNG' images.

print("Processing started.")

for each_raw_file in raw_train_files:

   raw_file_name = each_raw_file.split("/")[-1].split("_ct.nii.gz")[0]

   for each_mask_file in masks_train_files:

       if raw_file_name in each_mask_file.split("/")[-1]:

           generate_data(each_raw_file,
                         each_mask_file,
                         raw_file_name,
                         processed_train_raw_images,
                         processed_train_masks)

print("Processing train data done.")

# Loop over raw images and masks and generate 'PNG' images.

for each_raw_file in raw_validation_files:

   raw_file_name = each_raw_file.split("/")[-1].split("_ct.nii.gz")[0]

   for each_mask_file in masks_validation_files:

       if raw_file_name in each_mask_file.split("/")[-1]:

           generate_data(each_raw_file,
                         each_mask_file,
                         raw_file_name,
                         processed_validation_raw_images,
                         processed_validation_masks)

print("Processing validation data done.")


訓(xùn)練

# Define model parameters

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# image size to convert to

IMAGE_HEIGHT = 250

IMAGE_WIDTH = 250

LEARNING_RATE = 1e-4

BATCH_SIZE = 10

EPOCHS = 10

NUM_WORKERS = 8


# Set the device

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# UNet model parts

# Source code: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py

class DoubleConv(nn.Module):

   """(convolution => [BN] => ReLU) * 2"""

   def __init__(self, in_channels, out_channels, mid_channels=None):

       super().__init__()

       if not mid_channels:

           mid_channels = out_channels

       self.double_conv = nn.Sequential(

           nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),

           nn.BatchNorm2d(mid_channels),

           nn.ReLU(inplace=True),

           nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),

           nn.BatchNorm2d(out_channels),

           nn.ReLU(inplace=True)

       )

   def forward(self, x):

       return self.double_conv(x)


class Down(nn.Module):

   """Downscaling with maxpool then double conv"""

   def __init__(self, in_channels, out_channels):

       super().__init__()

       self.maxpool_conv = nn.Sequential(

           nn.MaxPool2d(2),

           DoubleConv(in_channels, out_channels)

       )

   def forward(self, x):

       return self.maxpool_conv(x)


class Up(nn.Module):

   """Upscaling then double conv"""

   def __init__(self, in_channels, out_channels, bilinear=True):

       super().__init__()

      

 # if bilinear, use the normal convolutions to reduce the number of channels

       if bilinear:
           self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
           self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
       else:
           self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
           self.conv = DoubleConv(in_channels, out_channels)


   def forward(self, x1, x2):

       x1 = self.up(x1)

       # input is CHW

       diffY = x2.size()[2] - x1.size()[2]

       diffX = x2.size()[3] - x1.size()[3]


       x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,

                       diffY // 2, diffY - diffY // 2])
       

       # if you have padding issues, see

       # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a

       # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

       x = torch.cat([x2, x1], dim=1)

       return self.conv(x)

class OutConv(nn.Module):

   def __init__(self, in_channels, out_channels):

       super(OutConv, self).__init__()

       self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)


   def forward(self, x):

       return self.conv(x)


# Defining UNet architecture

# Source code: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py

class UNet(nn.Module):

   def __init__(self, n_channels, n_classes, bilinear=True):

       super(UNet, self).__init__(

       self.n_channels = n_channels

       self.n_classes = n_classes

       self.bilinear = bilinear

       self.inc = DoubleConv(n_channels, 64)

       self.down1 = Down(64, 128)

       self.down2 = Down(128, 256)

       self.down3 = Down(256, 512)

       factor = 2 if bilinear else 1

       self.down4 = Down(512, 1024 // factor)

       self.up1 = Up(1024, 512 // factor, bilinear)

       self.up2 = Up(512, 256 // factor, bilinear)

       self.up3 = Up(256, 128 // factor, bilinear)

       self.up4 = Up(128, 64, bilinear)

       self.outc = OutConv(64, n_classes)

   def forward(self, x):

       x1 = self.inc(x)

       x2 = self.down1(x1)

       x3 = self.down2(x2)

       x4 = self.down3(x3)

       x5 = self.down4(x4)

       x = self.up1(x5, x4)

       x = self.up2(x, x3)

       x = self.up3(x, x2)

       x = self.up4(x, x1)

       logits = self.outc(x)

       return logits


# Define PyTorch dataset class

# This class will access the images and masks, preprocess them for training and validation


class VerSeDataset(Dataset):

   def __init__(self, raw_images_path, masks_path, images_name):

       self.raw_images_path = raw_images_path

       self.masks_path = masks_path

       self.images_name = images_name
       

   def __len__(self):

       return len(self.images_name)

   def __getitem__(self, index):
       

       # get image and mask for a given index

       img_path = os.path.join(self.raw_images_path, self.images_name[index])

       mask_path = os.path.join(self.masks_path, self.images_name[index])
       

       # read the image and mask

       image = Image.open(img_path)

       mask = Image.open(mask_path)
       

       # resize image and change the shape to (1, image_width, image_h(yuǎn)eight)

       w, h = image.size

       image = image.resize((w, h), resample=Image.BICUBIC)

       image = T.Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT))(image)

       image_ndarray = np.a(chǎn)sarray(image)

       image_ndarray = image_ndarray.reshape(1, image_ndarray.shape[0], image_ndarray.shape[1])

       # resize the mask. Mask shape is (image_width, image_h(yuǎn)eight)

       mask = mask.resize((w, h), resample=Image.NEAREST)

       mask = T.Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT))(mask)

       mask_ndarray = np.a(chǎn)sarray(mask)
       

       return {

           'image': torch.a(chǎn)s_tensor(image_ndarray.copy()).float().contiguous(),

           'mask': torch.a(chǎn)s_tensor(mask_ndarray.copy()).float().contiguous(

      }


# Get path for all images and masks

train_images_paths = os.listdir(processed_train_raw_images)

train_masks_paths = os.listdir(processed_train_masks)

validation_images_paths = os.listdir(processed_validation_raw_images)

validation_masks_paths = os.listdir(processed_validation_masks)


# Load both images and masks data

train_data = VerSeDataset(processed_train_raw_images, processed_train_masks, train_images_paths)

valid_data = VerSeDataset(processed_validation_raw_images, processed_validation_masks, validation_images_paths)

# Create PyTorch DataLoader

train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False)


# Looking at one image and mask from one batch just to check them visually

next_image = next(iter(valid_dataloader))

fig, ax = plt.subplots(1, 2, figsize = (60, 60))

ax[0].imshow(next_image['image'][0][0,:,:], cmap ='bone')

ax[0].set_title("Raw image", fontsize=60)

ax[1].imshow(next_image['mask'][0][:,:], cmap ='bone')

ax[1].set_title("Mask image", fontsize=60)

plt.show()

# Defining Dice loss class

# Source code: https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch

class DiceLoss(nn.Module):

   def __init__(self, weight=None, size_average=True):

       super(DiceLoss, self).__init__()

   def forward(self, inputs, targets, smooth=1):


       inputs = torch.sigmoid(inputs)      
       

       # flatten label and prediction tensors

       inputs = inputs.view(-1)

       targets = targets.view(-1)
       

       intersection = (inputs * targets).sum()                            

       dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
       

       bce = F.binary_cross_entropy_with_logits(inputs, targets)

       pred = torch.sigmoid(inputs)

       loss = bce * 0.5 + dice * (1 - 0.5)
       

      # subtract 1 to calculate loss from dice value

       return 1 - dice


# Define model as UNet


model = UNet(n_channels=1, n_classes=1)

model.to(device=device)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


# Train and validate


train_loss = []

val_loss = []

for epoch in range(EPOCHS):  
   

   model.train()

   train_running_loss = 0.0

   counter = 0
   

   with tqdm(total=len(train_data), desc=f'Epoch {epoch + 1}/{EPOCHS}', unit='img') as pbar:

       for batch in train_dataloader:
           

           counter+=1

           image = batch['image'].to(DEVICE)

           mask = batch['mask'].to(DEVICE)
           

           optimizer.zero_grad()

           outputs = model(image)

           outputs = outputs.squeeze(1)

           loss = DiceLoss()(outputs, mask)

           train_running_loss += loss.item()

           loss.backward()

           optimizer.step()
           

           pbar.update(image.shape[0])

           pbar.set_postfix(**{'loss (batch)': loss.item()})

       train_loss.a(chǎn)ppend(train_running_loss/counter)
   

   model.eval()

   valid_running_loss = 0.0

   counter = 0

   with torch.no_grad():

       for i, data in enumerate(valid_dataloader):

           counter += 1
           

           image = data['image'].to(DEVICE)

           mask = data['mask'].to(DEVICE)

           outputs = model(image)

           outputs = outputs.squeeze(1)
           

           loss = DiceLoss()(outputs, mask)

           valid_running_loss += loss.item()
           

       val_loss.a(chǎn)ppend(valid_running_loss)


Epoch 1/10: 100%|██████████| 4790/4790 [4:00:34<00:00,  3.01s/img, loss (batch)=0.385]  

Epoch 2/10: 100%|██████████| 4790/4790 [4:00:02<00:00,  3.01s/img, loss (batch)=0.268]  

Epoch 3/10: 100%|██████████| 4790/4790 [3:57:30<00:00,  2.98s/img, loss (batch)=0.152]  

Epoch 4/10: 100%|██████████| 4790/4790 [3:57:05<00:00,  2.97s/img, loss (batch)=0.105]  

Epoch 5/10: 100%|██████████| 4790/4790 [4:08:29<00:00,  3.11s/img, loss (batch)=0.103]   

Epoch 6/10: 100%|██████████| 4790/4790 [4:04:12<00:00,  3.06s/img, loss (batch)=0.0874]  

Epoch 7/10: 100%|██████████| 4790/4790 [4:02:00<00:00,  3.03s/img, loss (batch)=0.0759]  

Epoch 8/10: 100%|██████████| 4790/4790 [3:58:32<00:00,  2.99s/img, loss (batch)=0.0655]  

Epoch 9/10: 100%|██████████| 4790/4790 [4:00:47<00:00,  3.02s/img, loss (batch)=0.0644]  

Epoch 10/10: 100%|██████████| 4790/4790 [4:08:54<00:00,  3.12s/img, loss (batch)=0.0604]  


# Plot train vs validation loss

plt.figure(figsize=(10, 7))

plt.plot(train_loss, color="orange", label='train loss')

plt.plot(val_loss, color="red", label='validation loss')

plt.xlabel("Epochs")

plt.ylabel("Loss")

plt.legend()

plt.show()

# Save the trained model

torch.save({

   'epoch': EPOCHS,

   'model_state_dict': model.state_dict(),

   'optimizer_state_dict': optimizer.state_dict(),
}, "./unet_model.pth")


# Visually look at one prediction 

next_image = next(iter(valid_dataloader))

# do predict

outputs = model(next_image['image'].float())

outputs = outputs.detach().cpu()

loss = DiceLoss()(outputs, next_image['mask'])

print("Dice Score: ", 1- loss.item())

outputs[outputs<=0.0] = 0

outputs[outputs>0.0] = 1.0

# plot all three images

fig, ax = plt.subplots(1, 3, figsize = (60, 60))

ax[0].imshow(next_image['image'][0][0,:,:], cmap ='bone')

ax[0].set_title("Raw Image", fontsize=60)

ax[1].imshow(next_image['mask'][0][:,:], cmap ='bone')

ax[1].set_title("True Mask", fontsize=60)

ax[2].imshow(outputs[0,0,:,:], cmap ='bone')

ax[2].set_title("Predicted Mask", fontsize=60)

plt.show()

未來的工作:這個(gè)任務(wù)也可以用3D UNet完成,這可能是學(xué)習(xí)脊柱結(jié)構(gòu)的更好方法。

因?yàn)槲覀儗γ總(gè)椎骨的每個(gè)遮罩區(qū)域都有標(biāo)簽,所以我們可以進(jìn)一步進(jìn)行多類遮罩分割。此外,當(dāng)圖像視圖為矢狀視圖時(shí),模型性能最好,因此,將所有切片轉(zhuǎn)換為矢狀視圖可能會(huì)得到更好的結(jié)果。

感謝閱讀!

       原文標(biāo)題 : UNet分割脊柱

聲明: 本文由入駐維科號(hào)的作者撰寫,觀點(diǎn)僅代表作者本人,不代表OFweek立場。如有侵權(quán)或其他問題,請聯(lián)系舉報(bào)。

發(fā)表評論

0條評論,0人參與

請輸入評論內(nèi)容...

請輸入評論/評論長度6~500個(gè)字

您提交的評論過于頻繁,請輸入驗(yàn)證碼繼續(xù)

  • 看不清,點(diǎn)擊換一張  刷新

暫無評論

暫無評論

    掃碼關(guān)注公眾號(hào)
    OFweek人工智能網(wǎng)
    獲取更多精彩內(nèi)容
    文章糾錯(cuò)
    x
    *文字標(biāo)題:
    *糾錯(cuò)內(nèi)容:
    聯(lián)系郵箱:
    *驗(yàn) 證 碼:

    粵公網(wǎng)安備 44030502002758號(hào)