Training BERT like model from scratch by using fastai Learner - need help with setting up masked tokens

I am trying to create a custom model similar to BERT (not for NLP), and I have been operating in Pytorch. I’m hoping to use the Learner wrapper class on the final model for training so I can use fit_one_cycle and freeze etc. However there is a tricky part to the training related to how BERT does token masking for predictions. The below code will fire off model training much like fit_one_cycle will, and I need to implement the masking section (where I call the mask_tokens function/method) somewhere into Fastai code such that I can get it to train the way I want, and I’m struggling to navigate through the code (my experience with python is limited to hacking existing and make minor changes to do what I need it to). Can someone please provide some suggestions.

def train(train_dataloader, model, config):
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(arg.num_train_epochs), desc="Epoch", disable=arg.local_rank not in [-1, 0])
    set_seed(arg.seed_value, arg.n_gpu)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=arg.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            batch = batch.float()  # to resolve the runtime error
            model.train()

            inputs_, labels, amounts, masked_indices = mask_tokens(batch, tok, config)**
            inputs = {"input_data": inputs_,**
                      "labels": labels,**
                      "amounts": amounts}**

            outputs = model(**inputs)**
            loss = outputs[0]  

            if arg.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel (not distributed) training
            if arg.gradient_accumulation_steps > 1:
                loss = loss / arg.gradient_accumulation_steps

            if arg.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            print(loss)


            tr_loss += loss.item()
            if (step + 1) % arg.gradient_accumulation_steps == 0:
                if arg.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), arg.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), arg.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1


                if arg.local_rank in [-1, 0] and arg.save_steps > 0 and global_step % arg.save_steps == 0:
                    # Save model checkpoint
                    output_dir_ = os.path.join(output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir_):
                        os.makedirs(output_dir_)
                    model_to_save = (model.module if hasattr(model, "module") else model)  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir_)

            if arg.max_steps > 0 and global_step > arg.max_steps:
                epoch_iterator.close()
                break
        if arg.max_steps > 0 and global_step > arg.max_steps:
            train_iterator.close()
            break

Here is code for masked_token function/method (I have no idea what these things are called in Python)

def mask_tokens(inputs: torch.Tensor, tok, config) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

    mlm_probability: float = 0.15

        

    input_ids = inputs[:,:,0]

#     print(input_ids[0])

    input_amounts = inputs[:,:,1]

    

    device = input_ids.device

    x, y = input_ids.size()

        

    labels = input_ids.clone().detach()

        

    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

    probability_matrix = torch.full(labels.shape, mlm_probability, device=device)

    special_tokens_mask = torch.tensor(

                tok.get_special_tokens_mask(labels, already_has_special_tokens=True),

                dtype=torch.bool,

                device=device

            )

    probability_matrix.masked_fill_(torch.as_tensor(special_tokens_mask, dtype=torch.bool, device=device), value=0.0)

#     print('probability_matrix',probability_matrix[0])

    masked_indices = torch.bernoulli(probability_matrix).bool().to(device=device)   

#     print('masked_indices',masked_indices[0])

    

    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])

#     print(torch.bernoulli(torch.full(labels.shape, 0.8)).bool()[0])

    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool().to(device=device) & masked_indices

#     print('indices_replaced', indices_replaced[0])

    input_ids[indices_replaced] = tok.special_token_IDs[tok.mask_token]

    

    # 10% of the time, we replace masked input tokens with random word

    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool().to(device=device) & masked_indices & ~indices_replaced

    random_words = torch.randint(config.FS_size, labels.shape, dtype=torch.float).cuda()  # consider all tok ids excluding special toks      

#     print('indices_random', indices_random[0])    

    input_ids[indices_random] = random_words[indices_random]

    # The rest of the time (10% of the time) we keep the masked input tokens unchanged

    

#     print('input_ids', input_ids[0])

    

#     print("input_amounts", input_amounts[0])

#     print(input_amounts.device, masked_indices.device)

#     print(masked_indices.float()[0])

    input_amounts = input_amounts*masked_indices.float()

    input_amounts[filter_values(input_amounts,[0])] = np.nan

#     print("input_amounts", input_amounts[0]) 

    

    outputs = torch.cat([input_ids.reshape(x,y,1), input_amounts.reshape(x,y,1)], dim=-1)

#     print('outputs',outputs[0])

    return outputs, labels.long(), input_amounts, masked_indices