Learner.export is not intended to work if you didn’t build your data using the fastai library (it wants to remember the transforms/preprocessing you applied to be able to do it for you at inference).
You can try to patch a method get_state to TensorDataset that takes some args/kwargs and return an empty dictionary, but I don’t guarantee it will work.

Got it thanks! Actually it looks like every single method in the predict and inference API fails when using pytorch datasets, so maybe that needs to be noted in the docs.

Do you have any clue why the TB callback breaks with RNN’s. Here’s the full model, loss and metric:

class CustomLoss(torch.nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
self.loss = nn.CrossEntropyLoss()
def forward(self, input, target):
#reshape input to (batch x RNN seq len, cat_count), target to [RNN_seq_len] long tensor 0-N cat
score = self.loss.forward(input.view(-1,input.shape[2]), target.view(-1))
return score
def accuracy(input:Tensor, targs:Tensor)->Rank0Tensor:
"Computes accuracy with `targs` when `input` is bs * n_classes."
n = targs.shape[0]
input = input.argmax(dim=-1).view(n,-1)
targs = targs.view(n,-1)
return (input==targs).float().mean()
# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(BiRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
# Forward propagate LSTM
out, _ = self.lstm(x.unsqueeze(-1), (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2)
# Decode the hidden state of the last time step
out = self.fc(out.view(x.shape[0], x.shape[1], -1))
out = F.log_softmax(out, dim=1)
return out