I think I found the error. I had assumed the enumerate behavior was still there, but they removed it recently in Colab. So I removed the on_batch_begin()
that compensated for that.
However, I am dealing with a error that is going to be tougher to fix unfortunately. I get the following error from the validation loop:
File "/content/tpu_distributed_fastai.py", line 118, in train_loop
learn.fit(1)
File "/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py", line 200, in fit
fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
File "/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py", line 104, in fit
if not cb_handler.skip_validate and not learn.data.empty_val:
File "/usr/local/lib/python3.6/dist-packages/fastai/basic_data.py", line 122, in __getattr__
def __getattr__(self,k:int)->Any: return getattr(self.train_dl, k)
AttributeError: 'PerDeviceLoader' object has no attribute 'empty_val'
Here are the problems with fixing this.
For one, empty_val
is a property of learn.data
and should be returned as such. However, it is instead going to this __getattr__
function which is telling to just assume it’s an attribute of self.train_dl
, which I have redefined to be the PerDeviceLoader
iterator. I have no idea why it isn’t using the def empty_val(self):
which has a property decorator.
Even if I do fix it, there is still on more challenge, which I have a better idea of fixing. Here is the code for the empty_val
attribute:
@property
def empty_val(self)->bool:
if not hasattr(self, 'valid_dl') or self.valid_dl is None: return True
if hasattr(self.valid_ds, 'items') and len(self.valid_ds.items) == 0: return True
return (len(self.valid_ds) == 0)
So it’s going to look for self.valid_ds
:
@property
def valid_ds(self)->Dataset: return self._grab_dataset(self.valid_dl)
Both of the above attributes cannot be monkey-patched. Otherwise, I would try to directly change those. Here is the definition of _grab_dataset
def _grab_dataset(self, dl:DataLoader):
ds = dl.dl.dataset
while hasattr(ds, 'dataset'): ds = ds.dataset
return ds
So I think the solution here is to set in the callback self.learn.data.valid_dl.dataset = self.old_valid_dl.dataset
.
If you have any ideas to solve the getattr
problem, please let me know. I will probably come up with another hackish way to fix it