I’ve created the thread in the end so that we can write a bit more during the investigation
@stas the memory issue is quite important to me as well, we need it somehow fixed to work on multilingual version of ulmfit. There are as well some CUDA crashes other than OOM that I hit from time to time when I play with rnn api. So how about we create a new thread and start discussing there?
Regarding cyclic references in python, I remember that it was an issue in old versions but nowadays it just a matter of executing GC as opposed to quick reference counting strategy for acyclic graphs.
The cyclic reference comes in part from Callbacks (Learner Callbacks) as they are defined as dataclasses and they declare learner as property and some of the callbacks are stored in learner. Besides cyclic reference it makes it hard to list callbacks as the whole Learner object with Model object is printed as well for each callback.
I would vote for changing the Callbacks signature to make a learn a getter and use weak references to remove the cyclic dependency. What do you think @sgugger?
@dataclass
class LearnerCallback(Callback):
learn: Learner # <- source of cyclic references
def __post_init__(self):
if self.cb_name: setattr(self.learn, self.cb_name, self) # <- source of cyclic references
I would suggest using something along this lines
@dataclass
class LearnerCallback(Callback):
learn: InitVar[Learner]
_learn: Learner = field(init=False, repr=False, compare=False)
def __post_init__(self):
if self.cb_name: setattr(self.learn, self.cb_name, self) # <- still a problem
self._learn = weakref.ref(self.learn)
@getter
def learn():
return self._learn()
What remains an issue is the setattr on learner, @sgugger can we replace that with a weekref dictionary and some lookup function? Why do we need it, do you have an example at hand?