Custom ItemList init args are None when re-creating learner via "load_learner()"

Here is my init code for my custom ItemList:

    def __init__(self, items:Iterator, tokenization_config:TokenizationConfig=None, max_seq_len=128,
                 id_col='row_id', items_col='items', procs=None, **kwargs) -> 'HFT_ItemList':
        pdb.set_trace()
        #dataframe is in inner_df, items is just a range of index
        super().__init__(items=items, **kwargs)
        
        self.tokenization_config = tokenization_config
        self.max_seq_len = max_seq_len
        
        self.id_col = id_col
        self.items_col = items_col
        
        self.procs = procs
    
        # add any ItemList state into "copy_new" that needs to be copied each time "new()" is called; 
        # your ItemList acts as a prototype for training, validation, and/or test ItemList instances that
        # are created via ItemList.new()
        self.copy_new += [ 'id_col', 'items_col', 'tokenization_config', 'max_seq_len', 'procs' ]
        
        self.preprocessed = False

When I later export reload the learner via learn.export() and load_learner() … the tokenization_config object is None when I try to do a learn.predict()

Any ideas on how to fix?

Mmm, not sure. It’s been a while since I dug into the v1 code and v2 works way better by pickling everything. Are you 100% positive it wasn’t None before the exporting?

Yah everything is there before exporting … the problems are only with with the custom ItemList I’m using for my labels (my custom ItemList for my inputs works fine) …

    def __init__(self, items:Iterator, labels_col:str=None, classes:Collection=None,
                 ignore_token:str='[xIGNx]', ignore_token_id:int=CrossEntropyLoss().ignore_index,
                 **kwargs):
        
        super().__init__(items=items, **kwargs)   
        
        self.labels_col = labels_col
        self.classes = classes
        self.ignore_token = ignore_token
        self.ignore_token_id = ignore_token_id
        
        # add any ItemList state into "copy_new" that needs to be copied each time "new()" is called; 
        # your ItemList acts as a prototype for training, validation, and/or test ItemList instances that
        # are created via ItemList.new()
        self.copy_new += [ 'labels_col', 'classes', 'ignore_token', 'ignore_token_id']
        

When I attempt to reconstruct …

    def reconstruct(self, t, x):
        pdb.set_trace()
        raw_input_items = [[ el for el in x.items if el not in ['[PAD]', '[CLS]', '[SEP]'] ]]
        raw_label_items = [['O'] * len(raw_input_items[0])]
        
        df = pd.DataFrame({ self.x.id_col: x.row_id, self.x.items_col: raw_input_items, 
                           self.labels_col: raw_label_items })
        
        features = hft_process_token_seq_targs(df, self.x.id_col, self.x.items_col,  
                                               self.labels_col, self.classes,
                                               self.x.tokenization_config, self.x.max_seq_len)
        ...

self.labels_col is None (even though I set it to “labels” when training) and self.x.tokenization_config is None as well.

SOLVED

(I think)

It appears that the rehydration of init args does not work for ItemList classes used as labels (works as expected for custom ItemLists used for your inputs).

Anyhow, after digging around I figured out a way to save the state required in my custom PreProcessor class …

class HFT_TokenSeqTargetCategoryProcessor(CategoryProcessor):
    
    "`PreProcessor` that tokenizes the items in `ds`."
    
    def __init__(self, ds:ItemList, labels_col:str=None, config:TokenizationConfig=None, max_seq_len=None):   

        super().__init__(ds)
        
        self.labels_col = ifnone(labels_col, ds.labels_col)
        self.tokenization_config = ifnone(config, ds.x.tokenization_config)
        self.max_seq_len = ifnone(max_seq_len, ds.x.max_seq_len)
        
        # we need to add these "state_attrs" so we can get them back at inference (the init args
        # for ItemLists used as labels do not get rehydrated as they should)
        self.state_attrs += ['labels_col', 'tokenization_config', 'max_seq_len']
        ...

… and then in process() you make sure that your ds has those state_attrs set as such:

    def process(self, ds):
        # x.inner_df will be None during inference but we still need labels_col and
        # tokenization_config ... WHICH unfortunately aren't rehydrated like the INPUT ItemList inits
        ds.labels_col = self.labels_col
        ds.tokenization_config = self.tokenization_config
        ds.max_seq_len = self.max_seq_len
        ...