Hey, I didn’t change the model structure - I just modified this line when defining TextData.from_splits from
trn_iter,val_iter = torchtext.data.BucketIterator.splits(splits, batch_size=bs)
to
trn_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits(splits, batch_sizes=(bs, bs, 1))
You can see from the source code here that torchtext.data.BucketIterator.splits actually takes in a batch_sizes tuple argument that defines batch sizes for different datasets.