I grabbed a few images to share that show the improvements in the results after implementing new features and running experiments on the Style Transfer notebook. The first 3 images show my progression and the last 2 images are using different style images.
Some of the things I’ve implemented or tried so far:
- PyTorch forward hooks for feature extraction
- Replacing max pooling with average pooling in VGG16 (Paper)
- Weighting the style loss more heavily (multiplying by 100-500 so far seems to work best) (Paper)
- Dividing each layers component of the loss by the total number of layers used in the loss (Paper)
- Using the same layers for both style
(1, 6, 11, 18, 25)
and content loss (same plus 29) but weighting the losses by multiplying each layers style loss by1 / (layer_idx+1)
and content loss by1 / (n_layers - layer_idx)
. Essentially weighting the beginning layers higher for style loss and ending layers higher for content loss which is the general recommendation in the paper, though they accomplish this by committing layers.
#Weighted
class StyleLossToTarget():
def __init__(self, target_im, target_layers=(1, 6, 11, 18, 25)):
fc.store_attr()
with torch.no_grad(): self.target_grams = calc_grams(target_im, target_layers)
def __call__(self, input_im):
return sum((f1-f2).pow(2).mean()/len(self.target_layers)*(1./(idx+1)) for idx, (f1, f2) in
enumerate(zip(calc_grams(input_im, self.target_layers), self.target_grams)))
class ContentLossToTarget():
def __init__(self, target_im, target_layers=(1, 6, 11, 18, 25, 29)):
fc.store_attr()
with torch.no_grad():
self.target_features = calc_features(target_im, target_layers)
def __call__(self, input_im):
return sum((f1-f2).pow(2).mean()/len(self.target_layers)*(1./(len(self.target_layers)-idx)) for idx, (f1, f2) in
enumerate(zip(calc_features(input_im, self.target_layers), self.target_features)))
- Increasing the steps to 1,200
- Sometimes starting with noise, sometimes the content image and sometimes a combination of the 2. Ex:
noise = torch.rand_like(content_im)/2.
model = TensorModel(noise+(content_im-noise.mean()))
Progression w/ Spider Web Style
Farewell To Anger Style
Starry Night Style