Thanks for the help! I’ve gotten my GAN learner to export at this point after refactoring some of the Fastai code, though I’m running into some recursion issues in Python when I try to import my model. If there’s anything that immediately stands out, here’s my code (that I think) is causing the problem and stack trace:
class testGANLoss(GANModule):
"Wrapper around `loss_funcC` (for the critic) and `loss_funcG` (for the generator)."
def __init__(self, loss_funcG:Callable, loss_funcC:Callable, gan_model:GANModule,
loss_gen, loss_crit, weights_gen):
super().__init__()
self.loss_funcG,self.loss_funcC,self.gan_model = loss_funcG,loss_funcC,gan_model
self.loss_crit, self.weights_gen = loss_crit, weights_gen
def generator(self, output, target):
"Evaluate the `output` with the critic then uses `self.loss_funcG` to combine it with `target`."
fake_pred = self.gan_model.critic(output)
return self.loss_funcG(fake_pred, target, output, self.weights_gen, self.loss_crit)
def critic(self, real_pred, input):
"Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.loss_funcD`."
fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
fake_pred = self.gan_model.critic(fake)
return self.loss_funcC(real_pred, fake_pred, self.loss_crit)
def _loss_G(fake_pred, output, target, weights_gen, loss_crit):
ones = fake_pred.new_ones(fake_pred.shape[0])
weights_gen = ifnone(weights_gen, (1.,1.))
return weights_gen[0] * loss_crit(fake_pred, ones) + weights_gen[1] * loss_gen(output, target)
def _loss_C(real_pred, fake_pred, loss_crit):
ones = real_pred.new_ones (real_pred.shape[0])
zeros = fake_pred.new_zeros(fake_pred.shape[0])
return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2
class GANLearner(Learner):
"A `Learner` suitable for GANs."
def __init__(self, data:DataBunch, generator:nn.Module, critic:nn.Module, gen_loss_func:LossFunction,
crit_loss_func:LossFunction, switcher:Callback=None, learn_gen=None, learn_crit=None, gen_first:bool=False, switch_eval:bool=True,
show_img:bool=True, clip:float=None, weights_gen=(1.,50.), **learn_kwargs):
gan = GANModule(generator, critic)
loss_func = testGANLoss(gen_loss_func, crit_loss_func, gan, learn_gen.loss_func, learn_crit.loss_func, weights_gen)
switcher = ifnone(switcher, partial(FixedGANSwitcher, n_crit=5, n_gen=1))
super().__init__(data, gan, loss_func=loss_func, callback_fns=[switcher], **learn_kwargs)
trainer = GANTrainer(self, clip=clip, switch_eval=switch_eval, show_img=show_img)
self.gan_trainer = trainer
self.callbacks.append(trainer)
@classmethod
def from_learners(cls, learn_gen:Learner, learn_crit:Learner, switcher, **learn_kwargs):
"Create a GAN from `learn_gen` and `learn_crit`."
losses = _loss_G, _loss_C
learners = learn_gen, learn_crit
return cls(learn_gen.data, learn_gen.model, learn_crit.model, *losses, switcher, *learners, **learn_kwargs)