Image segmentation with CamVid
3行魔法代码
3行魔法代码
%reload_ext autoreload
%autoreload 2
%matplotlib inline
所需library
所需library
from fastai import *
from fastai.vision import *
from fastai.callbacks.hooks import *
下载在GitHub中的数据集
下载在GitHub中的数据集
The One Hundred Layer Tiramisu paper used a modified version of Camvid, with smaller images and few classes. You can get it from the CamVid directory of this repo:
git clone https://github.com/alexgkendall/SegNet-Tutorial.git
构建path到所需文件夹
构建path到所需文件夹
path = Path('./data/camvid-tiramisu')
path.ls()
[PosixPath('data/camvid-tiramisu/valannot'),
PosixPath('data/camvid-tiramisu/test'),
PosixPath('data/camvid-tiramisu/val'),
PosixPath('data/camvid-tiramisu/val.txt'),
PosixPath('data/camvid-tiramisu/trainannot'),
PosixPath('data/camvid-tiramisu/testannot'),
PosixPath('data/camvid-tiramisu/train'),
PosixPath('data/camvid-tiramisu/test.txt'),
PosixPath('data/camvid-tiramisu/train.txt'),
PosixPath('data/camvid-tiramisu/models')]
Data
将文件夹中文件转化成path list
将文件夹中文件转化成path list
fnames = get_image_files(path/'val')
fnames[:3]
[PosixPath('data/camvid-tiramisu/val/0016E5_08065.png'),
PosixPath('data/camvid-tiramisu/val/0016E5_07989.png'),
PosixPath('data/camvid-tiramisu/val/0016E5_08041.png')]
lbl_names = get_image_files(path/'valannot')
lbl_names[:3]
[PosixPath('data/camvid-tiramisu/valannot/0016E5_08065.png'),
PosixPath('data/camvid-tiramisu/valannot/0016E5_07989.png'),
PosixPath('data/camvid-tiramisu/valannot/0016E5_08041.png')]
将文件path转化成Image,再展示
将文件path转化成Image,再展示
img_f = fnames[0]
img = open_image(img_f)
img.show(figsize=(5,5))
对应文件图片找到annot图片在转化成mask图片
对应文件图片找到annot图片在转化成mask图片
def get_y_fn(x): return Path(str(x.parent)+'annot')/x.name
codes = array(['Sky', 'Building', 'Pole', 'Road', 'Sidewalk', 'Tree',
'Sign', 'Fence', 'Car', 'Pedestrian', 'Cyclist', 'Void'])
mask = open_mask(get_y_fn(img_f))
mask.show(figsize=(5,5), alpha=1)
查看mask图片的尺寸和数据
查看mask图片的尺寸和数据
src_size = np.array(mask.shape[1:])
src_size,mask.data
(array([360, 480]), tensor([[[1, 1, 1, ..., 5, 5, 5],
[1, 1, 1, ..., 5, 5, 5],
[1, 1, 1, ..., 5, 5, 5],
...,
[4, 4, 4, ..., 3, 3, 3],
[4, 4, 4, ..., 3, 3, 3],
[4, 4, 4, ..., 3, 3, 3]]]))
Datasets
设置小批量大小
设置小批量大小
bs,size = 8,src_size//2
创建segmentation data source
创建segmentation data source
src = (SegmentationItemList.from_folder(path)
.split_by_folder(valid='val')
.label_from_func(get_y_fn, classes=codes))
从data source创建databunch
从data source创建databunch
data = (src.transform(get_transforms(), tfm_y=True)
.databunch(bs=bs)
.normalize(imagenet_stats))
show_batch 将原图和annot图进行了融合,好比图片与label一同打印一样
show_batch 将原图和annot图进行了融合,好比图片与label一同打印一样
data.show_batch(2, figsize=(10,7))
Model
找出'void'对应的mask代码
找出’void’对应的mask代码
name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Void']
针对camvid的准确率设计
针对camvid的准确率设计
def acc_camvid(input, target):
target = target.squeeze(1)
mask = target != void_code
return (input.argmax(dim=1)[mask]==target[mask]).float().mean()
metrics=acc_camvid
设置weight decay
设置weight decay
wd=1e-2
创建U-net模型
创建U-net模型
learn = unet_learner(data, models.resnet34, metrics=metrics, wd=wd, bottle=True)
寻找并画出学习率-损失值图,并挑选学习率
寻找并画出学习率-损失值图,并挑选学习率
lr_find(learn)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
lr=2e-3
训练10次(设置pct_start)
训练10次(设置pct_start)
learn.fit_one_cycle(10, slice(lr), pct_start=0.8)
learn.save('stage-1')
加载,解冻,设置学习率区间slice, 用pct_start训练12次
加载,解冻,设置学习率区间slice, 用pct_start训练12次
learn.load('stage-1');
learn.unfreeze()
lrs = slice(lr/100,lr)
learn.fit_one_cycle(12, lrs, pct_start=0.8)
Total time: 05:52
epoch | train_loss | valid_loss | acc_camvid |
---|---|---|---|
1 | 0.277594 | 0.273819 | 0.913931 |
2 | 0.271254 | 0.266760 | 0.916620 |
3 | 0.269084 | 0.269211 | 0.915474 |
4 | 0.273889 | 0.295377 | 0.914132 |
5 | 0.268701 | 0.312179 | 0.906329 |
6 | 0.295838 | 0.363080 | 0.902990 |
7 | 0.304576 | 0.323809 | 0.898795 |
8 | 0.290066 | 0.267403 | 0.920294 |
9 | 0.274901 | 0.274512 | 0.914693 |
10 | 0.275207 | 0.273877 | 0.920632 |
11 | 0.248439 | 0.236959 | 0.931970 |
12 | 0.224031 | 0.253183 | 0.926807 |
learn.save('stage-2');
Go big
释放空间
释放空间
learn=None
gc.collect()
4194
根据数据大小,调整小批量大小
根据数据大小,调整小批量大小
You may have to restart your kernel and come back to this stage if you run out of memory, and may also need to decrease bs
.
size = src_size
bs=8
准备Databunch
准备Databunch
data = (src.transform(get_transforms(), size=size, tfm_y=True)
.databunch(bs=bs)
.normalize(imagenet_stats))
构建U-net并加载之前训练的模型
构建U-net并加载之前训练的模型
learn = unet_learner(data, models.resnet34, metrics=metrics, wd=wd, bottle=True).load('stage-2');
寻找学习率并作图
寻找学习率并作图
lr_find(learn)
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
挑选学习率并开始训练,采用slice, pct_start
挑选学习率并开始训练,采用slice, pct_start
lr=1e-3
learn.fit_one_cycle(10, slice(lr), pct_start=0.8)
learn.save('stage-1-big')
加载模型,解冻,设置学习率区间slice, 并训练10次
加载模型,解冻,设置学习率区间slice, 并训练10次
learn.load('stage-1-big');
learn.unfreeze()
lrs = slice(lr/1000,lr/10)
learn.fit_one_cycle(10, lrs)
learn.save('stage-2-big')
加载训练好的模型,并展示结果
加载训练好的模型,并展示结果
learn.load('stage-2-big');
learn.show_results(rows=3, figsize=(9,11))
fin
# start: 480x360
总结模型特点
总结模型特点
print(learn.summary())
======================================================================
Layer (type) Output Shape Param # Trainable
======================================================================
Conv2d [8, 64, 180, 240] 9408 False
______________________________________________________________________
BatchNorm2d [8, 64, 180, 240] 128 True
______________________________________________________________________
ReLU [8, 64, 180, 240] 0 False
______________________________________________________________________
MaxPool2d [8, 64, 90, 120] 0 False
______________________________________________________________________
Conv2d [8, 64, 90, 120] 36864 False
______________________________________________________________________
BatchNorm2d [8, 64, 90, 120] 128 True
______________________________________________________________________
ReLU [8, 64, 90, 120] 0 False
______________________________________________________________________
Conv2d [8, 64, 90, 120] 36864 False
______________________________________________________________________
BatchNorm2d [8, 64, 90, 120] 128 True
______________________________________________________________________
Conv2d [8, 64, 90, 120] 36864 False
______________________________________________________________________
BatchNorm2d [8, 64, 90, 120] 128 True
______________________________________________________________________
ReLU [8, 64, 90, 120] 0 False
______________________________________________________________________
Conv2d [8, 64, 90, 120] 36864 False
______________________________________________________________________
BatchNorm2d [8, 64, 90, 120] 128 True
______________________________________________________________________
Conv2d [8, 64, 90, 120] 36864 False
______________________________________________________________________
BatchNorm2d [8, 64, 90, 120] 128 True
______________________________________________________________________
ReLU [8, 64, 90, 120] 0 False
______________________________________________________________________
Conv2d [8, 64, 90, 120] 36864 False
______________________________________________________________________
BatchNorm2d [8, 64, 90, 120] 128 True
______________________________________________________________________
Conv2d [8, 128, 45, 60] 73728 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
ReLU [8, 128, 45, 60] 0 False
______________________________________________________________________
Conv2d [8, 128, 45, 60] 147456 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
Conv2d [8, 128, 45, 60] 8192 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
Conv2d [8, 128, 45, 60] 147456 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
ReLU [8, 128, 45, 60] 0 False
______________________________________________________________________
Conv2d [8, 128, 45, 60] 147456 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
Conv2d [8, 128, 45, 60] 147456 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
ReLU [8, 128, 45, 60] 0 False
______________________________________________________________________
Conv2d [8, 128, 45, 60] 147456 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
Conv2d [8, 128, 45, 60] 147456 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
ReLU [8, 128, 45, 60] 0 False
______________________________________________________________________
Conv2d [8, 128, 45, 60] 147456 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
Conv2d [8, 256, 23, 30] 294912 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
ReLU [8, 256, 23, 30] 0 False
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
Conv2d [8, 256, 23, 30] 32768 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
ReLU [8, 256, 23, 30] 0 False
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
ReLU [8, 256, 23, 30] 0 False
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
ReLU [8, 256, 23, 30] 0 False
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
ReLU [8, 256, 23, 30] 0 False
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
ReLU [8, 256, 23, 30] 0 False
______________________________________________________________________
Conv2d [8, 256, 23, 30] 589824 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
Conv2d [8, 512, 12, 15] 1179648 False
______________________________________________________________________
BatchNorm2d [8, 512, 12, 15] 1024 True
______________________________________________________________________
ReLU [8, 512, 12, 15] 0 False
______________________________________________________________________
Conv2d [8, 512, 12, 15] 2359296 False
______________________________________________________________________
BatchNorm2d [8, 512, 12, 15] 1024 True
______________________________________________________________________
Conv2d [8, 512, 12, 15] 131072 False
______________________________________________________________________
BatchNorm2d [8, 512, 12, 15] 1024 True
______________________________________________________________________
Conv2d [8, 512, 12, 15] 2359296 False
______________________________________________________________________
BatchNorm2d [8, 512, 12, 15] 1024 True
______________________________________________________________________
ReLU [8, 512, 12, 15] 0 False
______________________________________________________________________
Conv2d [8, 512, 12, 15] 2359296 False
______________________________________________________________________
BatchNorm2d [8, 512, 12, 15] 1024 True
______________________________________________________________________
Conv2d [8, 512, 12, 15] 2359296 False
______________________________________________________________________
BatchNorm2d [8, 512, 12, 15] 1024 True
______________________________________________________________________
ReLU [8, 512, 12, 15] 0 False
______________________________________________________________________
Conv2d [8, 512, 12, 15] 2359296 False
______________________________________________________________________
BatchNorm2d [8, 512, 12, 15] 1024 True
______________________________________________________________________
BatchNorm2d [8, 512, 12, 15] 1024 True
______________________________________________________________________
ReLU [8, 512, 12, 15] 0 False
______________________________________________________________________
Conv2d [8, 1024, 12, 15] 4719616 True
______________________________________________________________________
ReLU [8, 1024, 12, 15] 0 False
______________________________________________________________________
Conv2d [8, 512, 12, 15] 4719104 True
______________________________________________________________________
ReLU [8, 512, 12, 15] 0 False
______________________________________________________________________
Conv2d [8, 1024, 12, 15] 525312 True
______________________________________________________________________
PixelShuffle [8, 256, 24, 30] 0 False
______________________________________________________________________
ReplicationPad2d [8, 256, 25, 31] 0 False
______________________________________________________________________
AvgPool2d [8, 256, 24, 30] 0 False
______________________________________________________________________
ReLU [8, 1024, 12, 15] 0 False
______________________________________________________________________
BatchNorm2d [8, 256, 23, 30] 512 True
______________________________________________________________________
Conv2d [8, 512, 23, 30] 2359808 True
______________________________________________________________________
ReLU [8, 512, 23, 30] 0 False
______________________________________________________________________
Conv2d [8, 512, 23, 30] 2359808 True
______________________________________________________________________
ReLU [8, 512, 23, 30] 0 False
______________________________________________________________________
ReLU [8, 512, 23, 30] 0 False
______________________________________________________________________
Conv2d [8, 1024, 23, 30] 525312 True
______________________________________________________________________
PixelShuffle [8, 256, 46, 60] 0 False
______________________________________________________________________
ReplicationPad2d [8, 256, 47, 61] 0 False
______________________________________________________________________
AvgPool2d [8, 256, 46, 60] 0 False
______________________________________________________________________
ReLU [8, 1024, 23, 30] 0 False
______________________________________________________________________
BatchNorm2d [8, 128, 45, 60] 256 True
______________________________________________________________________
Conv2d [8, 384, 45, 60] 1327488 True
______________________________________________________________________
ReLU [8, 384, 45, 60] 0 False
______________________________________________________________________
Conv2d [8, 384, 45, 60] 1327488 True
______________________________________________________________________
ReLU [8, 384, 45, 60] 0 False
______________________________________________________________________
ReLU [8, 384, 45, 60] 0 False
______________________________________________________________________
Conv2d [8, 768, 45, 60] 295680 True
______________________________________________________________________
PixelShuffle [8, 192, 90, 120] 0 False
______________________________________________________________________
ReplicationPad2d [8, 192, 91, 121] 0 False
______________________________________________________________________
AvgPool2d [8, 192, 90, 120] 0 False
______________________________________________________________________
ReLU [8, 768, 45, 60] 0 False
______________________________________________________________________
BatchNorm2d [8, 64, 90, 120] 128 True
______________________________________________________________________
Conv2d [8, 256, 90, 120] 590080 True
______________________________________________________________________
ReLU [8, 256, 90, 120] 0 False
______________________________________________________________________
Conv2d [8, 256, 90, 120] 590080 True
______________________________________________________________________
ReLU [8, 256, 90, 120] 0 False
______________________________________________________________________
ReLU [8, 256, 90, 120] 0 False
______________________________________________________________________
Conv2d [8, 512, 90, 120] 131584 True
______________________________________________________________________
PixelShuffle [8, 128, 180, 240] 0 False
______________________________________________________________________
ReplicationPad2d [8, 128, 181, 241] 0 False
______________________________________________________________________
AvgPool2d [8, 128, 180, 240] 0 False
______________________________________________________________________
ReLU [8, 512, 90, 120] 0 False
______________________________________________________________________
BatchNorm2d [8, 64, 180, 240] 128 True
______________________________________________________________________
Conv2d [8, 96, 180, 240] 165984 True
______________________________________________________________________
ReLU [8, 96, 180, 240] 0 False
______________________________________________________________________
Conv2d [8, 96, 180, 240] 83040 True
______________________________________________________________________
ReLU [8, 96, 180, 240] 0 False
______________________________________________________________________
ReLU [8, 192, 180, 240] 0 False
______________________________________________________________________
Conv2d [8, 384, 180, 240] 37248 True
______________________________________________________________________
PixelShuffle [8, 96, 360, 480] 0 False
______________________________________________________________________
ReplicationPad2d [8, 96, 361, 481] 0 False
______________________________________________________________________
AvgPool2d [8, 96, 360, 480] 0 False
______________________________________________________________________
ReLU [8, 384, 180, 240] 0 False
______________________________________________________________________
MergeLayer [8, 99, 360, 480] 0 False
______________________________________________________________________
Conv2d [8, 49, 360, 480] 43708 True
______________________________________________________________________
ReLU [8, 49, 360, 480] 0 False
______________________________________________________________________
Conv2d [8, 99, 360, 480] 43758 True
______________________________________________________________________
ReLU [8, 99, 360, 480] 0 False
______________________________________________________________________
MergeLayer [8, 99, 360, 480] 0 False
______________________________________________________________________
Conv2d [8, 12, 360, 480] 1200 True
______________________________________________________________________
Total params: 41133018
Total trainable params: 19865370
Total non-trainable params: 21267648