- # Cell
- class Learner(GetAttr):
- _default='model'
- def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,
- metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,
- moms=(0.95,0.85,0.95)):
- path = Path(path) if path is not None else getattr(dls, 'path', Path('.'))
- if loss_func is None:
- loss_func = getattr(dls.train_ds, 'loss_func', None)
- assert loss_func is not None, "Could not infer loss function from the data, please pass a loss function."
- self.dls,self.model = dls,model
- store_attr(but='dls,model,cbs')
- self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L()
- self.add_cbs(L(defaults.callbacks)+L(cbs))
- self("after_create")
-
- @property
- def metrics(self): return self._metrics
- @metrics.setter
- def metrics(self,v): self._metrics = L(v).map(mk_metric)
-