Source code for panoptes.utils.images.plot

from copy import copy

from warnings import warn

from matplotlib import rc
from matplotlib import animation
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable

import numpy as np

from astropy.visualization import LogStretch, ImageNormalize, LinearStretch, MinMaxInterval

rc('animation', html='html5')


[docs]def get_palette(cmap='inferno'): """Get a palette for drawing. Returns a copy of the colormap palette with bad pixels marked. Args: cmap (str, optional): Colormap to use, default 'inferno'. Returns: `matplotlib.cm`: The colormap. """ palette = copy(getattr(cm, cmap)) # Mark bad pixels (e.g. saturated) # when using vmin or vmax and a normalizer. palette.set_over('w', 1.0) palette.set_under('k', 1.0) palette.set_bad('g', 1.0) return palette
[docs]def add_colorbar(axes_image, size='5%', pad=0.05, orientation='vertical'): """Add a colorbar to the image. This is a simple convenience function to add a colorbar to a plot generated by `matplotlib.pyplot.imshow`. .. plot:: >>> from matplotlib import pyplot as plt >>> import numpy as np >>> from panoptes.utils.images.plot import add_colorbar >>> >>> x = np.arange(0.0, 100.0) >>> y = np.arange(0.0, 100.0) >>> X, Y = np.meshgrid(x, y) >>> >>> func = lambda x, y: x**2 + y**2 >>> >>> z = func(X, Y) >>> >>> fig, ax = plt.subplots() >>> im1 = ax.imshow(z, origin='lower') >>> add_colorbar(im1) >>> fig.show() Args: axes_image (`matplotlib.image.AxesImage`): A matplotlib AxesImage. """ divider = make_axes_locatable(axes_image.axes) cax = divider.append_axes('right', size=size, pad=pad) axes_image.figure.colorbar(axes_image, cax=cax, orientation=orientation)
[docs]def add_pixel_grid(ax1, grid_height, grid_width, show_axis_labels=True, show_superpixel=False, major_alpha=0.5, minor_alpha=0.25): # major ticks every 2, minor ticks every 1 if show_superpixel: x_major_ticks = np.arange(-0.5, grid_width, 2) y_major_ticks = np.arange(-0.5, grid_height, 2) ax1.set_xticks(x_major_ticks) ax1.set_yticks(y_major_ticks) ax1.grid(which='major', color='r', linestyle='--', lw=3, alpha=major_alpha) else: ax1.set_xticks([]) ax1.set_yticks([]) x_minor_ticks = np.arange(-0.5, grid_width, 1) y_minor_ticks = np.arange(-0.5, grid_height, 1) ax1.set_xticks(x_minor_ticks, minor=True) ax1.set_yticks(y_minor_ticks, minor=True) ax1.grid(which='minor', color='r', lw='2', linestyle='--', alpha=minor_alpha) if show_axis_labels is False: ax1.set_xticklabels([]) ax1.set_yticklabels([])
[docs]def animate_stamp(d0): fig = Figure() FigureCanvas(fig) ax = fig.add_subplot(111) ax.set_xticks([]) ax.set_yticks([]) line = ax.imshow(d0[0]) ax.set_title(f'Frame 0') def animate(i): line.set_data(d0[i]) # update the data ax.set_title(f'Frame {i:03d}') return line, # Init only required for blitting to give a clean slate. def init(): line.set_data(d0[0]) return line, ani = animation.FuncAnimation(fig, animate, np.arange(0, len(d0)), init_func=init, interval=500, blit=True) return ani
[docs]def show_stamps(pscs, frame_idx=None, stamp_size=11, aperture_position=None, show_residual=False, stretch=None, save_name=None, show_max=False, show_pixel_grid=False, **kwargs): if aperture_position is None: midpoint = (stamp_size - 1) / 2 aperture_position = (midpoint, midpoint) ncols = len(pscs) if show_residual: ncols += 1 nrows = 1 fig = Figure() FigureCanvas(fig) fig.set_figheight(4) fig.set_figwidth(8) if frame_idx is not None: s0 = pscs[0][frame_idx] s1 = pscs[1][frame_idx] else: s0 = pscs[0] s1 = pscs[1] if stretch == 'log': stretch = LogStretch() else: stretch = LinearStretch() norm = ImageNormalize(s0, interval=MinMaxInterval(), stretch=stretch) ax1 = fig.add_subplot(nrows, ncols, 1) im = ax1.imshow(s0, cmap=get_palette(), norm=norm) # create an axes on the right side of ax. The width of cax will be 5% # of ax and the padding between cax and ax will be fixed at 0.05 inch. # https://stackoverflow.com/questions/18195758/set-matplotlib-colorbar-size-to-match-graph divider = make_axes_locatable(ax1) cax = divider.append_axes("right", size="5%", pad=0.05) fig.colorbar(im, cax=cax) ax1.set_title('Target') # Comparison ax2 = fig.add_subplot(nrows, ncols, 2) im = ax2.imshow(s1, cmap=get_palette(), norm=norm) divider = make_axes_locatable(ax2) cax = divider.append_axes("right", size="5%", pad=0.05) fig.colorbar(im, cax=cax) ax2.set_title('Comparison') if show_pixel_grid: add_pixel_grid(ax1, stamp_size, stamp_size, show_superpixel=False) add_pixel_grid(ax2, stamp_size, stamp_size, show_superpixel=False) if show_residual: ax3 = fig.add_subplot(nrows, ncols, 3) # Residual residual = s0 - s1 im = ax3.imshow(residual, cmap=get_palette(), norm=ImageNormalize( residual, interval=MinMaxInterval(), stretch=LinearStretch())) divider = make_axes_locatable(ax3) cax = divider.append_axes("right", size="5%", pad=0.05) fig.colorbar(im, cax=cax) ax3.set_title('Noise Residual') ax3.set_title('Residual RMS: {:.01%}'.format(residual.std())) ax3.set_yticklabels([]) ax3.set_xticklabels([]) if show_pixel_grid: add_pixel_grid(ax1, stamp_size, stamp_size, show_superpixel=False) # Turn off tick labels ax1.set_yticklabels([]) ax1.set_xticklabels([]) ax2.set_yticklabels([]) ax2.set_xticklabels([]) if save_name: try: fig.savefig(save_name) except Exception as e: warn("Can't save figure: {}".format(e)) return fig