I am trying to perform colorization by perceptual loss(written by PyTorch), but the results are bad(failure). Following are my codes(all of the codes put at here) with explanation, please give me some suggestions.
perceptual_colorize.py in charge of training
import bcolz
import numpy as np
import sys
sys.path.insert(0, '../utils') #I put BcolzArrayIterator in utils
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import torchvision
from bcolz_array_iterator import BcolzArrayIterator
from color_converter import ycrcb_to_rgb, ycrcb_to_rgb_torch
from colorize_net_0 import colorization_net
from loss_net import vgg_net
from torch.autograd import Variable
from torchvision import datasets, models, transforms
from tqdm import tqdm
#lumi array contain y channel of YCrCb, converted from the formula of opencv, 72*72*1
lumi_arr = bcolz.open('image_net_full_opencv_y_72.bc')
#contain rgb channels, 72*72*3
origin_arr = bcolz.open('/home/ramsus/Qt/computer_vision_dataset/super_res/image_net_full_72.bc')
mse_loss = torch.nn.MSELoss()
#This network in charge of perceptual loss
#5, 10, 17 are the convolution layer after max_pool layer
#before relu activation
vgg_loss = vgg_net([5, 10, 17])
#This is the transform net we want to train
transform_net = colorization_net()
vgg_loss.cuda()
transform_net.cuda()
def preprocess_input(input_data):
input_data = input_data.transpose((0,3,1,2))
input_data = torch.from_numpy(input_data)
return Variable(input_data.cuda())
#preprocess the input of perceptual net(pretrained vgg16)
def preprocess_perceptual_input(input_data):
input_data = input_data / 255.0
#mean = np.array([0.485, 0.456, 0.406])
#std = np.array([0.229, 0.224, 0.225])
input_data[:, 0, :, :] = (input_data[:, 0, :, :] - 0.485)/0.229
input_data[:, 1, :, :] = (input_data[:, 1, :, :] - 0.456)/0.224
input_data[:, 2, :, :] = (input_data[:, 2, :, :] - 0.406)/0.225
return input_data
def train(epoch, niter, weights = [0.2, 0.8, 0.2], lr_rate = [0.01, 0.001]):
#set network at training mode, has effect on dropout and batchnorm
transform_net.train()
loss_len = len(weights)
for e in range(epoch):
optimizer = optim.Adam(transform_net.parameters(), lr_rate[e])
batch_size = 16
agg_loss = 0 #aggregate loss
bc = BcolzArrayIterator(lumi_arr, origin_arr, batch_size = batch_size)
for i in range(niter):
#make gradient as zero
optimizer.zero_grad()
#lumi_bc contain 16 image with y channel, origin_bc contain 16 images with rgb channels
lumi_bc, origin_bc = next(bc)
#shape of lumi_bc and origin_bc are (16,72,72,3) but PyTorch expect (16,3,72,72)
#therefore we need to do some preprocessing
lumi_bc = preprocess_input(lumi_bc)
origin_bc = preprocess_input(origin_bc)
#output predicted rgb in [0, 255]
predict_rgb = transform_net(lumi_bc)
predict_features = vgg_loss(preprocess_perceptual_input(predict_rgb))
origin_features = vgg_loss(preprocess_perceptual_input(origin_bc))
#this is the preceptual loss
loss = 0
for j in range(loss_len):
loss += mse_loss(predict_features[j], origin_features[j]) * weights[j]
#print some info
agg_loss += loss.data[0]
print("epoch {}, {} iteration, agg loss={}, loss={}".format(e, i, agg_loss / (i+1), loss.data[0]))
#backward, compute gradient and optimize the weights
loss.backward()
optimizer.step()
torch.save(transform_net.state_dict(), 'color_transform_net_epoch_{}_iter_{}'.format(e+1, niter))
#I try with learning rate 0.01, 0.001
train(2, 2000, [0.2, 0.8, 0.2], [0.01, 0.001])
convert ycrcb to rgb, tested, it work as expected
def ycrcb_to_rgb_torch(input_tensor, delta = 0.5):
y, cr, cb = input_tensor[:,0,:,:], input_tensor[:,1,:,:], input_tensor[:,2,:,:]
r = torch.unsqueeze(y + 1.403 * (cr - delta), 1)
g = torch.unsqueeze(y - 0.714 * (cr - delta) - 0.344 * (cb - delta), 1)
b = torch.unsqueeze(y + 1.773 * (cb - delta), 1)
return torch.cat([r, g, b], 1)
colorize_net_0.py, this is the transform net, accept Y channel within [0, 255] and convert to RGB channels. I try with three approach
1 : this solution assume y would convert to cb,cr->concat y,cb,cr->convert ycbcr to rgb->clamp the rgb value to [0, 255](un-clamp fail as well)
2 : assume y would convert to cb, cr->concat y,cb, cr->activate by tanh-> adjust to [0, 255]
3 : Assume y would convert to r,g,b->adjust tanh activation to [0, 255]
from color_converter import ycrcb_to_rgb_torch
import torch
import torch.nn as nn
class conv_block(nn.Module):
def __init__(self, in_size, out_size, kernal = 3, stride = 1, activate = True):
super(conv_block, self).__init__()
padding = (kernal-1)//2
self._m_conv = nn.Sequential(
nn.Conv2d(in_size, out_size, kernal, stride, padding),
nn.BatchNorm2d(out_size)
)
if activate:
self._m_conv.add_module('conv_block_relu', nn.ReLU())
def forward(self, x):
out = self._m_conv(x)
return out
class res_block(nn.Module):
def __init__(self, in_size, out_size, kernal = 3, stride = 1):
super(res_block, self).__init__()
self._m_conv1 = conv_block(in_size, out_size, kernal, stride)
self._m_conv2 = conv_block(out_size, out_size, kernal, stride)
def forward(self, x):
out = self._m_conv1(x)
out = self._m_conv2(out)
out = out + x
return out
"""
This net accept Y channel of YCrCb(convert from rgb image which follow the formula from opencv)
within [0,255].
"""
class colorization_net(nn.Module):
def __init__(self):
super(colorization_net, self).__init__()
self._m_conv1 = conv_block(1, 64, 9)
self._m_res1 = res_block(64, 64)
self._m_res2 = res_block(64, 64)
self._m_res3 = res_block(64, 64)
self._m_res4 = res_block(64, 64)
self._m_conv2 = nn.Sequential(
#nn.Convd2d(64, 2, 9, 1, 4), #for solution 1 and 2
nn.Conv2d(64, 3, 9, 1, 4), #for solution 3
nn.Tanh()
)
self._m_tanh = nn.Tanh()
def forward(self, x):
out = self._m_conv1(x)
out = self._m_res1(out)
out = self._m_res2(out)
out = self._m_res3(out)
out = self._m_res4(out)
out = self._m_conv2(out)
out = out * 127 + 128 #out predict CrCb for solution 1 and 2, predict rgb for solution 3
#out = torch.cat([x, out], 1) #Now out become YCrCb
#I give 3 solutions a try, none of them work
#Solution 1 : clamp value to [0,255]
#Converted value may not fall in [0,255], so I clamp it to [0, 255]
#but this solution will generate image with the image only with one color
#If I do not clamp it, generated color is very weird
#This function convert ycrcb to rgb, follow the formula from opencv
#I transform out back to rgb because vgg expect the color space of input as RGB
#out = ycrcb_to_rgb_torch(out, 128)
#out = out.clamp(0, 255)
#Solution 2 : instead of clamp the output to [0,255]
#pass it to activation again, this generate image
#close to plain white
#out = ycrcb_to_rgb_torch(out, 128)
#out = self._m_tanh(out)
#out = out * 127 + 128
#Solution 3 : I change the lass convolution net from
#nn.Conv2d(64, 2, 9, 1, 4) to nn.Conv2d(64, 3, 9, 1, 4)
#assume the net will generate the rgb image directly
return out