How to use weights of Tensorflow model in pytorch?

There are many research papers that open-source their code and use tensorflow. The weights that they provide are in the form of checkpoint. So is there any way to convert these ways to pytorch weights.

I think we can do it manually, like opening the checkpoint file in tensorflow and kind of getting all values of the weights in dict or something and then copying them to the model_dict of pytorch model.

Is there some other way of doing it?

1 Like

There’s no tricks to it - you just have to dump the weights and reload them. It’s quite time consuming! Lots have already been done here BTW: https://github.com/Cadene/pretrained-models.pytorch/

1 Like