@jeremy, I am working extensively on this as a part of my post doc. If you would like any information, I would be happy to share.
@aparente, do you have any example code implementing graph convs for chemical structures in fastai I could take a look at?
@Xinhao, I noticed in your work you enumerated the smiles before training and merged all the files into one training set. It is more customary to apply data augmentation in an online setting. I am currently working on implementing smiles enumeration and sampling as a custom callbacks. Here is my current code:
class SampleSMILES(LearnerCallback):
def __init__(self, learn:Learner, path, vocab, debug):
super().__init__(learn)
self.path, self.vocab, self.debug = path, vocab,debug
self.encode_dict = MolTokenizer(lang='en').encode_dict
self.max_seq_length = 150
def confirm_vocab(self, epoch):
if( self.learn.data.train_ds.x.vocab != self.vocab):
print('non equal vocabs in sample smiles on epoch:', epoch )
else:
print('we have passed vocab check in sample smiles on epoch:', epoch )
print('print vocab for epoch end:', epoch, self.vocab.stoi)
def log_sampler_results(self, smiles, batch_sample, epoch):
#----log number of valid compounds made on this epoch
valid = 0
for smi in smiles:
mol = Chem.MolFromSmiles(smi)
if( smiles != '' and mol is not None and mol.GetNumAtoms() > 0 ):
valid+=1
f1 = open(self.path + 'valid_smiles.txt','a')
if( epoch == 0):
f1.write('number of valid smiles, batch sample size, epoch' + '\n')
f1.write( str(valid) + ',' + str(batch_sample) + ',' + str(self.max_seq_length) + ',' + str( epoch ) + '\n')
f1.close()
if( self.debug==True):
print('number of valid compounds:', valid)
def decode_smi(self, smiles ):
#---replace encoded tokens with chemicals
temp_smiles = smiles
for symbol, token in self.encode_dict.items():
temp_smiles = temp_smiles.replace(token,symbol)
return temp_smiles
def action_to_smiles(self, array, epoch):
#---convert action tensor to smiles
smiles_strings = []
for row in array:
predicted_chars = []
for j in row:
next_char = self.vocab.itos[j.item()]
if next_char == 'END':
break
predicted_chars.append(next_char)
smi = ''.join(predicted_chars)
smi = self.decode_smi(smi)
smiles_strings.append(smi)
if( self.debug == True):
print('we are now writing to file')
f1 = open(self.path + 'a2s_'+str(epoch) + '.txt','w')
for smi in smiles_strings:
if( ' ' in smi):
print('we have a space in smi:', smi)
f1.write( smi.strip() + '\n')
f1.close()
return smiles_strings
def sampler(self, epoch):
#---sample batch of compounds at end of epoch
self.learn.model.eval()
with torch.no_grad():
batch_sample=1024
seqs_gen = ['']*batch_sample
go_int = learner.data.train_ds.vocab.stoi['GO']
xb = torch.from_numpy( np.array( [go_int]*batch_sample ) ).to(device='cuda').unsqueeze(1)
yb = torch.from_numpy( np.array( [go_int]*batch_sample ) ).to(device='cuda').unsqueeze(1)
actions = torch.zeros((batch_sample, self.max_seq_length), dtype=torch.long).to(device='cuda')
for i in range(0, self.max_seq_length):
output = self.learn.model(xb)[0].squeeze()
output_probs = F.softmax(output, dim=1)
m = torch.distributions.Multinomial(probs=output_probs)
action = m.sample()
idx = action.nonzero()[:,1]
actions[:,i] = idx
xb = idx.unsqueeze(1)
smiles = self.action_to_smiles(actions, epoch)
self.log_sampler_results( smiles , batch_sample, epoch)
self.learn.model.train()
def export_learner(self, epoch):
self.learn.export('learner_'+str(epoch) + '.pkl')
def on_epoch_end(self, **kwargs):
#===unpack kwargs
epoch = kwargs['epoch']
print('beginning sample:', self.max_seq_length, epoch)
epoch = kwargs['epoch']
if( self.debug == True):
self.confirm_vocab(epoch)
self.export_learner(epoch)
self.sampler(epoch)
if(self.debug == True):
print('mode of model:', self.learn.model.training)
print('we have completed sampler')
class ShuffleSMILES(LearnerCallback):
def __init__(self, learn:Learner, path, bs, vocab, tok, train_data, valid_data, debug):
super().__init__(learn)
self.path, self.bs, self.vocab, self.tok = path, bs, vocab, tok
self.train_data, self.valid_data = train_data, valid_data
self.mol_tok = MolTokenizer(lang='en')
self.debug = debug
print('init shuffle smiles:', self.path, self.bs, self.path)
print('init shuffle smiles vocab:')
print( self.vocab.itos)
def confirm_vocab(self, epoch):
if( self.learn.data.train_ds.x.vocab != self.vocab):
print('non equal vocabs in shuffle smiles on epoch:', epoch )
else:
print('we have passed vocab check in shuffle smiles on epoch:', epoch )
print('print vocab for epoch:', epoch, self.vocab.stoi)
def check_vocab(self, smiles):
tokens = self.tok.process_text(smiles, self.mol_tok)
for tt in tokens:
if( tt not in self.vocab.stoi.keys() ):
return 0
return 1
def shuffle_pd(self, df, epoch):
smiles = list( df.smiles )
random_smiles = []
for smi in smiles:
try:
mol = Chem.MolFromSmiles(smi)
new_atom_order = list(range(mol.GetNumAtoms()))
random.shuffle(new_atom_order)
random_mol = Chem.RenumberAtoms(mol, newOrder=new_atom_order)
random_smi = Chem.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
in_vocab = self.check_vocab(random_smi)
if( in_vocab == 1):
random_smiles.append(random_smi)
else:
random_smiles.append(smi)
if( self.debug == True):
print('we have generated a compound with a new token:', smi, epoch)
except:
if( self.debug == True):
print('failed for:', smiles)
random_smiles.append(smi)
df['random_smiles'] = random_smiles
return df
def on_epoch_begin(self, **kwargs):
#===get state
epoch = kwargs['epoch']
print('beginning shuffle on :', epoch)
#===shuffle and prep data
self.train_data = self.shuffle_pd(self.train_data, epoch)
self.valid_data = self.shuffle_pd(self.valid_data, epoch)
self.train_data['r_len'] = self.train_data.random_smiles.apply(lambda x: len(x) )
self.valid_data['r_len'] = self.valid_data.random_smiles.apply(lambda x: len(x) )
print('max length on epoch:', self.train_data.r_len.max(), epoch)
#===dump our new data frames to files
if(self.debug == True):
print('saving data for later analysis')
self.train_data.to_csv(self.path + 'train_'+str(epoch) + '.csv.gzip',index=None,compression='gzip')
self.valid_data.to_csv(self.path + 'valid_'+str(epoch) + '.csv.gzip',index=None,compression='gzip')
#====prepare databunch
self.newData = TextLMDataBunch.from_df(self.path, self.train_data, self.valid_data, text_cols='random_smiles', bs=self.bs, tokenizer=self.tok, vocab=self.vocab, min_freq=1, include_bos=False, include_eos=False)
self.learn.data.train_dl.dl, self.learn.data.valid_dl.dl = self.newData.train_dl.dl, self.newData.valid_dl.dl
if( self.debug == True):
self.confirm_vocab(epoch)
self.newData.save( 'newData_' + str(epoch)+'.pkl')
torch.save( self.learn.data.train_ds.x.vocab.stoi, self.path + '/' + 'vocab_' + str(epoch) + '.pkl')
print('we have completed on epoch begin')
It is currently giving me bizarre training behavior however. The validation loss in continually decreasing, However, the number of valid molecules sampled after an epoch of training goes up to about 95% by ~epoch 13, and then suddenly starts getting worse each epoch. For example, here are my training stats (epoch, cross entropy training loss, cross entropy validation loss, accuracy):
we have completed on epoch begin
0 0.735757 0.734926 0.734893 58:55
--
we have completed on epoch begin
1 0.692634 0.692826 0.748313 59:24
--
we have completed on epoch begin
2 0.677761 0.676239 0.753382 1:00:02
--
we have completed on epoch begin
3 0.673893 0.674327 0.753982 1:00:17
--
we have completed on epoch begin
4 0.682018 0.680836 0.751632 1:00:18
--
we have completed on epoch begin
5 0.690983 0.690529 0.748319 1:00:11
--
we have completed on epoch begin
6 0.693963 0.692387 0.747694 1:01:02
--
we have completed on epoch begin
7 0.687516 0.685323 0.749805 1:00:52
--
we have completed on epoch begin
8 0.677629 0.676371 0.752961 1:00:31
--
we have completed on epoch begin
9 0.670770 0.668978 0.755372 1:01:01
--
we have completed on epoch begin
10 0.663913 0.662514 0.757452 1:01:11
--
we have completed on epoch begin
11 0.657783 0.657987 0.759115 1:01:29
--
we have completed on epoch begin
12 0.654865 0.652931 0.760745 1:01:13
--
we have completed on epoch begin
13 0.649611 0.648693 0.761964 1:01:12
--
we have completed on epoch begin
14 0.646384 0.644538 0.763589 1:01:36
--
we have completed on epoch begin
15 0.642046 0.640353 0.764990 1:01:32
--
we have completed on epoch begin
16 0.640199 0.638433 0.765466 1:01:08
--
we have completed on epoch begin
17 0.637574 0.634706 0.766821 1:01:38
--
we have completed on epoch begin
18 0.634607 0.633236 0.767360 1:01:24
--
we have completed on epoch begin
19 0.632348 0.630206 0.768307 1:01:16
and here are the number of valid smiles generated after each epoch:
number of valid smiles, batch sample size, max seq length, epoch
625,1024,150,0
766,1024,150,1
817,1024,150,2
686,1024,150,3
865,1024,150,4
862,1024,150,5
667,1024,150,6
825,1024,150,7
926,1024,150,8
877,1024,150,9
922,1024,150,10
926,1024,150,11
927,1024,150,12
918,1024,150,13
731,1024,150,14
407,1024,150,15
0,1024,150,16
1,1024,150,17
10,1024,150,18
339,1024,150,19
I am not sure what is going on. If anyone sees a bug in my code, please let me know.