Keras Weight Tying / Sharing

PyTorch is really nice. I converted from Keras to PyTorch early 2017 and to be honest haven’t really looked back since. Anyhow, I’m trying to do an nlp project in Keras and am having some issues.

In PyTorch, Linear layers very conveniently have their weights initialized like this:

self.weight = Parameter(torch.Tensor(out_features, in_features))

Note how out_features curiously comes first(!). Embedding layers in PyTorch have weights that look like this:

self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim))

This “feature” makes it really easy to share / tie weights between embeddings and linear layers, just as fastai has done with their LinearDecoder by essentially just setting lineardecoder.linear.weights = rnnencoder.embedding.weights and letting PyTorch do the rest.

From my reading, the Keras paradigm to weight sharing is actually layer reuse w/ the functional api. Unfortunately, one cannot simply swap an ‘embedding’ and ‘dense’ layer. To further complicate, keras dense layers have their kernels defined as:

self.kernel = self.add_weight(shape=(input_dim, self.units), .....

So even if one could theoretically just assign weights from one layer to another (which we can’t in Keras), there would still be some issues like transposing.

Anyhow, my question is, have any of you or do any of you know how to share / tie weights between a Keras embedding and dense layer similar to how fastai has in their language models?

from keras.layers import Layer
import keras.backend as K
from keras import activations

class TiedEmbeddingsTransposed(Layer):
“”“Layer for tying embeddings in an output layer.
A regular embedding layer has the shape: V x H (V: size of the vocabulary. H: size of the projected space).
In this layer, we’ll go: H x V.
With the same weights than the regular embedding.
In addition, it may have an activation.
# References
- Using the Output Embedding to Improve Language Models
“””

def __init__(self, tied_to=None,
             activation=None,
             **kwargs):
    super(TiedEmbeddingsTransposed, self).__init__(**kwargs)
    self.tied_to = tied_to
    self.activation = activations.get(activation)

def build(self, input_shape):
    self.transposed_weights = K.transpose(self.tied_to.weights[0])
    self.built = True

def compute_mask(self, inputs, mask=None):
    return mask

def compute_output_shape(self, input_shape):
    return input_shape[0], K.int_shape(self.tied_to.weights[0])[0]

def call(self, inputs, mask=None):
    output = K.dot(inputs, self.transposed_weights)
    if self.activation is not None:
        output = self.activation(output)
    return output


def get_config(self):
    config = {'activation': activations.serialize(self.activation)
              }
    base_config = super(TiedEmbeddingsTransposed, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))