takotab
(Tako Tabak)
January 28, 2020, 10:57am
1
What do I need to do with MyDataLoader
to pass the test test_eq(dl_new.horizon,dl.horizon)
?
@delegates(TfmdDL.__init__)
class MyDataLoader(TfmdDL):
def __init__(self, dataset, horizon=1, **kwargs):
store_attr(self, 'horizon')
super().__init__(dataset=dataset, horizon=horizon, **kwargs)
dl = MyDataLoader([1,2,3],horizon=2)
dl.horizon
dl_new = dl.new()
test_eq(dl_new.horizon,dl.horizon)
ps any suggestions for a better title are welcome. I doubt if people search with the same problem will find it this way
takotab
(Tako Tabak)
January 28, 2020, 12:48pm
2
Oke so I refactored the problem to:
@delegates(TfmdDL.__init__)
class MyDataLoader(TfmdDL):
def __init__(self, dataset, addition=1, **kwargs):
store_attr(self, 'addition')
super().__init__(dataset=dataset, addition=addition, **kwargs)
def create_item(self, idx):
return self.dataset[idx] + self.addition
dl = MyDataLoader([1,2,3,4], addition= 2)
dl_new = dl.new()
test_eq(dl.one_batch(),dl_new.one_batch())
And then the solution is obvious:
from fastai2.data.all import *
class Addition(Transform):
def __init__(self, addition):
store_attr(self, 'addition')
def encodes(self, o):
return o + self.addition
def decodes(self, o):
return o - self.addition
@delegates(TfmdDL.__init__)
class MyDataLoader(TfmdDL):
def __init__(self, dataset, **kwargs):
super().__init__(dataset=dataset, **kwargs)
dl = MyDataLoader([1,2,3,4], after_item=[Addition(2)])
dl_new = dl.new()
test_eq(dl.one_batch(),dl_new.one_batch())
That’s one very good way of doing it. If you were stuck at the dataloader level, you would need to write the new method and add that attr on super().new
. I’ll do that soon to the LMDataLoader for seq_len
and SortedDL for sort_func
for instance (you made me realize I forgot that).
takotab
(Tako Tabak)
January 28, 2020, 3:27pm
4
Yes indeed that’s how I’m now going to solve it. Because I need to access to the TfmdDL.dataset
and also need some initialization (like make_chunck
in LMDataLoader
). And the proposed solution is not sufficient better use the super().new
.