Create Language Model for Chemical Structures

Molecules can be presented as SMILES strings, e.g. CC(Oc1cc(-c2cnn(C3CCNCC3)c2)cnc1N)c1c(Cl)ccc(F)c1Cl.
In analogy to NLP, a molecule would be a sentence, the atom symbols and special characters would be the words.
From a corpus of ca. 500k valid SMILES strings I would like to train a language model. In a next step I would like to use the language model to generate molecules which are similar to bioactive compounds or have desired properties (transfer learning).
I have a SMILES tokenizer to produce a character-level vocabulary for one-hot encoding (approx. 35 characters).

In the lectures we used a pre-trained language model. How can I generate a language model from scratch?
Is the ULMFiT procedure applicable for transfer learning in this special case?

9 Likes

Just create the same architecture, but don’t load the pre-trained weight.

Though, I’d advise to start from simple one-layer LSTM/GRU model and only later make it more complex

1 Like

Please do let me know how you go with this. I’ve heard about some folks that have done this in industry and had great success, but they didn’t publish the results AFAIK - so would be nice to have something that’s open.

3 Likes

I was working on a similar project.
this paper might be interesting.


they also released the code and data here.
https://figshare.com/s/b35a2d3bf442a6f15b6e

1 Like

Thanks for the feedback!

Thanks for the reference!

In the meantime several publications have become available reporting remarkable success applying RNNs to generate valid chemical structures:



I have just recently started the DL MOOC and will definitely try out the fastai library.

2 Likes

Recent work published in J. Chem. In. Model

I would really like to implement something similar with fastai!

@Bill, have you seem this competition: https://www.kaggle.com/c/champs-scalar-coupling/overview ?

Not the exact same as what you are talking about, but in the realm of computational chemistry esp. for drug discovery.

@adinho, Thanks for sending me this!
To be honest, I haven’t seen it.
It looks interesting!

SMILES are great, but they are not necessarily unique (many SMILES strings can reflect the same molecule). SMILES are also a lossy format eg. they don’t encode all information regarding stereochemistry or bond lengths. You’ve probably figured this out by now, but I’d look into using molecular graphs and graph cons.

1 Like

How do we improvise a drug to make it better than the existing version using Deep Learning. How can we redesign a drug that can interfere interaction between two proteins? How we can use Deep Learning to design peptides that can interfere protein/protein interactions. Can you please point me to some good deep learning papers in this space.

… completely agree, have been looking into molecular graphs a while ago.

Can anyone in the group help me with this?

I adopted the UMLFiT method and applied it to chemical properties prediction tasks. I pre-trained a ‘language model’ with 1 million molecules (SMILES) from ChEMBL and then fine-tuned the pre-trained model on other chemical properties prediction tasks.

Inductive Transfer Learning for Molecular Activity Prediction: Next-Gen QSAR Models with MolPMoFiT

4 Likes

@jeremy, I am working extensively on this as a part of my post doc. If you would like any information, I would be happy to share.

@aparente, do you have any example code implementing graph convs for chemical structures in fastai I could take a look at?

@Xinhao, I noticed in your work you enumerated the smiles before training and merged all the files into one training set. It is more customary to apply data augmentation in an online setting. I am currently working on implementing smiles enumeration and sampling as a custom callbacks. Here is my current code:

class SampleSMILES(LearnerCallback):
    def __init__(self, learn:Learner, path, vocab, debug):
        super().__init__(learn)
        self.path, self.vocab, self.debug = path, vocab,debug
        self.encode_dict = MolTokenizer(lang='en').encode_dict
        self.max_seq_length = 150
    def confirm_vocab(self, epoch):
        if( self.learn.data.train_ds.x.vocab != self.vocab):
            print('non equal vocabs in sample smiles on epoch:', epoch )
        else:
            print('we have passed vocab check in sample smiles on epoch:', epoch )
        print('print vocab for epoch end:', epoch, self.vocab.stoi)
    def log_sampler_results(self, smiles, batch_sample, epoch):
        #----log number of valid compounds made on this epoch
        valid = 0                                                                                                                                                                                                     
        for smi in smiles:                                                                                                                                                                                            
            mol = Chem.MolFromSmiles(smi)                                                                                                                                                                             
            if( smiles != '' and mol is not None and mol.GetNumAtoms() > 0 ):                                                                                                                                         
                valid+=1  
        f1 = open(self.path + 'valid_smiles.txt','a')
        if( epoch == 0):
            f1.write('number of valid smiles, batch sample size, epoch' + '\n')
        f1.write( str(valid) + ',' + str(batch_sample) + ',' + str(self.max_seq_length) + ',' + str( epoch ) + '\n')
        f1.close()
        if( self.debug==True):
            print('number of valid compounds:', valid)
    def decode_smi(self, smiles ):
        #---replace encoded tokens with chemicals
        temp_smiles = smiles
        for symbol, token in self.encode_dict.items():
            temp_smiles = temp_smiles.replace(token,symbol)
        return temp_smiles
    def action_to_smiles(self, array, epoch):
        #---convert action tensor to smiles
        smiles_strings = []
        for row in array:
            predicted_chars = []
            for j in row:
                next_char = self.vocab.itos[j.item()]
                if next_char == 'END':
                    break
                predicted_chars.append(next_char)
            smi = ''.join(predicted_chars)
            smi = self.decode_smi(smi)
            smiles_strings.append(smi)
        if( self.debug == True):
            print('we are now writing to file')
            f1 = open(self.path + 'a2s_'+str(epoch) + '.txt','w')
            for smi in smiles_strings:
                if( ' ' in smi):
                    print('we have a space in smi:', smi)
                f1.write( smi.strip() + '\n')
            f1.close()
        return smiles_strings
    def sampler(self,  epoch):
        #---sample batch of compounds at end of epoch
        self.learn.model.eval()
        with torch.no_grad():
            batch_sample=1024
            seqs_gen = ['']*batch_sample
            go_int = learner.data.train_ds.vocab.stoi['GO']
            xb = torch.from_numpy( np.array( [go_int]*batch_sample ) ).to(device='cuda').unsqueeze(1)
            yb = torch.from_numpy( np.array( [go_int]*batch_sample ) ).to(device='cuda').unsqueeze(1)
            actions = torch.zeros((batch_sample, self.max_seq_length), dtype=torch.long).to(device='cuda')
            for i in range(0, self.max_seq_length):
                output = self.learn.model(xb)[0].squeeze()
                output_probs = F.softmax(output, dim=1)
                m = torch.distributions.Multinomial(probs=output_probs)
                action = m.sample()
                idx = action.nonzero()[:,1]
                actions[:,i] = idx
                xb = idx.unsqueeze(1)
        smiles = self.action_to_smiles(actions, epoch)
        self.log_sampler_results( smiles , batch_sample,  epoch)
        self.learn.model.train()
    def export_learner(self, epoch):
        self.learn.export('learner_'+str(epoch) + '.pkl')
    def on_epoch_end(self, **kwargs):
        #===unpack kwargs
        epoch = kwargs['epoch']
        print('beginning sample:', self.max_seq_length, epoch)
        epoch = kwargs['epoch']
        if( self.debug == True):
            self.confirm_vocab(epoch)
            self.export_learner(epoch)
        self.sampler(epoch)
        if(self.debug == True):
            print('mode of model:', self.learn.model.training)
            print('we have completed sampler')

class ShuffleSMILES(LearnerCallback):
    def __init__(self, learn:Learner, path, bs, vocab, tok, train_data, valid_data, debug):
        super().__init__(learn)
        self.path, self.bs, self.vocab, self.tok = path, bs, vocab, tok
        self.train_data, self.valid_data = train_data, valid_data
        self.mol_tok = MolTokenizer(lang='en')
        self.debug = debug
        print('init shuffle smiles:', self.path, self.bs, self.path)
        print('init shuffle smiles vocab:')
        print( self.vocab.itos)
    def confirm_vocab(self, epoch):
        if( self.learn.data.train_ds.x.vocab != self.vocab):
            print('non equal vocabs in shuffle smiles on epoch:', epoch )
        else:
            print('we have passed vocab check in shuffle smiles on epoch:', epoch )
        print('print vocab for epoch:', epoch, self.vocab.stoi)
    def check_vocab(self, smiles):
        tokens = self.tok.process_text(smiles, self.mol_tok)
        for tt in tokens:
            if( tt not in self.vocab.stoi.keys() ):
                return 0
        return 1
    def shuffle_pd(self, df, epoch):
        smiles = list( df.smiles )
        random_smiles = []
        for smi in smiles:
            try:
                mol = Chem.MolFromSmiles(smi)
                new_atom_order = list(range(mol.GetNumAtoms()))
                random.shuffle(new_atom_order)
                random_mol    = Chem.RenumberAtoms(mol, newOrder=new_atom_order)
                random_smi    = Chem.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
                in_vocab = self.check_vocab(random_smi)
                if( in_vocab == 1):
                    random_smiles.append(random_smi)
                else:
                    random_smiles.append(smi)
                    if( self.debug == True):
                        print('we have generated a compound with a new token:', smi, epoch)
            except:
                if( self.debug == True):
                    print('failed for:', smiles)
                random_smiles.append(smi)
        df['random_smiles'] = random_smiles
        return df
    def on_epoch_begin(self, **kwargs):
        #===get state
        epoch = kwargs['epoch']
        print('beginning shuffle on :', epoch)
        #===shuffle and prep data
        self.train_data = self.shuffle_pd(self.train_data, epoch)
        self.valid_data = self.shuffle_pd(self.valid_data, epoch)
        self.train_data['r_len'] = self.train_data.random_smiles.apply(lambda x: len(x) )
        self.valid_data['r_len'] = self.valid_data.random_smiles.apply(lambda x: len(x) )
        print('max length on epoch:', self.train_data.r_len.max(), epoch)
        #===dump our new data frames to files
        if(self.debug == True):
            print('saving data for later analysis')
            self.train_data.to_csv(self.path + 'train_'+str(epoch) + '.csv.gzip',index=None,compression='gzip')
            self.valid_data.to_csv(self.path + 'valid_'+str(epoch) + '.csv.gzip',index=None,compression='gzip')
        #====prepare databunch
        self.newData = TextLMDataBunch.from_df(self.path, self.train_data, self.valid_data, text_cols='random_smiles', bs=self.bs, tokenizer=self.tok, vocab=self.vocab, min_freq=1, include_bos=False, include_eos=False)
        self.learn.data.train_dl.dl, self.learn.data.valid_dl.dl = self.newData.train_dl.dl, self.newData.valid_dl.dl
        if( self.debug == True):
            self.confirm_vocab(epoch)
            self.newData.save( 'newData_' + str(epoch)+'.pkl')
            torch.save( self.learn.data.train_ds.x.vocab.stoi, self.path + '/' + 'vocab_' + str(epoch) + '.pkl')
        print('we have completed on epoch begin')

It is currently giving me bizarre training behavior however. The validation loss in continually decreasing, However, the number of valid molecules sampled after an epoch of training goes up to about 95% by ~epoch 13, and then suddenly starts getting worse each epoch. For example, here are my training stats (epoch, cross entropy training loss, cross entropy validation loss, accuracy):

we have completed on epoch begin
0         0.735757    0.734926    0.734893  58:55     
--
we have completed on epoch begin
1         0.692634    0.692826    0.748313  59:24     
--
we have completed on epoch begin
2         0.677761    0.676239    0.753382  1:00:02   
--
we have completed on epoch begin
3         0.673893    0.674327    0.753982  1:00:17   
--
we have completed on epoch begin
4         0.682018    0.680836    0.751632  1:00:18   
--
we have completed on epoch begin
5         0.690983    0.690529    0.748319  1:00:11   
--
we have completed on epoch begin
6         0.693963    0.692387    0.747694  1:01:02   
--
we have completed on epoch begin
7         0.687516    0.685323    0.749805  1:00:52   
--
we have completed on epoch begin
8         0.677629    0.676371    0.752961  1:00:31   
--
we have completed on epoch begin
9         0.670770    0.668978    0.755372  1:01:01   
--
we have completed on epoch begin
10        0.663913    0.662514    0.757452  1:01:11   
--
we have completed on epoch begin
11        0.657783    0.657987    0.759115  1:01:29   
--
we have completed on epoch begin
12        0.654865    0.652931    0.760745  1:01:13   
--
we have completed on epoch begin
13        0.649611    0.648693    0.761964  1:01:12   
--
we have completed on epoch begin
14        0.646384    0.644538    0.763589  1:01:36   
--
we have completed on epoch begin
15        0.642046    0.640353    0.764990  1:01:32   
--
we have completed on epoch begin
16        0.640199    0.638433    0.765466  1:01:08   
--
we have completed on epoch begin
17        0.637574    0.634706    0.766821  1:01:38   
--
we have completed on epoch begin
18        0.634607    0.633236    0.767360  1:01:24   
--
we have completed on epoch begin
19        0.632348    0.630206    0.768307  1:01:16   

and here are the number of valid smiles generated after each epoch:

number of valid smiles, batch sample size, max seq length, epoch
625,1024,150,0
766,1024,150,1
817,1024,150,2
686,1024,150,3
865,1024,150,4
862,1024,150,5
667,1024,150,6
825,1024,150,7
926,1024,150,8
877,1024,150,9
922,1024,150,10
926,1024,150,11
927,1024,150,12
918,1024,150,13
731,1024,150,14
407,1024,150,15
0,1024,150,16
1,1024,150,17
10,1024,150,18
339,1024,150,19

I am not sure what is going on. If anyone sees a bug in my code, please let me know.

I haven’t seen any molecular graph conv implementations in fastai and have also been interested in this. Deepchem currently is updating to TF2 and a lot of changes will be coming with that, currently the best tool for GCNs that I know of. If you find any implementations in fastai, please let me know!