I trained a GANLearner, saved the weights, loaded them again now running on some test images.
img = open_image(fn); img.shape
_,img_hr,b = learn.predict(img)
Then when trying to disply the generated image with Image(img_hr)
I get
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/IPython/core/formatters.py in __call__(self, obj)
332 pass
333 else:
--> 334 return printer(obj)
335 # Finally look for special method names
336 method = get_real_method(obj, self.print_method)
/usr/local/lib/python3.6/dist-packages/IPython/core/pylabtools.py in <lambda>(fig)
239
240 if 'png' in formats:
--> 241 png_formatter.for_type(Figure, lambda fig: print_figure(fig, 'png', **kwargs))
242 if 'retina' in formats or 'png2x' in formats:
243 png_formatter.for_type(Figure, lambda fig: retina_figure(fig, **kwargs))
/usr/local/lib/python3.6/dist-packages/IPython/core/pylabtools.py in print_figure(fig, fmt, bbox_inches, **kwargs)
123
124 bytes_io = BytesIO()
--> 125 fig.canvas.print_figure(bytes_io, **kw)
126 data = bytes_io.getvalue()
127 if fmt == 'svg':
/usr/local/lib/python3.6/dist-packages/matplotlib/backend_bases.py in print_figure(self, filename, dpi, facecolor, edgecolor, orientation, format, **kwargs)
2214 orientation=orientation,
2215 dryrun=True,
-> 2216 **kwargs)
2217 renderer = self.figure._cachedRenderer
2218 bbox_inches = self.figure.get_tightbbox(renderer)
/usr/local/lib/python3.6/dist-packages/matplotlib/backends/backend_agg.py in print_png(self, filename_or_obj, *args, **kwargs)
505
506 def print_png(self, filename_or_obj, *args, **kwargs):
--> 507 FigureCanvasAgg.draw(self)
508 renderer = self.get_renderer()
509 original_dpi = renderer.dpi
/usr/local/lib/python3.6/dist-packages/matplotlib/backends/backend_agg.py in draw(self)
428 # if toolbar:
429 # toolbar.set_cursor(cursors.WAIT)
--> 430 self.figure.draw(self.renderer)
431 finally:
432 # if toolbar:
/usr/local/lib/python3.6/dist-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
---> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:
/usr/local/lib/python3.6/dist-packages/matplotlib/figure.py in draw(self, renderer)
1297
1298 mimage._draw_list_compositing_images(
-> 1299 renderer, self, artists, self.suppressComposite)
1300
1301 renderer.close_group('figure')
/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
136 if not_composite or not has_images:
137 for a in artists:
--> 138 a.draw(renderer)
139 else:
140 # Composite any adjacent images together
/usr/local/lib/python3.6/dist-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
---> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:
/usr/local/lib/python3.6/dist-packages/matplotlib/axes/_base.py in draw(self, renderer, inframe)
2435 renderer.stop_rasterizing()
2436
-> 2437 mimage._draw_list_compositing_images(renderer, self, artists)
2438
2439 renderer.close_group('axes')
/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in _draw_list_compositing_images(renderer, parent, artists, suppress_composite)
136 if not_composite or not has_images:
137 for a in artists:
--> 138 a.draw(renderer)
139 else:
140 # Composite any adjacent images together
/usr/local/lib/python3.6/dist-packages/matplotlib/artist.py in draw_wrapper(artist, renderer, *args, **kwargs)
53 renderer.start_filter()
54
---> 55 return draw(artist, renderer, *args, **kwargs)
56 finally:
57 if artist.get_agg_filter() is not None:
/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in draw(self, renderer, *args, **kwargs)
564 else:
565 im, l, b, trans = self.make_image(
--> 566 renderer, renderer.get_image_magnification())
567 if im is not None:
568 renderer.draw_image(gc, l, b, im)
/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in make_image(self, renderer, magnification, unsampled)
791 return self._make_image(
792 self._A, bbox, transformed_bbox, self.axes.bbox, magnification,
--> 793 unsampled=unsampled)
794
795 def _check_unsampled_image(self, renderer):
/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification, unsampled, round_to_pixel_border)
482 # (of int or float)
483 # or an RGBA array of re-sampled input
--> 484 output = self.to_rgba(output, bytes=True, norm=False)
485 # output is now a correctly sized RGBA array of uint8
486
/usr/local/lib/python3.6/dist-packages/matplotlib/cm.py in to_rgba(self, x, alpha, bytes, norm)
255 if xx.dtype.kind == 'f':
256 if norm and xx.max() > 1 or xx.min() < 0:
--> 257 raise ValueError("Floating point image RGB values "
258 "must be in the 0..1 range.")
259 if bytes:
ValueError: Floating point image RGB values must be in the 0..1 range.
<matplotlib.figure.Figure at 0x7f0e346ff390>
I tried training the learner one more epoch and didn’t get a similar error!
Any clue how I could solve this?