Unexpected DataLoader behavior on MNIST_SAMPLE task

For practice, I’ve re-implemented parts of the 04_mnist_basics notebook. However, I run into some unexpected behavior when setting the DataLoader parameter shuffle for the training data loader:

In my example, when shuffle=False everything works fine, when shuffle=True I see no decrease in loss and no increase accuracy. There must be something wrong with the function definitions, but I can’t seem to find it. Does anyone have a pointer for me (sorry for the wall of code)?

Code:

from fastai2.vision.all import *
from utils import *

dir_data = untar_data(URLs.MNIST_SAMPLE, dest=Path('.'))

img_dict = {}
for data_set in ['valid', 'train']:
    img_dict[data_set] = dict(
        (data_cls, [tensor(Image.open(f)).float()/255. for f in (dir_data/data_set/data_cls).iterdir() if f.is_file()])
        for data_cls in ['3', '7']
    )
    
tns_dict = dict((data_set, dict((data_cls, torch.stack(img_dict[data_set][data_cls])) 
                                for data_cls in ['3', '7'])) 
                for data_set in ['valid', 'train'])

x_train = torch.cat([tns_dict['train']['3'].view(-1, 28*28), tns_dict['train']['7'].view(-1, 28*28)])
y_train = tensor([1.]*len(tns_dict['train']['3']) + [0.]*len(tns_dict['train']['7'])).unsqueeze(1)

x_valid = torch.cat([tns_dict['valid']['3'].view(-1, 28*28), tns_dict['valid']['7'].view(-1, 28*28)])
y_valid = tensor([1.]*len(tns_dict['valid']['3']) + [0.]*len(tns_dict['valid']['7'])).unsqueeze(1)

BATCH_SIZE = 256
dl_train = DataLoader(dataset=list(zip(x_train, y_train)), batch_size=BATCH_SIZE, **shuffle=False**)
dl_valid = DataLoader(dataset=list(zip(x_valid, y_valid)), batch_size=BATCH_SIZE, shuffle=False)

def loss_func(yhat, y):
    yhat = yhat.sigmoid()
    return torch.where(y==1., 1-yhat, yhat).mean()

def linear(x, w, b):
    return x@w + b

def train_epoch(dataloader, model, params, lr):
    losses = []
    for xb, yb in dataloader:
        yhat = model(xb, *params)
        loss = loss_func(yhat, yb)
        losses.append(loss.item())
        loss.backward()
        for param in params:
            param.data -= lr*param.grad
            param.grad.zero_()
    return round(np.mean(losses), 4)
            
def valid_epoch(dataloader, model, params):
    def accuracy(yhat, y):
        yhat = yhat.sigmoid()
        correct = ((yhat>.5)==y).float()
        return correct.mean()
    return round(np.mean([accuracy(model(xb, *params), yb) for xb, yb in dataloader]), 2)

EPOCHS = 10
LR = 1.
weights, bias = torch.randn(784).requires_grad_(), torch.randn(1).requires_grad_()
params = weights, bias

for _ in range(EPOCHS):
    train_result = train_epoch(dl_train, linear, params, LR)
    print(f'Mean train loss of epoch {(_+1):02}/{EPOCHS}: {train_result:.2f}')
    valid_result = valid_epoch(dl_valid, linear, params)
    print(f'Mean valid accuracy of epoch {(_+1):02}/{EPOCHS}: {valid_result:.2f}')

Executing with shuffle=True yields:

Mean train loss of epoch 01/20: 0.50
Mean valid accuracy of epoch 01/20: 0.47
Mean train loss of epoch 02/20: 0.50
Mean valid accuracy of epoch 02/20: 0.48
Mean train loss of epoch 03/20: 0.50
Mean valid accuracy of epoch 03/20: 0.48

Executing with shuffle=False yields:

Mean train loss of epoch 01/10: 0.35
Mean valid accuracy of epoch 01/10: 0.72
Mean train loss of epoch 02/10: 0.24
Mean valid accuracy of epoch 02/10: 0.79
Mean train loss of epoch 03/10: 0.19
Mean valid accuracy of epoch 03/10: 0.83