Learn.save pickle error

I’m training a model using AWS SageMaker. I have my model code inside a train() function. When I try to execute learn.save('path/to/model') I get the error pickler.dump(obj) AttributeError: Can't pickle local object. When I instead place all the train() code inside main(), saving works.

From googling, I think this has to do with pickling needing to be in global namespace? I tried setting my data and learn objects to global within the train() function, but I still got the error.

Any idea how I should save my model inside a function?

Hello @austinmw , do you have some sample code someplace to look at? I routinely put my trainer in a function to loop through different items and have not run into any issues. Here is a sample of one:

def trainer(step):

    model_name = 'step_{}'.format(step)
    src = (SegmentationItemList.from_folder(path_img).split_by_idx(valid_i[step]).label_from_func(get_label, classes=codes))
    data = (src.transform(transforms, size=size, tfm_y=True).databunch(bs=bs).normalize(imagenet_stats))
    learn = unet_learner(data, models.resnet34, metrics=metrics, wd=wd).to_fp16()
    learn.model = torch.nn.DataParallel(learn.model)
    learn.fit_one_cycle(10, slice(lr), pct_start=0.9)
    #learn.fit_one_cycle(3, lrs, pct_start=0.8)

    return learn

@sariabod Thanks for your reply. I can post my example code, but to reproduce you’d need a docker image with fastai built and uploaded to ECR, then have this code ran through that container in sagemaker. I was hoping that it might be a more easy to track down issue, but maybe not :frowning:

# SageMaker CamVid U-Net

import ast
import argparse
import logging
import os

from pathlib import *
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.utils.mem import *
from fastai.callbacks.mem import *
from fastai.callbacks.hooks import *
from fastai.callbacks.tracker import *
import multiprocessing as mp

# ignore the PIL warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

# setup the logger
logger = logging.getLogger(__name__)

# get the host from environment variable
HOSTNAME = os.environ.get('SM_CURRENT_HOST', 'train-host')

def _train(args):

	path = Path(args.data_dir)
	path_lbl = path/'labels'
	path_img = path/'images'
	model_dir = Path(args.model_dir)
	print(f'Model save dir: {model_dir}')
	if verify_images == 1:
		print('verifying images...')
		verify_images(path_img, max_size=int(800))

	fnames = get_image_files(path_img)
	lbl_names = get_image_files(path_lbl)
	get_y_fn = lambda x: path_lbl/f'{x.stem}_P{x.suffix}'
	img_f = fnames[0]
	img = open_image(img_f)
	mask = open_mask(get_y_fn(img_f))
	src_size = np.array(mask.shape[1:])
	codes = np.loadtxt(path/'codes.txt', dtype=str);
	size = src_size//2	

	# the max size of bs depends on the available GPU RAM
	free = gpu_mem_get_free_no_cache()
	print(f'free GPU memory: {free}')
	src = (SegmentationItemList.from_folder(path_img)
	   .label_from_func(get_y_fn, classes=codes))
	data = (src.transform(get_transforms(), size=size, tfm_y=True)
		.databunch(bs=args.bs, num_workers=args.num_workers)
	print(data.classes, data.c, len(data.train_ds), len(data.valid_ds))
	name2id = {v:k for k,v in enumerate(codes)}
	void_code = name2id['Void']

	def acc_camvid(input, target):
		target = target.squeeze(1)
		mask = target != void_code
		return (input.argmax(dim=1)[mask]==target[mask]).float().mean()
	callback_fns = [PeakMemMetric, partial(EarlyStoppingCallback, monitor='accuracy', min_delta=0.01, patience=3)]
	callbacks = [TerminateOnNaNCallback()]
	learn = unet_learner(data, models.resnet34, metrics=metrics, wd=wd, 
						 callbacks=callbacks, callback_fns=callback_fns)
	if args.lr_finder == 1:
		print('starting lr finder...')
		print('GOT HERE: 1')
		fig = learn.recorder.plot(suggestion=True, return_fig=True);
		print('GOT HERE: 2')
		print('GOT HERE: 3')
		min_grad_lr = learn.recorder.min_grad_lr
		lr = min_grad_lr
		lr = 3e-3
	print('starting fit 1...')
	learn.fit_one_cycle(args.epochs, slice(lr), pct_start=0.9)
	print('\n\nGOT HERE: 4\n\n')
	if args.fine_tune == 1:

		if args.lr_finder == 1:
			print('starting lr finder...')
			fig2 = learn.recorder.plot(suggestion=True, return_fig=True)
			min_grad_lr = learn.recorder.min_grad_lr
			lr = min_grad_lr
			lrs = slice(lr/100,lr)
			lrs = slice(lr/400,lr/4)

		print('starting fit 2...')
		learn.fit_one_cycle(12, lrs, pct_start=0.8)

	print(f'Writing classes to model dir')
	save_texts(model_dir/'classes.txt', data.classes)

if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('--verify_images', type=int, default=0,
						help='check each image and delete if unable to read')
	parser.add_argument('--fine_tune', type=int, default=1,
						help='second training stage')
	parser.add_argument('--lr_finder', type=int, default=0,
						help='run lr finder')
	parser.add_argument('--num_workers', type=int, default=mp.cpu_count(), 
						help='number of cpus to use')
	parser.add_argument('--epochs', type=int, default=1, metavar='E',
						help='number of total epochs to run')
	parser.add_argument('--bs', type=int, default=4, metavar='BS',
						help='batch size (default: 64)')
	parser.add_argument('--save_name', type=str, default='model', 
						help='will save as <save_name>.pth')
    # The parameters below retrieve their default values from SageMaker environment variables, which are
	# instantiated by the SageMaker containers framework.
	# https://github.com/aws/sagemaker-containers#how-a-script-is-executed-inside-the-container
	parser.add_argument('--hosts', type=str, default=ast.literal_eval(os.environ['SM_HOSTS']))
	parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
	parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
	parser.add_argument('--data-dir', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
	parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS'])


So if I just place all of the _train code directly inside main it works, but as is, saving causes that pickle error. Hmm I wonder if it has something to do with multiprocessing…

@austinmw I got your code running locally and can now reproduce the error. Unfortunately nothing sticks out as being an issue. I will try to look into more tonight. If you figure it out, please post it as I am very curious.

1 Like

@sariabod Thank you for reproducing and confirming that I’m not crazy! :joy: If I figure it out I’ll definitely let you know.

@austinmw it turns out it is saving the model without showing any errors. You can only see the error when you go to load it again. Even the one I am doing has errors when I try to load the model again. I guess I never really noticed it since it saves without displaying any issues. If you just remove the load model line your script works, you just don’t have a model to load up later as it is “corrupt”. Maybe returning the model then saving it out in main fixes it as a work around.

I think the issue might be partially related to this thread: https://discourse.pymc.io/t/what-does-it-mean-can-not-pickle-due-to-not-finding-function-for-parallel-chain-sampling/1127