Lesson 6: pets revisited
三行魔法代码和所需library
三行魔法代码和所需library
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
设置批量大小
设置批量大小
bs = 64
下载数据,获取图片文件夹地址
下载数据,获取图片文件夹地址
path = untar_data(URLs.PETS)/'images'
Data augmentation
对图片做特定处理
对图片做特定处理
tfms = get_transforms(max_rotate=20, max_zoom=1.3, max_lighting=0.4, max_warp=0.4,
p_affine=1., p_lighting=1.)
查看get_transforms文档
查看get_transforms文档
doc(get_transforms)
构建数据src
构建数据src
src = ImageList.from_folder(path).random_split_by_pct(0.2, seed=2)
创建一个定制函数来构建DataBunch
创建一个定制函数来构建DataBunch
def get_data(size, bs, padding_mode='reflection'):
return (src.label_from_re(r'([^/]+)_\d+.jpg$')
.transform(tfms, size=size, padding_mode=padding_mode)
.databunch(bs=bs).normalize(imagenet_stats))
展示同一张图片的各种变形效果(padding=0)
展示同一张图片的各种变形效果(padding=0)
data = get_data(224, bs, 'zeros')
def _plot(i,j,ax):
x,y = data.train_ds[3]
x.show(ax, y=y)
plot_multi(_plot, 3, 3, figsize=(8,8))
展示同一张图片的各种变形效果(padding=reflection)
展示同一张图片的各种变形效果(padding=reflection)
data = get_data(224,bs)
plot_multi(_plot, 3, 3, figsize=(8,8))
Train a model
释放内存空间
释放内存空间
gc.collect()
用迁移学习构建模型 (bn_final=True)
用迁移学习构建模型 (bn_final=True)
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True)
训练模型 (pct_start=0.8)
训练模型 (pct_start=0.8)
learn.fit_one_cycle(3, slice(1e-2), pct_start=0.8)
Total time: 01:22
epoch |
train_loss |
valid_loss |
error_rate |
1 |
2.573282 |
1.364505 |
0.271989 |
2 |
1.545074 |
0.377077 |
0.094046 |
3 |
0.937992 |
0.270508 |
0.068336 |
解冻,再训练 max_lr=slice(1e-6,1e-3)
解冻,再训练 max_lr=slice(1e-6,1e-3)
learn.unfreeze()
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-3), pct_start=0.8)
Total time: 00:55
epoch |
train_loss |
valid_loss |
error_rate |
1 |
0.721187 |
0.294177 |
0.058187 |
2 |
0.675999 |
0.285875 |
0.050744 |
改变数据的图片大小
改变数据的图片大小
data = get_data(352,bs)
learn.data = data
再训练 max_lr=slice(1e-6,1e-4)
再训练 max_lr=slice(1e-6,1e-4)
learn.fit_one_cycle(2, max_lr=slice(1e-6,1e-4))
Total time: 01:37
epoch |
train_loss |
valid_loss |
error_rate |
1 |
0.627055 |
0.286791 |
0.058863 |
2 |
0.602765 |
0.286951 |
0.058863 |
保存模型
保存模型
learn.save('352')
Convolution kernel
改变数据批量大小 (缩小)
改变数据批量大小 (缩小)
data = get_data(352,16)
加载上次训练的模型
加载上次训练的模型
learn = create_cnn(data, models.resnet34, metrics=error_rate, bn_final=True).load('352')
展示验证集中的第一个数据点(图和label)
展示验证集中的第一个数据点(图和label)
idx=0
x,y = data.valid_ds[idx]
x.show()
data.valid_ds.y[idx]
Category american_pit_bull_terrier
创建一个kernel or filter
创建一个kernel or filter
k = tensor([
[0. ,-5/3,1],
[-5/3,-5/3,1],
[1. ,1 ,1],
]).expand(1,3,3,3)/6
k
tensor([[[[ 0.0000, -0.2778, 0.1667],
[-0.2778, -0.2778, 0.1667],
[ 0.1667, 0.1667, 0.1667]],
[[ 0.0000, -0.2778, 0.1667],
[-0.2778, -0.2778, 0.1667],
[ 0.1667, 0.1667, 0.1667]],
[[ 0.0000, -0.2778, 0.1667],
[-0.2778, -0.2778, 0.1667],
[ 0.1667, 0.1667, 0.1667]]]])
k.shape
torch.Size([1, 3, 3, 3])
从验证数据中提起一个数据点的图片tensor
从验证数据中提起一个数据点的图片tensor
t = data.valid_ds[0][0].data; t.shape
torch.Size([3, 352, 352])
将3D tensor变成4D
将3D tensor变成4D
t[None].shape
torch.Size([1, 3, 352, 352])
对这个4D tensor做filter处理
对这个4D tensor做filter处理
edge = F.conv2d(t[None], k)
显示filter处理结构
显示filter处理结构
show_image(edge[0], figsize=(5,5));
查看data.c
查看data.c
data.c
37
查看模型结构
查看模型结构
learn.model
Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(6): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(4): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(5): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(7): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)
(1): Sequential(
(0): AdaptiveConcatPool2d(
(ap): AdaptiveAvgPool2d(output_size=1)
(mp): AdaptiveMaxPool2d(output_size=1)
)
(1): Flatten()
(2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Dropout(p=0.25)
(4): Linear(in_features=1024, out_features=512, bias=True)
(5): ReLU(inplace)
(6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): Dropout(p=0.5)
(8): Linear(in_features=512, out_features=37, bias=True)
(9): BatchNorm1d(37, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
打印模型总结
打印模型总结
print(learn.summary())
'======================================================================\nLayer (type) Output Shape Param # Trainable \n======================================================================\nConv2d [16, 64, 176, 176] 9408 False \n______________________________________________________________________\nBatchNorm2d [16, 64, 176, 176] 128 True \n______________________________________________________________________\nReLU [16, 64, 176, 176] 0 False \n______________________________________________________________________\nMaxPool2d [16, 64, 88, 88] 0 False \n______________________________________________________________________\nConv2d [16, 64, 88, 88] 36864 False \n______________________________________________________________________\nBatchNorm2d [16, 64, 88, 88] 128 True \n______________________________________________________________________\nReLU [16, 64, 88, 88] 0 False \n______________________________________________________________________\nConv2d [16, 64, 88, 88] 36864 False \n______________________________________________________________________\nBatchNorm2d [16, 64, 88, 88] 128 True \n______________________________________________________________________\nConv2d [16, 64, 88, 88] 36864 False \n______________________________________________________________________\nBatchNorm2d [16, 64, 88, 88] 128 True \n______________________________________________________________________\nReLU [16, 64, 88, 88] 0 False \n______________________________________________________________________\nConv2d [16, 64, 88, 88] 36864 False \n______________________________________________________________________\nBatchNorm2d [16, 64, 88, 88] 128 True \n______________________________________________________________________\nConv2d [16, 64, 88, 88] 36864 False \n______________________________________________________________________\nBatchNorm2d [16, 64, 88, 88] 128 True \n______________________________________________________________________\nReLU [16, 64, 88, 88] 0 False \n______________________________________________________________________\nConv2d [16, 64, 88, 88] 36864 False \n______________________________________________________________________\nBatchNorm2d [16, 64, 88, 88] 128 True \n______________________________________________________________________\nConv2d [16, 128, 44, 44] 73728 False \n______________________________________________________________________\nBatchNorm2d [16, 128, 44, 44] 256 True \n______________________________________________________________________\nReLU [16, 128, 44, 44] 0 False \n______________________________________________________________________\nConv2d [16, 128, 44, 44] 147456 False \n______________________________________________________________________\nBatchNorm2d [16, 128, 44, 44] 256 True \n______________________________________________________________________\nConv2d [16, 128, 44, 44] 8192 False \n______________________________________________________________________\nBatchNorm2d [16, 128, 44, 44] 256 True \n______________________________________________________________________\nConv2d [16, 128, 44, 44] 147456 False \n______________________________________________________________________\nBatchNorm2d [16, 128, 44, 44] 256 True \n______________________________________________________________________\nReLU [16, 128, 44, 44] 0 False \n______________________________________________________________________\nConv2d [16, 128, 44, 44] 147456 False \n______________________________________________________________________\nBatchNorm2d [16, 128, 44, 44] 256 True \n______________________________________________________________________\nConv2d [16, 128, 44, 44] 147456 False \n______________________________________________________________________\nBatchNorm2d [16, 128, 44, 44] 256 True \n______________________________________________________________________\nReLU [16, 128, 44, 44] 0 False \n______________________________________________________________________\nConv2d [16, 128, 44, 44] 147456 False \n______________________________________________________________________\nBatchNorm2d [16, 128, 44, 44] 256 True \n______________________________________________________________________\nConv2d [16, 128, 44, 44] 147456 False \n______________________________________________________________________\nBatchNorm2d [16, 128, 44, 44] 256 True \n______________________________________________________________________\nReLU [16, 128, 44, 44] 0 False \n______________________________________________________________________\nConv2d [16, 128, 44, 44] 147456 False \n______________________________________________________________________\nBatchNorm2d [16, 128, 44, 44] 256 True \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 294912 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nReLU [16, 256, 22, 22] 0 False \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 32768 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nReLU [16, 256, 22, 22] 0 False \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nReLU [16, 256, 22, 22] 0 False \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nReLU [16, 256, 22, 22] 0 False \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nReLU [16, 256, 22, 22] 0 False \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nReLU [16, 256, 22, 22] 0 False \n______________________________________________________________________\nConv2d [16, 256, 22, 22] 589824 False \n______________________________________________________________________\nBatchNorm2d [16, 256, 22, 22] 512 True \n______________________________________________________________________\nConv2d [16, 512, 11, 11] 1179648 False \n______________________________________________________________________\nBatchNorm2d [16, 512, 11, 11] 1024 True \n______________________________________________________________________\nReLU [16, 512, 11, 11] 0 False \n______________________________________________________________________\nConv2d [16, 512, 11, 11] 2359296 False \n______________________________________________________________________\nBatchNorm2d [16, 512, 11, 11] 1024 True \n______________________________________________________________________\nConv2d [16, 512, 11, 11] 131072 False \n______________________________________________________________________\nBatchNorm2d [16, 512, 11, 11] 1024 True \n______________________________________________________________________\nConv2d [16, 512, 11, 11] 2359296 False \n______________________________________________________________________\nBatchNorm2d [16, 512, 11, 11] 1024 True \n______________________________________________________________________\nReLU [16, 512, 11, 11] 0 False \n______________________________________________________________________\nConv2d [16, 512, 11, 11] 2359296 False \n______________________________________________________________________\nBatchNorm2d [16, 512, 11, 11] 1024 True \n______________________________________________________________________\nConv2d [16, 512, 11, 11] 2359296 False \n______________________________________________________________________\nBatchNorm2d [16, 512, 11, 11] 1024 True \n______________________________________________________________________\nReLU [16, 512, 11, 11] 0 False \n______________________________________________________________________\nConv2d [16, 512, 11, 11] 2359296 False \n______________________________________________________________________\nBatchNorm2d [16, 512, 11, 11] 1024 True \n______________________________________________________________________\nAdaptiveAvgPool2d [16, 512, 1, 1] 0 False \n______________________________________________________________________\nAdaptiveMaxPool2d [16, 512, 1, 1] 0 False \n______________________________________________________________________\nFlatten [16, 1024] 0 False \n______________________________________________________________________\nBatchNorm1d [16, 1024] 2048 True \n______________________________________________________________________\nDropout [16, 1024] 0 False \n______________________________________________________________________\nLinear [16, 512] 524800 True \n______________________________________________________________________\nReLU [16, 512] 0 False \n______________________________________________________________________\nBatchNorm1d [16, 512] 1024 True \n______________________________________________________________________\nDropout [16, 512] 0 False \n______________________________________________________________________\nLinear [16, 37] 18981 True \n______________________________________________________________________\nBatchNorm1d [16, 37] 74 True \n______________________________________________________________________\n\nTotal params: 21831599\nTotal trainable params: 563951\nTotal non-trainable params: 21267648\n'
Heatmap
提取模型正向传递计算
提取模型正向传递计算
m = learn.model.eval();
提取一个数据点 (只用X部分)
提取一个数据点 (只用X部分)
xb,_ = data.one_item(x)
对数据点X部分做denormalization处理,在转化为图片格式
对数据点X部分做denormalization处理,在转化为图片格式
xb_im = Image(data.denorm(xb)[0])
对数据点X部分做GPU计算设置
对数据点X部分做GPU计算设置
xb = xb.cuda()
调用callbacks.hooks全部功能
调用callbacks.hooks全部功能
from fastai.callbacks.hooks import *
构建函数提取模型激活层数据
构建函数提取模型激活层数据
def hooked_backward(cat=y):
with hook_output(m[0]) as hook_a:
with hook_output(m[0], grad=True) as hook_g:
preds = m(xb)
preds[0,int(cat)].backward()
return hook_a,hook_g
hook_a,hook_g = hooked_backward()
提取激活层数据,纵向做均值处理
提取激活层数据,纵向做均值处理
acts = hook_a.stored[0].cpu()
acts.shape
torch.Size([512, 11, 11])
avg_acts = acts.mean(0)
avg_acts.shape
torch.Size([11, 11])
构建heatmap作图函数
构建heatmap作图函数
def show_heatmap(hm):
_,ax = plt.subplots()
xb_im.show(ax)
ax.imshow(hm, alpha=0.6, extent=(0,352,352,0),
interpolation='bilinear', cmap='magma');
show_heatmap(avg_acts)
Grad-CAM
论文提出的制作heatmap方法
论文提出的制作heatmap方法
Paper: Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
案例1
案例1
grad = hook_g.stored[0][0].cpu()
grad_chan = grad.mean(1).mean(1)
grad.shape,grad_chan.shape
(torch.Size([512, 11, 11]), torch.Size([512]))
mult = (acts*grad_chan[...,None,None]).mean(0)
show_heatmap(mult)
案例2
案例2
fn = path/'../other/bulldog_maine.jpg' #Replace with your own image
x = open_image(fn); x
xb,_ = data.one_item(x)
xb_im = Image(data.denorm(xb)[0])
xb = xb.cuda()
hook_a,hook_g = hooked_backward()
acts = hook_a.stored[0].cpu()
grad = hook_g.stored[0][0].cpu()
grad_chan = grad.mean(1).mean(1)
mult = (acts*grad_chan[...,None,None]).mean(0)
show_heatmap(mult)
案例3: 通过处理数据类别,heatmap从聚焦猫到了狗
案例3: 通过处理数据类别,heatmap从聚焦猫到了狗
data.classes[0]
'american_bulldog'
hook_a,hook_g = hooked_backward(0)
acts = hook_a.stored[0].cpu()
grad = hook_g.stored[0][0].cpu()
grad_chan = grad.mean(1).mean(1)
mult = (acts*grad_chan[...,None,None]).mean(0)
show_heatmap(mult)