Removing or reducing the amount of skip connections in a unet

Hello everyone!

I am trying to create an autoencoder for image reconstruction on the MVTEC dataset. The idea is to train an autoencoder model to reconstruct images and it is trained with good samples only, so that when i feed the trained model an image with an error/anomaly in it, the reconstruction will be bad and so i can conclude that there must be an anomaly in the image.

I used the unet_learner to try and achieve this, but unfortunately the model reconstructs the images way too perfectly (I assume because of the skip connections in the unet). So when i give my trained model an image of a metal nut with a scratch in it for example, it just perfectly reconstructs the image leaving me unable to find the anomaly (the scratch).

So my question is if there is any way to remove the skip connections from the unet_learner or to reduce how much the model relies on these skip connections (so giving the skip connections less weight/importance or something?). Or is there simply no way around constructing an autoencoder model myself?

Hello,

Yes, the reason U-Net is able to perfectly reconstruct your dataset is that the original input is fed to one of the final layers with nearly no modifications. The purpose of autoencoders is to efficiently encode the input by learning the most important features in the data rather than every single small detail, so U-Net, with its many cross connections, clearly defeats the purpose.

You could remove them, but I believe it’d be simpler to create your own autoencoder. It is easier than you might imagine: Extract any network’s body (the encoder), implement a corresponding upsampler (the decoder), and add a few tweaks here and there for better performance. There are many, many variations you could try out (different CNNs for the encoder, various upsampling blocks, etc.), but on a high-level, they’re very similar. Here’s an example with fastai to get you started.

Have a great weekend!

2 Likes

Thank you!

1 Like

Hey!
I am somewhat new to Fastai and are also curious if there are any easy way of removing some of the skip connections. I am using the unet as a feature-map extractor and therefore would like to see if I would get better/more information in my feature maps if I remove some of the skip connections. I have tried to apply hooks to remove the skip connections, but without luck. Any help out there?