My “final” version with no pandas dependencies. Pretty minimal, if I do say so myself:
import csv
import torch
from collections import OrderedDict
from fastai.basic_train import Learner
defaultlr = 1e-3
def write_encoding_dict(filename, learner, cat_names, cat):
classes = learner.data.label_list.train.x.classes[cat]
weight_matrix = learner.model.embeds[cat_names.index(cat)].weight
with open(filename, 'w') as csvFile:
writer = csv.writer(csvFile, lineterminator='\n')
for i in range(len(classes)):
writer.writerow([classes[i],*weight_matrix[i].tolist()])
def get_encoding_dict(filename):
with open(filename, 'r') as csvFile:
reader = csv.reader(csvFile)
lines = list(reader)
d = OrderedDict()
for i in range(len(lines)):
d[lines[i][0]] = [float(lines[i][j]) for j in range(1,len(lines[i]))]
return d
def load_embed_weights(filename, learner, cat_names, cat):
encodings = get_encoding_dict(filename)
classes = learner.data.label_list.train.x.classes[cat]
weight_matrix = learner.model.embeds[cat_names.index(cat)].weight
emb_dim=weight_matrix.shape[1]
with torch.no_grad():
for i, value in enumerate(classes):
try:
enc = encodings[value]
for j in range(emb_dim):
weight_matrix[i][j] = enc[j]
except KeyError:
for j in range(emb_dim):
weight_matrix[i][j] = np.random.normal(scale=0.6)
def freeze_embedding(learner:Learner,index=0):
learner.model.embeds[index].weight.requires_grad = False
learner.create_opt(defaultlr)
def unfreeze_embedding(learner:Learner,index=0):
learner.model.embeds[index].weight.requires_grad = True
learner.create_opt(defaultlr)