Prediction using GANLearner [SOLVED]

I created a GAN Learner using from_learners method as mentioned in Lesson 7.

When I am trying to predict the output for a image using learn.predict(img) (img is a image opened using open_image), I’m getting the following error

    AttributeError                            Traceback (most recent call last)
    <ipython-input-81-3b6ef5ec6f96> in <module>
    ----> 1 learn.predict(img)

    /opt/anaconda3/lib/python3.7/site-packages/fastai/basic_train.py in predict(self, item, **kwargs)
        353         "Return predicted class, label and probabilities for `item`."
        354         batch = self.data.one_item(item)
    --> 355         res = self.pred_batch(batch=batch)
        356         pred,x = res[0],batch[0]
        357         norm = getattr(self.data,'norm',False)

    /opt/anaconda3/lib/python3.7/site-packages/fastai/basic_train.py in pred_batch(self, ds_type, batch, reconstruct)
        332         else: xb,yb = self.data.one_batch(ds_type, detach=False, denorm=False)
        333         cb_handler = CallbackHandler(self.callbacks)
    --> 334         xb,yb = cb_handler.on_batch_begin(xb,yb, train=False)
        335         preds = loss_batch(self.model.eval(), xb, yb, cb_handler=cb_handler)
        336         res = _loss_func2activ(self.loss_func)(preds[0])

    /opt/anaconda3/lib/python3.7/site-packages/fastai/callback.py in on_batch_begin(self, xb, yb, train)
        261         self.state_dict.update(dict(last_input=xb, last_target=yb, train=train, 
        262             stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))
    --> 263         self('batch_begin', mets = not self.state_dict['train'])
        264         return self.state_dict['last_input'], self.state_dict['last_target']
        265 

    /opt/anaconda3/lib/python3.7/site-packages/fastai/callback.py in __call__(self, cb_name, call_mets, **kwargs)
        233         if call_mets:
        234             for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
    --> 235         for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
        236 
        237     def set_dl(self, dl:DataLoader):

    /opt/anaconda3/lib/python3.7/site-packages/fastai/callback.py in _call_and_update(self, cb, cb_name, **kwargs)
        223     def _call_and_update(self, cb, cb_name, **kwargs)->None:
        224         "Call `cb_name` on `cb` and update the inner state."
    --> 225         new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
        226         for k,v in new.items():
        227             if k not in self.state_dict:

    /opt/anaconda3/lib/python3.7/site-packages/fastai/vision/gan.py in on_batch_begin(self, last_input, last_target, **kwargs)
        112         if self.clip is not None:
        113             for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
    --> 114         return {'last_input':last_input,'last_target':last_target} if self.gen_mode else {'last_input':last_target,'last_target':last_input}
        115 
        116     def on_backward_begin(self, last_loss, last_output, **kwargs):

    /opt/anaconda3/lib/python3.7/site-packages/fastai/basic_train.py in __getattr__(self, k)
        414         setattr(self.learn, self.cb_name, self)
        415 
    --> 416     def __getattr__(self,k): return getattr(self.learn, k)
        417     def __setstate__(self,data:Any): self.__dict__.update(data)
        418 

    AttributeError: 'GANLearner' object has no attribute 'gen_mode'

How do I get the predicted image?

2 Likes

UPDATE:

After looking at the source code, I tried explicitly setting the value of gen_mode to True
learn.gen_mode = True

When I try to predict using learn.predict(img), I’m getting the following output

(Image (3, 1, 121),
 tensor([[[ 0.2183,  0.2188,  0.2186,  0.2159,  0.2161,  0.2174,  0.2175,
            0.2235,  0.2296,  0.2340,  0.2326,  0.2633,  0.2615,  0.2548,
            0.2367,  0.2248,  0.2269,  0.2179,  0.2106,  0.2218,  0.2387,
            0.2412,  0.2807,  0.2792,  0.2721,  0.2514,  0.2372,  0.2390,
            0.2255,  0.2020,  0.1902,  0.2023,  0.2238,  0.2684,  0.2731,
            0.2764,  0.2703,  0.2557,  0.2448,  0.2353,  0.1992,  0.1446,
            0.1371,  0.1628,  0.2459,  0.2540,  0.2645,  0.2603,  0.2149,
            0.1721,  0.1652,  0.1329,  0.0727,  0.0875,  0.1453,  0.1996,
            0.2010,  0.1990,  0.1754,  0.0966,  0.0197,  0.0131,  0.0128,
           -0.0293,  0.0066,  0.0999,  0.1318,  0.1018,  0.0558,  0.0324,
           -0.0181, -0.0891, -0.0838, -0.0713, -0.1172, -0.0960,  0.0036,
            0.0683,  0.0244, -0.0268, -0.0390, -0.0481, -0.0931, -0.0853,
           -0.0672, -0.1085, -0.1022, -0.0010,  0.0377, -0.0113, -0.0299,
           -0.0194, -0.0009, -0.0211,  0.0013,  0.0398,  0.0283,  0.0332,
            0.0956,  0.0328,  0.0044,  0.0094,  0.0353,  0.0575,  0.0377,
            0.0636,  0.1108,  0.1172,  0.1213,  0.1481,  0.0708,  0.0432,
            0.0427,  0.0592,  0.0795,  0.0683,  0.0915,  0.1401,  0.1550,
            0.1577,  0.1697]],
 
         [[ 0.1952,  0.1956,  0.1954,  0.1928,  0.1930,  0.1942,  0.1943,
            0.2002,  0.2062,  0.2105,  0.2091,  0.2391,  0.2373,  0.2309,
            0.2132,  0.2015,  0.2035,  0.1947,  0.1875,  0.1985,  0.2151,
            0.2175,  0.2561,  0.2547,  0.2477,  0.2275,  0.2136,  0.2153,
            0.2021,  0.1792,  0.1677,  0.1795,  0.2005,  0.2442,  0.2488,
            0.2520,  0.2460,  0.2317,  0.2211,  0.2118,  0.1765,  0.1230,
            0.1157,  0.1408,  0.2221,  0.2301,  0.2404,  0.2362,  0.1918,
            0.1499,  0.1432,  0.1116,  0.0527,  0.0672,  0.1237,  0.1768,
            0.1782,  0.1763,  0.1531,  0.0760,  0.0009, -0.0056, -0.0059,
           -0.0471, -0.0120,  0.0793,  0.1105,  0.0812,  0.0362,  0.0133,
           -0.0362, -0.1056, -0.1004, -0.0882, -0.1331, -0.1123, -0.0149,
            0.0484,  0.0054, -0.0446, -0.0566, -0.0655, -0.1095, -0.1018,
           -0.0841, -0.1245, -0.1184, -0.0193,  0.0185, -0.0294, -0.0477,
           -0.0374, -0.0193, -0.0391, -0.0171,  0.0205,  0.0093,  0.0141,
            0.0751,  0.0137, -0.0141, -0.0092,  0.0161,  0.0379,  0.0185,
            0.0438,  0.0900,  0.0963,  0.1002,  0.1264,  0.0508,  0.0238,
            0.0234,  0.0395,  0.0593,  0.0484,  0.0711,  0.1186,  0.1332,
            0.1359,  0.1476]],
 
         [[ 0.1440,  0.1444,  0.1442,  0.1416,  0.1418,  0.1431,  0.1432,
            0.1491,  0.1551,  0.1594,  0.1580,  0.1882,  0.1864,  0.1799,
            0.1621,  0.1503,  0.1524,  0.1436,  0.1363,  0.1474,  0.1640,
            0.1664,  0.2052,  0.2038,  0.1968,  0.1765,  0.1626,  0.1643,
            0.1510,  0.1279,  0.1164,  0.1283,  0.1493,  0.1932,  0.1978,
            0.2011,  0.1951,  0.1807,  0.1700,  0.1607,  0.1252,  0.0715,
            0.0642,  0.0894,  0.1711,  0.1791,  0.1894,  0.1853,  0.1406,
            0.0986,  0.0918,  0.0601,  0.0009,  0.0154,  0.0723,  0.1256,
            0.1269,  0.1250,  0.1018,  0.0243, -0.0512, -0.0577, -0.0580,
           -0.0993, -0.0641,  0.0276,  0.0590,  0.0295, -0.0157, -0.0387,
           -0.0883, -0.1581, -0.1528, -0.1406, -0.1857, -0.1649, -0.0670,
           -0.0035, -0.0466, -0.0969, -0.1088, -0.1178, -0.1620, -0.1543,
           -0.1365, -0.1771, -0.1710, -0.0715, -0.0335, -0.0816, -0.0999,
           -0.0896, -0.0714, -0.0913, -0.0692, -0.0314, -0.0427, -0.0379,
            0.0234, -0.0383, -0.0662, -0.0613, -0.0358, -0.0140, -0.0335,
           -0.0080,  0.0384,  0.0447,  0.0487,  0.0750, -0.0010, -0.0281,
           -0.0285, -0.0124,  0.0076, -0.0034,  0.0194,  0.0671,  0.0817,
            0.0845,  0.0962]]]),
 tensor([-1.1645, -1.1627, -1.1634, -1.1751, -1.1740, -1.1686, -1.1681, -1.1418,
         -1.1152, -1.0959, -1.1024, -0.9681, -0.9762, -1.0051, -1.0841, -1.1363,
         -1.1272, -1.1664, -1.1984, -1.1495, -1.0753, -1.0647, -0.8923, -0.8989,
         -0.9299, -1.0199, -1.0820, -1.0744, -1.1334, -1.2358, -1.2871, -1.2343,
         -1.1407, -0.9457, -0.9251, -0.9108, -0.9374, -1.0015, -1.0489, -1.0904,
         -1.2480, -1.4866, -1.5192, -1.4070, -1.0441, -1.0086, -0.9627, -0.9811,
         -1.1793, -1.3664, -1.3964, -1.5374, -1.8006, -1.7359, -1.4833, -1.2464,
         -1.2403, -1.2488, -1.3521, -1.6963, -2.0319, -2.0608, -2.0620, -2.2458,
         -2.0891, -1.6816, -1.5424, -1.6733, -1.8743, -1.9763, -2.1971, -2.5070,
         -2.4837, -2.4293, -2.6298, -2.5373, -2.1023, -1.8198, -2.0116, -2.2350,
         -2.2882, -2.3281, -2.5244, -2.4904, -2.4112, -2.5916, -2.5644, -2.1221,
         -1.9533, -2.1671, -2.2486, -2.2027, -2.1218, -2.2101, -2.1122, -1.9441,
         -1.9942, -1.9728, -1.7006, -1.9746, -2.0986, -2.0768, -1.9638, -1.8666,
         -1.9532, -1.8402, -1.6340, -1.6059, -1.5882, -1.4713, -1.8089, -1.9294,
         -1.9313, -1.8594, -1.7709, -1.8195, -1.7184, -1.5062, -1.4412, -1.4291,
         -1.3767]))

The output must be an RGB image of size 224x224.

How do I get the output image from this?

SOLVED:

Instead of using learn.predict(img), I used learn_gen.predict(img)
(learn_gen is the generator model, using which GANLearner was created using from_learners method)

5 Likes

Any idea on how we could generate an image without predicting on an existing image?

Thank you for your suggestions. I have tried this. But it’s not working for me. It still returns the tensor object instead of an image.

img=open_image(’/yourpath/000001.jpg’)
x=learn_gen.predict(img)
x[0].show()

Surely it will produce a RGB image