Source code for proplot.subplots

#!/usr/bin/env python3
The starting point for creating custom ProPlot figures. Includes
pyplot-inspired functions for creating figures and related classes.
# NOTE: Importing backend causes issues with sphinx, and anyway not sure it's
# always included, so make it optional
import os
import numpy as np
import functools
import inspect
import matplotlib.pyplot as plt
import matplotlib.figure as mfigure
import matplotlib.transforms as mtransforms
import matplotlib.gridspec as mgridspec
from .rctools import rc
from .utils import _warn_proplot, _notNone, _counter, units
from . import projs, axes
__all__ = [
    'subplot_grid', 'close', 'show', 'subplots', 'Figure',
    'GridSpec', 'SubplotSpec',

# Translation
    'l': 'left',
    'r': 'right',
    'b': 'bottom',
    't': 'top',

# Dimensions of figures for common journals
    'aaas1': '5.5cm',
    'aaas2': '12cm',
    'agu1': ('95mm', '115mm'),
    'agu2': ('190mm', '115mm'),
    'agu3': ('95mm', '230mm'),
    'agu4': ('190mm', '230mm'),
    'ams1': 3.2,
    'ams2': 4.5,
    'ams3': 5.5,
    'ams4': 6.5,
    'nat1': '89mm',
    'nat2': '183mm',
    'pnas1': '8.7cm',
    'pnas2': '11.4cm',
    'pnas3': '17.8cm',

[docs]def close(*args, **kwargs): """Pass the input arguments to `matplotlib.pyplot.close`. This is included so you don't have to import `~matplotlib.pyplot`.""" plt.close(*args, **kwargs)
[docs]def show(): """Call ``. This is included so you don't have to import `~matplotlib.pyplot`. Note this command should *not be necessary* if you are working in an iPython session and :rcraw:`matplotlib` is non-empty -- when you create a new figure, it will be automatically displayed."""
[docs]class subplot_grid(list): """List subclass and pseudo-2D array that is used as a container for the list of axes returned by `subplots`. The shape of the array is stored in the ``shape`` attribute. See the `~subplot_grid.__getattr__` and `~subplot_grid.__getitem__` methods for details.""" def __init__(self, objs, n=1, order='C'): """ Parameters ---------- objs : list-like 1D iterable of `~proplot.axes.Axes` instances. n : int, optional The length of the fastest-moving dimension, i.e. the number of columns when `order` is ``'C'``, and the number of rows when `order` is ``'F'``. Used to treat lists as pseudo-2D arrays. order : {'C', 'F'}, optional Whether 1D indexing returns results in row-major (C-style) or column-major (Fortran-style) order, respectively. Used to treat lists as pseudo-2D arrays. """ if not all(isinstance(obj, axes.Axes) for obj in objs): raise ValueError( f'Axes grid must be filled with Axes instances, got {objs!r}.') self._n = n self._order = order super().__init__(objs) self.shape = (len(self) // n, n)[::(1 if order == 'C' else -1)] def __repr__(self): return 'subplot_grid([' + ', '.join(str(ax) for ax in self) + '])'
[docs] def __setitem__(self, key, value): """Pseudo immutability. Raises error.""" raise LookupError('subplot_grid is immutable.')
[docs] def __getitem__(self, key): """If an integer is passed, the item is returned, and if a slice is passed, an `subplot_grid` of the items is returned. You can also use 2D indexing, and the corresponding axes in the axes grid will be chosen. Example ------- >>> import proplot as plot ... f, axs = plot.subplots(nrows=3, ncols=3, colorbars='b', bstack=2) ... axs[0] # the subplot in the top-right corner ... axs[3] # the first subplot in the second row ... axs[1,2] # the subplot in the second row, third from the left ... axs[:,0] # the subplots in the first column """ # Allow 2D specification if isinstance(key, tuple) and len(key) == 1: key = key[0] # do not expand single slice to list of integers or we get recursion! # len() operator uses __getitem__! if not isinstance(key, tuple): axlist = isinstance(key, slice) objs = list.__getitem__(self, key) elif len(key) == 2: axlist = any(isinstance(ikey, slice) for ikey in key) # Expand keys keys = [] order = self._order for i, ikey in enumerate(key): if (i == 1 and order == 'C') or (i == 0 and order != 'C'): n = self._n else: n = len(self) // self._n if isinstance(ikey, slice): start, stop, step = ikey.start, ikey.stop, ikey.step if start is None: start = 0 elif start < 0: start = n + start if stop is None: stop = n elif stop < 0: stop = n + stop if step is None: step = 1 ikeys = [*range(start, stop, step)] else: if ikey < 0: ikey = n + ikey ikeys = [ikey] keys.append(ikeys) # Get index pairs and get objects # Note that in double for loop, right loop varies fastest, so # e.g. axs[:,:] delvers (0,0), (0,1), ..., (0,N), (1,0), ... # Remember for order == 'F', subplot_grid was sent a list unfurled # in column-major order, so we replicate row-major indexing syntax # by reversing the order of the keys. objs = [] if self._order == 'C': idxs = [key0 * self._n + key1 for key0 in keys[0] for key1 in keys[1]] else: idxs = [key1 * self._n + key0 for key1 in keys[1] for key0 in keys[0]] for idx in idxs: objs.append(list.__getitem__(self, idx)) if not axlist: # objs will always be length 1 objs = objs[0] else: raise IndexError # Return if axlist: return subplot_grid(objs) else: return objs
[docs] def __getattr__(self, attr): """ If the attribute is *callable*, returns a dummy function that loops through each identically named method, calls them in succession, and returns a tuple of the results. This lets you call arbitrary methods on multiple axes at once! If the `subplot_grid` has length ``1``, just returns the single result. If the attribute is *not callable*, returns a tuple of attributes for every object in the list. Example ------- >>> import proplot as plot ... f, axs = plot.subplots(nrows=2, ncols=2) ... axs.format(...) # calls "format" on all subplots in the list ... paxs = axs.panel_axes('r') ... paxs.format(...) # calls "format" on all panels """ if not self: raise AttributeError( f'Invalid attribute {attr!r}, axes grid {self!r} is empty.') objs = (*(getattr(ax, attr) for ax in self),) # may raise error # Objects if not any(callable(_) for _ in objs): if len(self) == 1: return objs[0] else: return objs # Methods # NOTE: Must manually copy docstring because help() cannot inherit it elif all(callable(_) for _ in objs): @functools.wraps(objs[0]) def _iterator(*args, **kwargs): ret = [] for func in objs: ret.append(func(*args, **kwargs)) ret = (*ret,) if len(self) == 1: return ret[0] elif all(res is None for res in ret): return None elif all(isinstance(res, axes.Axes) for res in ret): return subplot_grid(ret, n=self._n, order=self._order) else: return ret _iterator.__doc__ = inspect.getdoc(objs[0]) return _iterator # Mixed raise AttributeError(f'Found mixed types for attribute {attr!r}.')
[docs]class SubplotSpec(mgridspec.SubplotSpec): """ Adds two helper methods to `~matplotlib.gridspec.SubplotSpec` that return the geometry *excluding* rows and columns allocated for spaces. """
[docs] def get_active_geometry(self): """Returns the number of rows, number of columns, and 1D subplot location indices, ignoring rows and columns allocated for spaces.""" nrows, ncols, row1, row2, col1, col2 = self.get_active_rows_columns() num1 = row1 * ncols + col1 num2 = row2 * ncols + col2 return nrows, ncols, num1, num2
[docs] def get_active_rows_columns(self): """Returns the number of rows, number of columns, first subplot row, last subplot row, first subplot column, and last subplot column, ignoring rows and columns allocated for spaces.""" gridspec = self.get_gridspec() nrows, ncols = gridspec.get_geometry() row1, col1 = divmod(self.num1, ncols) if self.num2 is not None: row2, col2 = divmod(self.num2, ncols) else: row2 = row1 col2 = col1 return ( nrows // 2, ncols // 2, row1 // 2, row2 // 2, col1 // 2, col2 // 2)
[docs]class GridSpec(mgridspec.GridSpec): """ `~matplotlib.gridspec.GridSpec` generalization that allows for grids with *variable spacing* between successive rows and columns of axes. Accomplishes this by actually drawing ``nrows*2 + 1`` and ``ncols*2 + 1`` `~matplotlib.gridspec.GridSpec` rows and columns, setting `wspace` and `hspace` to ``0``, and masking out every other row and column of the `~matplotlib.gridspec.GridSpec`, so they act as "spaces". These "spaces" are then allowed to vary in width using the builtin `width_ratios` and `height_ratios` properties. """ def __init__(self, figure, nrows=1, ncols=1, **kwargs): """ Parameters ---------- figure : `Figure` The figure instance filled by this gridspec. Unlike `~matplotlib.gridspec.GridSpec`, this argument is required. nrows, ncols : int, optional The number of rows and columns on the subplot grid. hspace, wspace : float or list of float The vertical and horizontal spacing between rows and columns of subplots, respectively. In `~proplot.subplots.subplots`, ``wspace`` and ``hspace`` are in physical units. When calling `GridSpec` directly, values are scaled relative to the average subplot height or width. If float, the spacing is identical between all rows and columns. If list of float, the length of the lists must equal ``nrows-1`` and ``ncols-1``, respectively. height_ratios, width_ratios : list of float Ratios for the relative heights and widths for rows and columns of subplots, respectively. For example, ``width_ratios=(1,2)`` scales a 2-column gridspec so that the second column is twice as wide as the first column. left, right, top, bottom : float or str Passed to `~matplotlib.gridspec.GridSpec`, denotes the margin positions in figure-relative coordinates. **kwargs Passed to `~matplotlib.gridspec.GridSpec`. """ self._nrows = nrows * 2 - 1 # used with get_geometry self._ncols = ncols * 2 - 1 self._nrows_active = nrows self._ncols_active = ncols wratios, hratios, kwargs = self._spaces_as_ratios(**kwargs) super().__init__(self._nrows, self._ncols, hspace=0, wspace=0, # replaced with "hidden" slots width_ratios=wratios, height_ratios=hratios, figure=figure, **kwargs, )
[docs] def __getitem__(self, key): """Magic obfuscation that renders `~matplotlib.gridspec.GridSpec` rows and columns designated as 'spaces' inaccessible.""" nrows, ncols = self.get_geometry() nrows_active, ncols_active = self.get_active_geometry() if not isinstance(key, tuple): # usage gridspec[1,2] num1, num2 = self._normalize(key, nrows_active * ncols_active) else: if len(key) == 2: k1, k2 = key else: raise ValueError(f'Invalid index {key!r}.') num1 = self._normalize(k1, nrows_active) num2 = self._normalize(k2, ncols_active) num1, num2 = np.ravel_multi_index((num1, num2), (nrows, ncols)) num1 = self._positem(num1) num2 = self._positem(num2) return SubplotSpec(self, num1, num2)
@staticmethod def _positem(size): """Account for negative indices.""" if size < 0: # want -1 to stay -1, -2 becomes -3, etc. return 2 * (size + 1) - 1 else: return size * 2 @staticmethod def _normalize(key, size): """Transform gridspec index into standardized form.""" if isinstance(key, slice): start, stop, _ = key.indices(size) if stop > start: return start, stop - 1 else: if key < 0: key += size if 0 <= key < size: return key, key raise IndexError(f'Invalid index: {key} with size {size}.') def _spaces_as_ratios( self, hspace=None, wspace=None, # spacing between axes height_ratios=None, width_ratios=None, **kwargs): """For keyword arg usage, see `GridSpec`.""" # Parse flexible input nrows, ncols = self.get_active_geometry() hratios = np.atleast_1d(_notNone(height_ratios, 1)) wratios = np.atleast_1d(_notNone(width_ratios, 1)) # this is relative to axes hspace = np.atleast_1d(_notNone(hspace, np.mean(hratios) * 0.10)) wspace = np.atleast_1d(_notNone(wspace, np.mean(wratios) * 0.10)) if len(wspace) == 1: wspace = np.repeat(wspace, (ncols - 1,)) # note: may be length 0 if len(hspace) == 1: hspace = np.repeat(hspace, (nrows - 1,)) if len(wratios) == 1: wratios = np.repeat(wratios, (ncols,)) if len(hratios) == 1: hratios = np.repeat(hratios, (nrows,)) # Verify input ratios and spacings # Translate height/width spacings, implement as extra columns/rows if len(hratios) != nrows: raise ValueError(f'Got {nrows} rows, but {len(hratios)} hratios.') if len(wratios) != ncols: raise ValueError( f'Got {ncols} columns, but {len(wratios)} wratios.') if len(wspace) != ncols - 1: raise ValueError( f'Require {ncols-1} width spacings for {ncols} columns, ' f'got {len(wspace)}.') if len(hspace) != nrows - 1: raise ValueError( f'Require {nrows-1} height spacings for {nrows} rows, ' f'got {len(hspace)}.') # Assign spacing as ratios nrows, ncols = self.get_geometry() wratios_final = [None] * ncols wratios_final[::2] = [*wratios] if ncols > 1: wratios_final[1::2] = [*wspace] hratios_final = [None] * nrows hratios_final[::2] = [*hratios] if nrows > 1: hratios_final[1::2] = [*hspace] return wratios_final, hratios_final, kwargs # bring extra kwargs back
[docs] def get_margins(self): """Returns left, bottom, right, top values. Not sure why this method doesn't already exist on `~matplotlib.gridspec.GridSpec`.""" return self.left, self.bottom, self.right,
[docs] def get_hspace(self): """Returns row ratios allocated for spaces.""" return self.get_height_ratios()[1::2]
[docs] def get_wspace(self): """Returns column ratios allocated for spaces.""" return self.get_width_ratios()[1::2]
[docs] def get_active_height_ratios(self): """Returns height ratios excluding slots allocated for spaces.""" return self.get_height_ratios()[::2]
[docs] def get_active_width_ratios(self): """Returns width ratios excluding slots allocated for spaces.""" return self.get_width_ratios()[::2]
[docs] def get_active_geometry(self): """Returns the number of active rows and columns, i.e. the rows and columns that aren't skipped by `~GridSpec.__getitem__`.""" return self._nrows_active, self._ncols_active
[docs] def update(self, **kwargs): """ Updates the width ratios, height ratios, gridspec margins, and spacing allocated between subplot rows and columns. The default `~matplotlib.gridspec.GridSpec.update` tries to update positions for axes on all active figures -- but this can fail after successive figure edits if it has been removed from the figure manager. So, we explicitly require that the gridspec is dedicated to a particular `~matplotlib.figure.Figure` instance, and just edit axes positions for axes on that instance. """ # Convert spaces to ratios wratios, hratios, kwargs = self._spaces_as_ratios(**kwargs) self.set_width_ratios(wratios) self.set_height_ratios(hratios) # Validate args kwargs.pop('ncols', None) kwargs.pop('nrows', None) self.left = kwargs.pop('left', None) self.right = kwargs.pop('right', None) self.bottom = kwargs.pop('bottom', None) = kwargs.pop('top', None) if kwargs: raise ValueError(f'Unknown keyword arg(s): {kwargs}.') # Apply to figure and all axes fig = self.figure fig.subplotpars.update(self.left, self.bottom, self.right, for ax in fig.axes: ax.update_params() ax.set_position(ax.figbox) fig.stale = True
def _canvas_preprocess(canvas, method): """Return a pre-processer that can be used to override instance-level canvas draw() and print_figure() methods. This applies tight layout and aspect ratio-conserving adjustments and aligns labels. Required so that the canvas methods instantiate renderers with the correct dimensions. Note that MacOSX currently `cannot be resized \ <>`__.""" # NOTE: This is by far the most robust approach. Renderer must be (1) # initialized with the correct figure size or (2) changed inplace during # draw, but vector graphic renderers *cannot* be changed inplace. # Options include (1) monkey patch canvas.get_width_height, overriding # figure.get_size_inches, and exploit the FigureCanvasAgg.get_renderer() # implementation (because FigureCanvasAgg queries the bbox directly # rather than using get_width_height() so requires a workaround), or (2) # override bbox and bbox_inches as *properties*, but these are really # complicated, dangerous, and result in unnecessary extra draws. def _preprocess(self, *args, **kwargs): if method == 'draw_idle' and self._is_idle_drawing: return # copied from source code fig = self.figure # update even if not stale! needed after saves if fig.stale and method == 'print_figure': # Needed for displaying already-drawn inline figures, for # some reason tight layout algorithm gets it wrong otherwise. # Concerned that draw_idle() might wait until after # print_figure() is done, so we use draw(). self.draw() renderer = fig._get_renderer() # any renderer will do for now for ax in fig._iter_axes(): ax._draw_auto_legends_colorbars() # may insert panels if rc['backend'] != 'nbAgg': fig._adjust_aspect() # resizes figure if fig._auto_tight_layout: fig._align_axislabels(False) # get proper label offset only fig._align_labels(renderer) # position labels and suptitle fig._adjust_tight_layout(renderer) fig._align_axislabels(True) # slide spanning labels across fig._align_labels(renderer) # update figure-relative coordinates! res = getattr(type(self), method)(self, *args, **kwargs) return res return _preprocess.__get__(canvas) # ...I don't get it either def _panels_kwargs( side, share=None, width=None, space=None, filled=False, figure=False): """Converts global keywords like `space` and `width` to side-local keywords like `lspace` and `lwidth`, and applies default settings.""" # Return values # NOTE: Make default legend width same as default colorbar width, in # case user draws legend and colorbar panel in same row or column! s = side[0] if s not in 'lrbt': raise ValueError(f'Invalid panel spec {side!r}.') space_orig = units(space) if filled: default = rc['colorbar.width'] else: default = rc['subplots.panelwidth'] share = _notNone(share, (not filled)) width = units(_notNone(width, default)) space = _notNone(units(space), units(rc['subplots.' + ( 'panel' if share and not figure else 'xlab' if s == 'b' else 'ylab' if s == 'l' else 'inner' if figure else 'panel') + 'space'])) return share, width, space, space_orig def _subplots_geometry(**kwargs): """Save arguments passed to `subplots`, calculates gridspec settings and figure size necessary for requested geometry, and returns keyword args necessary to reconstruct and modify this configuration. Note that `wspace`, `hspace`, `left`, `right`, `top`, and `bottom` always have fixed physical units, then we scale figure width, figure height, and width and height ratios to accommodate spaces.""" # Dimensions and geometry nrows, ncols = kwargs['nrows'], kwargs['ncols'] aspect, xref, yref = kwargs['aspect'], kwargs['xref'], kwargs['yref'] width, height = kwargs['width'], kwargs['height'] axwidth, axheight = kwargs['axwidth'], kwargs['axheight'] # Gridspec settings wspace, hspace = kwargs['wspace'], kwargs['hspace'] wratios, hratios = kwargs['wratios'], kwargs['hratios'] left, bottom = kwargs['left'], kwargs['bottom'] right, top = kwargs['right'], kwargs['top'] # Panel string toggles, lists containing empty strings '' (indicating a # main axes), or one of 'l', 'r', 'b', 't' (indicating axes panels) or # 'f' (indicating figure panels) wpanels, hpanels = kwargs['wpanels'], kwargs['hpanels'] # Checks, important now that we modify gridspec geometry if len(hratios) != nrows: raise ValueError( f'Expected {nrows} width ratios for {nrows} rows, ' f'got {len(hratios)}.') if len(wratios) != ncols: raise ValueError( f'Expected {ncols} width ratios for {ncols} columns, ' f'got {len(wratios)}.') if len(hspace) != nrows - 1: raise ValueError( f'Expected {nrows - 1} hspaces for {nrows} rows, ' f'got {len(hspace)}.') if len(wspace) != ncols - 1: raise ValueError( f'Expected {ncols - 1} wspaces for {ncols} columns, ' f'got {len(wspace)}.') if len(hpanels) != nrows: raise ValueError( f'Expected {nrows} hpanel toggles for {nrows} rows, ' f'got {len(hpanels)}.') if len(wpanels) != ncols: raise ValueError( f'Expected {ncols} wpanel toggles for {ncols} columns, ' f'got {len(wpanels)}.') # Get indices corresponding to main axes or main axes space slots idxs_ratios, idxs_space = [], [] for panels in (hpanels, wpanels): # Ratio indices mask = np.array([bool(s) for s in panels]) ratio_idxs, = np.where(~mask) idxs_ratios.append(ratio_idxs) # Space indices space_idxs = [] for idx in ratio_idxs[:-1]: # exclude last axes slot offset = 1 while panels[idx + offset] not in 'rbf': # main space next to this offset += 1 space_idxs.append(idx + offset - 1) idxs_space.append(space_idxs) # Separate the panel and axes ratios hratios_main = [hratios[idx] for idx in idxs_ratios[0]] wratios_main = [wratios[idx] for idx in idxs_ratios[1]] hratios_panels = [ratio for idx, ratio in enumerate( hratios) if idx not in idxs_ratios[0]] wratios_panels = [ratio for idx, ratio in enumerate( wratios) if idx not in idxs_ratios[1]] hspace_main = [hspace[idx] for idx in idxs_space[0]] wspace_main = [wspace[idx] for idx in idxs_space[1]] # Reduced geometry nrows_main = len(hratios_main) ncols_main = len(wratios_main) # Get reference properties, account for panel slots in space and ratios # TODO: Shouldn't panel space be included in these calculations? (x1, x2), (y1, y2) = xref, yref dx, dy = x2 - x1 + 1, y2 - y1 + 1 rwspace = sum(wspace_main[x1:x2]) rhspace = sum(hspace_main[y1:y2]) rwratio = ( ncols_main * sum(wratios_main[x1:x2 + 1])) / (dx * sum(wratios_main)) rhratio = ( nrows_main * sum(hratios_main[y1:y2 + 1])) / (dy * sum(hratios_main)) if rwratio == 0 or rhratio == 0: raise RuntimeError( f'Something went wrong, got wratio={rwratio!r} ' f'and hratio={rhratio!r} for reference axes.') if np.iterable(aspect): aspect = aspect[0] / aspect[1] # Determine figure and axes dims from input in width or height dimenion. # For e.g. common use case [[1,1,2,2],[0,3,3,0]], make sure we still scale # the reference axes like square even though takes two columns of gridspec! auto_width = (width is None and height is not None) auto_height = (height is None and width is not None) if width is None and height is None: # get stuff directly from axes if axwidth is None and axheight is None: axwidth = units(rc['subplots.axwidth']) if axheight is not None: auto_width = True axheight_all = (nrows_main * (axheight - rhspace)) / (dy * rhratio) height = axheight_all + top + bottom + \ sum(hspace) + sum(hratios_panels) if axwidth is not None: auto_height = True axwidth_all = (ncols_main * (axwidth - rwspace)) / (dx * rwratio) width = axwidth_all + left + right + \ sum(wspace) + sum(wratios_panels) if axwidth is not None and axheight is not None: auto_width = auto_height = False else: if height is not None: axheight_all = height - top - bottom - \ sum(hspace) - sum(hratios_panels) axheight = (axheight_all * dy * rhratio) / nrows_main + rhspace if width is not None: axwidth_all = width - left - right - \ sum(wspace) - sum(wratios_panels) axwidth = (axwidth_all * dx * rwratio) / ncols_main + rwspace # Automatically figure dim that was not specified above if auto_height: axheight = axwidth / aspect axheight_all = (nrows_main * (axheight - rhspace)) / (dy * rhratio) height = axheight_all + top + bottom + \ sum(hspace) + sum(hratios_panels) elif auto_width: axwidth = axheight * aspect axwidth_all = (ncols_main * (axwidth - rwspace)) / (dx * rwratio) width = axwidth_all + left + right + sum(wspace) + sum(wratios_panels) if axwidth_all < 0: raise ValueError( f'Not enough room for axes (would have width {axwidth_all}). ' 'Try using tight=False, increasing figure width, or decreasing ' "'left', 'right', or 'wspace' spaces.") if axheight_all < 0: raise ValueError( f'Not enough room for axes (would have height {axheight_all}). ' 'Try using tight=False, increasing figure height, or decreasing ' "'top', 'bottom', or 'hspace' spaces.") # Reconstruct the ratios array with physical units for subplot slots # The panel slots are unchanged because panels have fixed widths wratios_main = axwidth_all * np.array(wratios_main) / sum(wratios_main) hratios_main = axheight_all * np.array(hratios_main) / sum(hratios_main) for idx, ratio in zip(idxs_ratios[0], hratios_main): hratios[idx] = ratio for idx, ratio in zip(idxs_ratios[1], wratios_main): wratios[idx] = ratio # Convert margins to figure-relative coordinates left = left / width bottom = bottom / height right = 1 - right / width top = 1 - top / height # Return gridspec keyword args gridspec_kw = { 'ncols': ncols, 'nrows': nrows, 'wspace': wspace, 'hspace': hspace, 'width_ratios': wratios, 'height_ratios': hratios, 'left': left, 'bottom': bottom, 'right': right, 'top': top, } return (width, height), gridspec_kw, kwargs class _hidelabels(object): """Hide objects temporarily so they are ignored by the tight bounding box algorithm.""" def __init__(self, *args): self._labels = args def __enter__(self): for label in self._labels: label.set_visible(False) def __exit__(self, *args): for label in self._labels: label.set_visible(True) class _unlocker(object): """Suppress warning message when adding subplots, and cleanly reset lock setting if exception raised.""" def __init__(self, fig): self._fig = fig def __enter__(self): self._fig._locked = False def __exit__(self, *args): self._fig._locked = True
[docs]class Figure(mfigure.Figure): """The `~matplotlib.figure.Figure` class returned by `subplots`. At draw-time, an improved tight layout algorithm is employed, and the space around the figure edge, between subplots, and between panels is changed to accommodate subplot content. Figure dimensions may be automatically scaled to preserve subplot aspect ratios.""" def __init__(self, tight=None, pad=None, axpad=None, panelpad=None, includepanels=False, autoformat=True, ref=1, order='C', # documented in subplots but needed here subplots_kw=None, gridspec_kw=None, subplots_orig_kw=None, tight_layout=None, constrained_layout=None, **kwargs): """ Parameters ---------- tight : bool, optional Toggles automatic tight layout adjustments. Default is :rc:`tight`. pad : float or str, optional Padding around edge of figure. Units are interpreted by `~proplot.utils.units`. Default is :rc:`subplots.pad`. axpad : float or str, optional Padding between subplots in adjacent columns and rows. Units are interpreted by `~proplot.utils.units`. Default is :rc:`subplots.axpad`. panelpad : float or str, optional Padding between subplots and axes panels, and between "stacked" panels. Units are interpreted by `~proplot.utils.units`. Default is :rc:`subplots.panelpad`. includepanels : bool, optional Whether to include panels when centering *x* axis labels, *y* axis labels, and figure "super titles" along the edge of the subplot grid. Default is ``False``. autoformat : bool, optional Whether to automatically configure *x* axis labels, *y* axis labels, axis formatters, axes titles, colorbar labels, and legend labels when a `~pandas.Series`, `~pandas.DataFrame` or `~xarray.DataArray` with relevant metadata is passed to a plotting command. gridspec_kw, subplots_kw, subplots_orig_kw Keywords used for initializing the main gridspec, for initializing the figure, and original spacing keyword args used for initializing the figure that override tight layout spacing. Other parameters ---------------- ref, order Documented in `subplots`. tight_layout, constrained_layout Ignored, because ProPlot uses its own tight layout algorithm. **kwargs Passed to `matplotlib.figure.Figure`. See also -------- `~matplotlib.figure.Figure` """ # Initialize first, because need to provide fully initialized figure # as argument to gridspec, because matplotlib tight_layout does that if tight_layout or constrained_layout: _warn_proplot( f'Ignoring tight_layout={tight_layout} and ' f'contrained_layout={constrained_layout}. ProPlot uses its ' 'own tight layout algorithm, activated by default or with ' 'tight=True.') super().__init__(**kwargs) self._locked = False self._pad = units(_notNone(pad, rc['subplots.pad'])) self._axpad = units(_notNone(axpad, rc['subplots.axpad'])) self._panelpad = units(_notNone(panelpad, rc['subplots.panelpad'])) self._auto_format = autoformat self._auto_tight_layout = _notNone(tight, rc['tight']) self._include_panels = includepanels self._order = order # used for configuring panel subplot_grids self._ref_num = ref self._axes_main = [] self._subplots_orig_kw = subplots_orig_kw self._subplots_kw = subplots_kw self._bpanels = [] self._tpanels = [] self._lpanels = [] self._rpanels = [] gridspec = GridSpec(self, **(gridspec_kw or {})) nrows, ncols = gridspec.get_active_geometry() self._barray = np.empty((0, ncols), dtype=bool) self._tarray = np.empty((0, ncols), dtype=bool) self._larray = np.empty((0, nrows), dtype=bool) self._rarray = np.empty((0, nrows), dtype=bool) self._gridspec_main = gridspec self.suptitle('') # add _suptitle attribute @_counter def _add_axes_panel(self, ax, side, filled=False, **kwargs): """Hidden method that powers `~proplot.axes.panel_axes`.""" # Interpret args # NOTE: Axis sharing not implemented for figure panels, 99% of the # time this is just used as construct for adding global colorbars and # legends, really not worth implementing axis sharing s = side[0] if s not in 'lrbt': raise ValueError(f'Invalid side {side!r}.') ax = ax._panel_parent or ax # redirect to main axes side = SIDE_TRANSLATE[s] share, width, space, space_orig = _panels_kwargs( s, filled=filled, figure=False, **kwargs) # Get gridspec and subplotspec indices subplotspec = ax.get_subplotspec() *_, row1, row2, col1, col2 = subplotspec.get_active_rows_columns() pgrid = getattr(ax, '_' + s + 'panels') offset = (len(pgrid) * bool(pgrid)) + 1 if s in 'lr': iratio = (col1 - offset if s == 'l' else col2 + offset) idx1 = slice(row1, row2 + 1) idx2 = iratio else: iratio = (row1 - offset if s == 't' else row2 + offset) idx1 = iratio idx2 = slice(col1, col2 + 1) gridspec_prev = self._gridspec_main gridspec = self._insert_row_column( side, iratio, width, space, space_orig, figure=False) if gridspec is not gridspec_prev: if s == 't': idx1 += 1 elif s == 'l': idx2 += 1 # Draw and setup panel with self._unlock(): pax = self.add_subplot( gridspec[idx1, idx2], sharex=ax._sharex_level, sharey=ax._sharey_level, projection='xy') getattr(ax, '_' + s + 'panels').append(pax) pax._panel_side = side pax._panel_share = share pax._panel_parent = ax # Axis sharing and axis setup only for non-legend or colorbar axes if not filled: ax._share_setup() axis = (pax.yaxis if side in ('left', 'right') else pax.xaxis) # sets tick and tick label positions intelligently getattr(axis, 'tick_' + side)() axis.set_label_position(side) return pax def _add_figure_panel(self, side, span=None, row=None, col=None, rows=None, cols=None, **kwargs): """Adds figure panels. Also modifies the panel attribute stored on the figure to include these panels.""" # Interpret args and enforce sensible keyword args s = side[0] if s not in 'lrbt': raise ValueError(f'Invalid side {side!r}.') side = SIDE_TRANSLATE[s] _, width, space, space_orig = _panels_kwargs( s, filled=True, figure=True, **kwargs) if s in 'lr': for key, value in (('col', col), ('cols', cols)): if value is not None: raise ValueError( f'Invalid keyword arg {key!r} for figure panel ' f'on side {side!r}.') span = _notNone(span, row, rows, None, names=('span', 'row', 'rows')) else: for key, value in (('row', row), ('rows', rows)): if value is not None: raise ValueError( f'Invalid keyword arg {key!r} for figure panel ' f'on side {side!r}.') span = _notNone(span, col, cols, None, names=('span', 'col', 'cols')) # Get props subplots_kw = self._subplots_kw if s in 'lr': panels, nacross = subplots_kw['hpanels'], subplots_kw['ncols'] else: panels, nacross = subplots_kw['wpanels'], subplots_kw['nrows'] array = getattr(self, '_' + s + 'array') npanels, nalong = array.shape # Check span array span = _notNone(span, (1, nalong)) if not np.iterable(span) or len(span) == 1: span = 2 * np.atleast_1d(span).tolist() if len(span) != 2: raise ValueError(f'Invalid span {span!r}.') if span[0] < 1 or span[1] > nalong: raise ValueError( f'Invalid coordinates in span={span!r}. Coordinates ' f'must satisfy 1 <= c <= {nalong}.') start, stop = span[0] - 1, span[1] # zero-indexed # See if there is room for panel in current figure panels # The 'array' is an array of boolean values, where each row corresponds # to another figure panel, moving toward the outside, and boolean # True indicates the slot has been filled iratio = (-1 if s in 'lt' else nacross) # default vals for i in range(npanels): if not any(array[i, start:stop]): array[i, start:stop] = True if s in 'lt': # descending array moves us closer to 0 # npanels=1, i=0 --> iratio=0 # npanels=2, i=0 --> iratio=1 # npanels=2, i=1 --> iratio=0 iratio = npanels - 1 - i else: # descending array moves us closer to nacross-1 # npanels=1, i=0 --> iratio=nacross-1 # npanels=2, i=0 --> iratio=nacross-2 # npanels=2, i=1 --> iratio=nacross-1 iratio = nacross - (npanels - i) break if iratio in (-1, nacross): # add to array iarray = np.zeros((1, nalong), dtype=bool) iarray[0, start:stop] = True array = np.concatenate((array, iarray), axis=0) setattr(self, '_' + s + 'array', array) # Get gridspec and subplotspec indices idxs, = np.where(np.array(panels) == '') if len(idxs) != nalong: raise RuntimeError('Wut?') if s in 'lr': idx1 = slice(idxs[start], idxs[stop - 1] + 1) idx2 = max(iratio, 0) else: idx1 = max(iratio, 0) idx2 = slice(idxs[start], idxs[stop - 1] + 1) gridspec = self._insert_row_column( side, iratio, width, space, space_orig, figure=True) # Draw and setup panel with self._unlock(): pax = self.add_subplot(gridspec[idx1, idx2], projection='xy') getattr(self, '_' + s + 'panels').append(pax) pax._panel_side = side pax._panel_share = False pax._panel_parent = None return pax def _adjust_aspect(self): """Adjust average aspect ratio used for gridspec calculations. This fixes grids with identically fixed aspect ratios, e.g. identically zoomed-in cartopy projections and imshow images.""" # Get aspect ratio if not self._axes_main: return ax = self._axes_main[self._ref_num - 1] mode = ax.get_aspect() aspect = None if mode == 'equal': xscale, yscale = ax.get_xscale(), ax.get_yscale() if xscale == 'linear' and yscale == 'linear': aspect = 1.0 / ax.get_data_ratio() elif xscale == 'log' and yscale == 'log': aspect = 1.0 / ax.get_data_ratio_log() else: pass # matplotlib issues warning, forces aspect == 'auto' # Apply aspect if aspect is not None: aspect = round(aspect * 1e10) * 1e-10 subplots_kw = self._subplots_kw aspect_prev = round(subplots_kw['aspect'] * 1e10) * 1e-10 if aspect != aspect_prev: subplots_kw['aspect'] = aspect figsize, gridspec_kw, _ = _subplots_geometry(**subplots_kw) self.set_size_inches(figsize, manual=False) self._gridspec_main.update(**gridspec_kw) def _adjust_tight_layout(self, renderer): """Applies tight layout scaling that permits flexible figure dimensions and preserves panel widths and subplot aspect ratios. The `renderer` should be a `~matplotlib.backend_bases.RendererBase` instance.""" # Initial stuff axs = self._iter_axes() obox = self.bbox_inches # original bbox bbox = self.get_tightbbox(renderer) gridspec = self._gridspec_main subplots_kw = self._subplots_kw subplots_orig_kw = self._subplots_orig_kw # tight layout overrides if not axs or not subplots_kw or not subplots_orig_kw: return # Tight box *around* figure # Get bounds from old bounding box pad = self._pad left = bbox.xmin bottom = bbox.ymin right = obox.xmax - bbox.xmax top = obox.ymax - bbox.ymax # Apply new bounds, permitting user overrides # TODO: Account for bounding box NaNs? for key, offset in zip( ('left', 'right', 'top', 'bottom'), (left, right, top, bottom) ): previous = subplots_orig_kw[key] current = subplots_kw[key] subplots_kw[key] = _notNone(previous, current - offset + pad) # Get arrays storing gridspec spacing args axpad = self._axpad panelpad = self._panelpad nrows, ncols = gridspec.get_active_geometry() wspace, hspace = subplots_kw['wspace'], subplots_kw['hspace'] wspace_orig = subplots_orig_kw['wspace'] hspace_orig = subplots_orig_kw['hspace'] # Get new subplot spacings, axes panel spacing, figure panel spacing spaces = [] for (w, x, y, nacross, ispace, ispace_orig) in zip( 'wh', 'xy', 'yx', (nrows, ncols), (wspace, hspace), (wspace_orig, hspace_orig), ): # Determine which rows and columns correspond to panels panels = subplots_kw[w + 'panels'] jspace = [*ispace] ralong = np.array([ax._range_gridspec(x) for ax in axs]) racross = np.array([ax._range_gridspec(y) for ax in axs]) for i, (space, space_orig) in enumerate(zip(ispace, ispace_orig)): # Figure out whether this is a normal space, or a # panel stack space/axes panel space pad = axpad if (panels[i] in ('l', 't') and panels[i + 1] in ('l', 't', '') or panels[i] in ('', 'r', 'b') and panels[i + 1] in ('r', 'b') or panels[i] == 'f' and panels[i + 1] == 'f'): pad = panelpad # Find axes that abutt aginst this space on each row groups = [] # i.e. right/bottom edge abutts against this space filt1 = ralong[:, 1] == i # i.e. left/top edge abutts against this space filt2 = ralong[:, 0] == i + 1 for j in range(nacross): # e.g. each row # Get indices filt = (racross[:, 0] <= j) & (j <= racross[:, 1]) if sum(filt) < 2: # no interface here continue idx1, = np.where(filt & filt1) idx2, = np.where(filt & filt2) if idx1.size > 1 or idx2.size > 2: _warn_proplot('This should never happen.') continue # raise RuntimeError('This should never happen.') elif not idx1.size or not idx2.size: continue idx1, idx2 = idx1[0], idx2[0] # Put these axes into unique groups. Store groups as # (left axes, right axes) or (bottom axes, top axes) pairs. ax1, ax2 = axs[idx1], axs[idx2] if x != 'x': # order bottom-to-top ax1, ax2 = ax2, ax1 newgroup = True for (group1, group2) in groups: if ax1 in group1 or ax2 in group2: newgroup = False group1.add(ax1) group2.add(ax2) break if newgroup: groups.append([{ax1}, {ax2}]) # form new group # Get spaces # Layout is lspace, lspaces[0], rspaces[0], wspace, ... # so panels spaces are located where i % 3 is 1 or 2 jspaces = [] for (group1, group2) in groups: x1 = max(ax._range_tightbbox(x)[1] for ax in group1) x2 = min(ax._range_tightbbox(x)[0] for ax in group2) jspaces.append((x2 - x1) / self.dpi) if jspaces: # TODO: why max 0? space = max(0, space - min(jspaces) + pad) space = _notNone(space_orig, space) # user input overwrite jspace[i] = space spaces.append(jspace) # Apply new spaces subplots_kw.update({ 'wspace': spaces[0], 'hspace': spaces[1], }) figsize, gridspec_kw, _ = _subplots_geometry(**subplots_kw) self._gridspec_main.update(**gridspec_kw) self.set_size_inches(figsize, manual=False) def _align_axislabels(self, b=True): """Align spanning *x* and *y* axis labels in the perpendicular direction and, if `b` is ``True``, the parallel direction.""" # TODO: Ensure this is robust to complex panels and shared axes # NOTE: Need to turn off aligned labels before _adjust_tight_layout # call, so cannot put this inside Axes draw tracker = {*()} for ax in self._axes_main: if not isinstance(ax, axes.XYAxes): continue for x, axis in zip('xy', (ax.xaxis, ax.yaxis)): # top or bottom, left or right s = axis.get_label_position()[0] span = getattr(ax, '_span' + x + '_on') align = getattr(ax, '_align' + x + '_on') if s not in 'bl' or axis in tracker: continue axs = ax._get_side_axes(s) for _ in range(2): axs = [getattr(ax, '_share' + x) or ax for ax in axs] # Align axis label offsets axises = [getattr(ax, x + 'axis') for ax in axs] tracker.update(axises) if span or align: grp = getattr(self, '_align_' + x + 'label_grp', None) if grp is not None: for ax in axs[1:]: # copied from source code, add to grouper grp.join(axs[0], ax) elif align: _warn_proplot( f'Aligning *x* and *y* axis labels required ' f'matplotlib >=3.1.0') if not span: continue # Get spanning label position c, spanax = self._get_align_coord(s, axs) spanaxis = getattr(spanax, x + 'axis') spanlabel = spanaxis.label if not hasattr(spanlabel, '_orig_transform'): spanlabel._orig_transform = spanlabel.get_transform() spanlabel._orig_position = spanlabel.get_position() if not b: # toggle off, done before tight layout spanlabel.set_transform(spanlabel._orig_transform) spanlabel.set_position(spanlabel._orig_position) for axis in axises: axis.label.set_visible(True) else: # toggle on, done after tight layout if x == 'x': position = (c, 1) transform = mtransforms.blended_transform_factory( self.transFigure, mtransforms.IdentityTransform()) else: position = (1, c) transform = mtransforms.blended_transform_factory( mtransforms.IdentityTransform(), self.transFigure) for axis in axises: axis.label.set_visible((axis is spanaxis)) spanlabel.update( {'position': position, 'transform': transform}) def _align_labels(self, renderer): """Adjusts position of row and column labels, and aligns figure super title accounting for figure margins and axes and figure panels.""" # Offset using tight bounding boxes # TODO: Super labels fail with popup backend!! Fix this # NOTE: Must use get_tightbbox so (1) this will work if tight layout # mode if off and (2) actually need *two* tight bounding boxes when # labels are present: 1 not including the labels, used to position # them, and 1 including the labels, used to determine figure borders suptitle = self._suptitle suptitle_on = suptitle.get_text().strip() width, height = self.get_size_inches() for s in 'lrbt': # Get axes and offset the label to relevant panel x = ('x' if s in 'lr' else 'y') axs = self._get_align_axes(s) axs = [ax._reassign_suplabel(s) for ax in axs] labels = [getattr(ax, '_' + s + 'label') for ax in axs] coords = [None] * len(axs) if s == 't' and suptitle_on: supaxs = axs with _hidelabels(*labels): for i, (ax, label) in enumerate(zip(axs, labels)): label_on = label.get_text().strip() if not label_on: continue # Get coord from tight bounding box # Include twin axes and panels along the same side extra = ('bt' if s in 'lr' else 'lr') icoords = [] for iax in ax._iter_panels(extra): bbox = iax.get_tightbbox(renderer) if s == 'l': jcoords = (bbox.xmin, 0) elif s == 'r': jcoords = (bbox.xmax, 0) elif s == 't': jcoords = (0, bbox.ymax) else: jcoords = (0, bbox.ymin) c = self.transFigure.inverted().transform(jcoords) c = (c[0] if s in 'lr' else c[1]) icoords.append(c) # Offset, and offset a bit extra for left/right labels # See: # fontsize = label.get_fontsize() if s in 'lr': scale1, scale2 = 0.6, width else: scale1, scale2 = 0.3, height if s in 'lb': coords[i] = min(icoords) - ( scale1 * fontsize / 72) / scale2 else: coords[i] = max(icoords) + ( scale1 * fontsize / 72) / scale2 # Assign coords coords = [i for i in coords if i is not None] if coords: if s in 'lb': c = min(coords) else: c = max(coords) for label in labels: label.update({x: c}) # Update super title position # If no axes on the top row are visible, do not try to align! if suptitle_on and supaxs: ys = [] for ax in supaxs: bbox = ax.get_tightbbox(renderer) _, y = self.transFigure.inverted().transform((0, bbox.ymax)) ys.append(y) x, _ = self._get_align_coord('t', supaxs) y = max(ys) + (0.3 * suptitle.get_fontsize() / 72) / height kw = {'x': x, 'y': y, 'ha': 'center', 'va': 'bottom', 'transform': self.transFigure} suptitle.update(kw) def _insert_row_column( self, side, idx, ratio, space, space_orig, figure=False): """Helper function that "overwrites" the main figure gridspec to make room for a panel. The `side` is the panel side, the `idx` is the slot you want the panel to occupy, and the remaining args are the panel widths and spacings.""" # Constants and stuff # Insert spaces to the left of right panels or to the right of # left panels. And note that since .insert() pushes everything in # that column to the right, actually must insert 1 slot farther to # the right when inserting left panels/spaces s = side[0] if s not in 'lrbt': raise ValueError(f'Invalid side {side}.') idx_space = idx - 1 * bool(s in 'br') idx_offset = 1 * bool(s in 'tl') if s in 'lr': w, ncols = 'w', 'ncols' else: w, ncols = 'h', 'nrows' # Load arrays and test if we need to insert subplots_kw = self._subplots_kw subplots_orig_kw = self._subplots_orig_kw panels = subplots_kw[w + 'panels'] ratios = subplots_kw[w + 'ratios'] spaces = subplots_kw[w + 'space'] spaces_orig = subplots_orig_kw[w + 'space'] # Slot already exists entry = ('f' if figure else s) exists = (idx not in (-1, len(panels)) and panels[idx] == entry) if exists: # already exists! if spaces_orig[idx_space] is None: spaces_orig[idx_space] = units(space_orig) spaces[idx_space] = _notNone(spaces_orig[idx_space], space) # Make room for new panel slot else: # Modify basic geometry idx += idx_offset idx_space += idx_offset subplots_kw[ncols] += 1 # Original space, ratio array, space array, panel toggles spaces_orig.insert(idx_space, space_orig) spaces.insert(idx_space, space) ratios.insert(idx, ratio) panels.insert(idx, entry) # Reference ax location array # TODO: For now do not need to increment, but need to double # check algorithm for fixing axes aspect! # ref = subplots_kw[x + 'ref'] # ref[:] = [val + 1 if val >= idx else val for val in ref] # Update figure figsize, gridspec_kw, _ = _subplots_geometry(**subplots_kw) self.set_size_inches(figsize, manual=False) if exists: gridspec = self._gridspec_main gridspec.update(**gridspec_kw) else: # New gridspec gridspec = GridSpec(self, **gridspec_kw) self._gridspec_main = gridspec # Reassign subplotspecs to all axes and update positions # May seem inefficient but it literally just assigns a hidden, # attribute, and the creation time for subpltospecs is tiny axs = [iax for ax in self._iter_axes() for iax in (ax, *ax.child_axes)] for ax in axs: # Get old index # NOTE: Endpoints are inclusive, not exclusive! if not hasattr(ax, 'get_subplotspec'): continue if s in 'lr': inserts = (None, None, idx, idx) else: inserts = (idx, idx, None, None) subplotspec = ax.get_subplotspec() igridspec = subplotspec.get_gridspec() topmost = subplotspec.get_topmost_subplotspec() # Apply new subplotspec! _, _, *coords = topmost.get_active_rows_columns() for i in range(4): # if inserts[i] is not None and coords[i] >= inserts[i]: if inserts[i] is not None and coords[i] >= inserts[i]: coords[i] += 1 (row1, row2, col1, col2) = coords subplotspec_new = gridspec[row1:row2 + 1, col1:col2 + 1] if topmost is subplotspec: ax.set_subplotspec(subplotspec_new) elif topmost is igridspec._subplot_spec: igridspec._subplot_spec = subplotspec_new else: raise ValueError( f'Unexpected GridSpecFromSubplotSpec nesting.') # Update parent or child position ax.update_params() ax.set_position(ax.figbox) return gridspec def _get_align_coord(self, side, axs): """Returns figure coordinate for spanning labels and super title. The `x` can be ``'x'`` or ``'y'``.""" # Get position in figure relative coordinates s = side[0] x = ('y' if s in 'lr' else 'x') extra = ('tb' if s in 'lr' else 'lr') if self._include_panels: axs = [iax for ax in axs for iax in ax._iter_panels(extra)] ranges = np.array([ax._range_gridspec(x) for ax in axs]) min_, max_ = ranges[:, 0].min(), ranges[:, 1].max() axlo = axs[np.where(ranges[:, 0] == min_)[0][0]] axhi = axs[np.where(ranges[:, 1] == max_)[0][0]] lobox = axlo.get_subplotspec().get_position(self) hibox = axhi.get_subplotspec().get_position(self) if x == 'x': pos = (lobox.x0 + hibox.x1) / 2 else: # 'lo' is actually on top, highest up in gridspec pos = (lobox.y1 + hibox.y0) / 2 # Return axis suitable for spanning position spanax = axs[(np.argmin(ranges[:, 0]) + np.argmax(ranges[:, 1])) // 2] spanax = spanax._panel_parent or spanax return pos, spanax def _get_align_axes(self, side): """Returns main axes along the left, right, bottom, or top sides of the figure.""" # Initial stuff s = side[0] idx = (0 if s in 'lt' else 1) if s in 'lr': x, y = 'x', 'y' else: x, y = 'y', 'x' # Get edge index axs = self._axes_main if not axs: return [] ranges = np.array([ax._range_gridspec(x) for ax in axs]) min_, max_ = ranges[:, 0].min(), ranges[:, 1].max() edge = (min_ if s in 'lt' else max_) # Return axes on edge sorted by order of appearance axs = [ax for ax in self._axes_main if ax._range_gridspec(x)[ idx] == edge] ranges = [ax._range_gridspec(y)[0] for ax in axs] return [ax for _, ax in sorted(zip(ranges, axs)) if ax.get_visible()] def _get_renderer(self): """Get a renderer at all costs, even if it means generating a brand new one! Used for updating the figure bounding box when it is accessed and calculating centered-row legend bounding boxes. This is copied from in matplotlib.""" if self._cachedRenderer: renderer = self._cachedRenderer else: canvas = self.canvas if canvas and hasattr(canvas, 'get_renderer'): renderer = canvas.get_renderer() else: from matplotlib.backends.backend_agg import FigureCanvasAgg canvas = FigureCanvasAgg(self) renderer = canvas.get_renderer() return renderer def _unlock(self): """Prevent warning message when adding subplots one-by-one. Used internally.""" return _unlocker(self) def _update_figtitle(self, title, **kwargs): """Assign figure "super title".""" if title is not None and self._suptitle.get_text() != title: self._suptitle.set_text(title) if kwargs: self._suptitle.update(kwargs) def _update_labels(self, ax, side, labels, **kwargs): """Assigns side labels, updates label settings.""" s = side[0] if s not in 'lrbt': raise ValueError(f'Invalid label side {side!r}.') # Get main axes on the edge axs = self._get_align_axes(s) if not axs: return # occurs if called while adding axes # Update label text for axes on the edge if labels is None or isinstance(labels, str): # common during testing labels = [labels] * len(axs) if len(labels) != len(axs): raise ValueError( f'Got {len(labels)} {s}labels, but there are {len(axs)} axes ' 'along that side.') for ax, label in zip(axs, labels): obj = getattr(ax, '_' + s + 'label') if label is not None and obj.get_text() != label: obj.set_text(label) if kwargs: obj.update(kwargs)
[docs] def add_subplot(self, *args, **kwargs): """Issues warning for new users that try to call `~matplotlib.figure.Figure.add_subplot` manually.""" if self._locked: _warn_proplot( 'Using "fig.add_subplot()" with ProPlot figures may result in ' 'unexpected behavior. Use "proplot.subplots()" instead.') ax = super().add_subplot(*args, **kwargs) return ax
[docs] def colorbar(self, *args, loc='r', width=None, space=None, row=None, col=None, rows=None, cols=None, span=None, **kwargs): """ Draws a colorbar along the left, right, bottom, or top side of the figure, centered between the leftmost and rightmost (or topmost and bottommost) main axes. Parameters ---------- loc : str, optional The colorbar location. Valid location keys are as follows. =========== ===================== Location Valid keys =========== ===================== left edge ``'l'``, ``'left'`` right edge ``'r'``, ``'right'`` bottom edge ``'b'``, ``'bottom'`` top edge ``'t'``, ``'top'`` =========== ===================== row, rows : optional Aliases for `span` for panels on the left or right side. col, cols : optional Aliases for `span` for panels on the top or bottom side. span : int or (int, int), optional Describes how the colorbar spans rows and columns of subplots. For example, ``fig.colorbar(loc='b', col=1)`` draws a colorbar beneath the leftmost column of subplots, and ``fig.colorbar(loc='b', cols=(1,2))`` draws a colorbar beneath the left two columns of subplots. By default, the colorbar will span all rows and columns. space : float or str, optional The space between the main subplot grid and the colorbar, or the space between successively stacked colorbars. Units are interpreted by `~proplot.utils.units`. By default, this is determined by the "tight layout" algorithm, or is :rc:`subplots.panelspace` if "tight layout" is off. width : float or str, optional The colorbar width. Units are interpreted by `~proplot.utils.units`. Default is :rc:`colorbar.width`. *args, **kwargs Passed to `~proplot.axes.Axes.colorbar`. """ if 'cax' in kwargs: return super().colorbar(*args, **kwargs) elif 'ax' in kwargs: return kwargs.pop('ax').colorbar( *args, space=space, width=width, **kwargs) else: ax = self._add_figure_panel( loc, space=space, width=width, span=span, row=row, col=col, rows=rows, cols=cols) return ax.colorbar(*args, loc='_fill', **kwargs)
[docs] def legend(self, *args, loc='r', width=None, space=None, row=None, col=None, rows=None, cols=None, span=None, **kwargs): """ Draws a legend along the left, right, bottom, or top side of the figure, centered between the leftmost and rightmost (or topmost and bottommost) main axes. Parameters ---------- loc : str, optional The legend location. Valid location keys are as follows. =========== ===================== Location Valid keys =========== ===================== left edge ``'l'``, ``'left'`` right edge ``'r'``, ``'right'`` bottom edge ``'b'``, ``'bottom'`` top edge ``'t'``, ``'top'`` =========== ===================== row, rows : optional Aliases for `span` for panels on the left or right side. col, cols : optional Aliases for `span` for panels on the top or bottom side. span : int or (int, int), optional Describes how the legend spans rows and columns of subplots. For example, ``fig.legend(loc='b', col=1)`` draws a legend beneath the leftmost column of subplots, and ``fig.legend(loc='b', cols=(1,2))`` draws a legend beneath the left two columns of subplots. By default, the legend will span all rows and columns. space : float or str, optional The space between the main subplot grid and the legend, or the space between successively stacked colorbars. Units are interpreted by `~proplot.utils.units`. By default, this is adjusted automatically in the "tight layout" calculation, or is :rc:`subplots.panelspace` if "tight layout" is turned off. *args, **kwargs Passed to `~proplot.axes.Axes.legend`. """ if 'ax' in kwargs: return kwargs.pop('ax').legend(*args, space=space, width=width, **kwargs) else: ax = self._add_figure_panel(loc, space=space, width=width, span=span, row=row, col=col, rows=rows, cols=cols) return ax.legend(*args, loc='_fill', **kwargs)
def save(self, filename, **kwargs): # Alias for `~Figure.savefig` because ``fig.savefig`` is redundant. # Also automatically expands user paths e.g. the tilde ``'~'``. return self.savefig(filename, **kwargs) def savefig(self, filename, **kwargs): # Automatically expand user because why in gods name does # matplotlib not already do this. Undocumented because do not # want to overwrite matplotlib docstring. super().savefig(os.path.expanduser(filename), **kwargs) def set_canvas(self, canvas): # Set the canvas and add monkey patches to the instance-level # `~matplotlib.backend_bases.FigureCanvasBase.draw_idle` and # `~matplotlib.backend_bases.FigureCanvasBase.print_figure` # methods. The latter is called by save() and by the inline backend. # See `_canvas_preprocess` for details.""" # NOTE: Use draw_idle() rather than draw() becuase latter is not # always called! For example, MacOSX uses _draw() and nbAgg does # not call draw() *or* _draw()! Not sure how it works actually. # Should be same because we piggyback draw() which *itself* defers # the event. Just make sure to check _is_idle_drawing! canvas.draw_idle = _canvas_preprocess(canvas, 'draw_idle') canvas.print_figure = _canvas_preprocess(canvas, 'print_figure') super().set_canvas(canvas) def set_size_inches(self, w, h=None, forward=True, manual=True): # Set the figure size and, if this is being called manually or from # an interactive backend, override the geometry tracker so users can # use interactive backends. See #76. Undocumented because this is # only relevant internally. # NOTE: Bitmap renderers use int(Figure.bbox.[width|height]) which # rounds to whole pixels. So when renderer resizes the figure # internally there may be roundoff error! Always compare to *both* # Figure.get_size_inches() and the truncated bbox dimensions times dpi. # Comparison is critical because most renderers call set_size_inches() # before any resizing interaction! if h is None: width, height = w else: width, height = w, h if not all(np.isfinite(_) for _ in (width, height)): raise ValueError('Figure size must be finite, not ' f'({width}, {height}).') width_true, height_true = self.get_size_inches() width_trunc = int(self.bbox.width) / self.dpi height_trunc = int(self.bbox.height) / self.dpi if (manual # have actually seen (width_true, heigh_trunc)! and width not in (width_true, width_trunc) and height not in (height_true, height_trunc)): self._subplots_kw.update(width=width, height=height) super().set_size_inches(width, height, forward=forward) def _iter_axes(self): """Iterates over all axes and panels in the figure belonging to the `~proplot.axes.Axes` class. Excludes inset and twin axes.""" axs = [] for ax in (*self._axes_main, *self._lpanels, *self._rpanels, *self._bpanels, *self._tpanels): if not ax or not ax.get_visible(): continue axs.append(ax) for ax in axs: for s in 'lrbt': for iax in getattr(ax, '_' + s + 'panels'): if not iax or not iax.get_visible(): continue axs.append(iax) return axs
def _journals(journal): """Journal sizes for figures.""" # Get dimensions for figure from common journals. value = JOURNAL_SPECS.get(journal, None) if value is None: raise ValueError( f'Unknown journal figure size specifier {journal!r}. ' 'Current options are: ' + ', '.join(map(repr, JOURNAL_SPECS.keys()))) # Return width, and optionally also the height width, height = None, None try: width, height = value except (TypeError, ValueError): width = value return width, height def _axes_dict(naxs, value, kw=False, default=None): """Build a dictionary that looks like ``{1:value1, 2:value2, ...}`` or ``{1:{key1:value1, ...}, 2:{key2:value2, ...}, ...}`` for storing standardized axes-specific properties or keyword args.""" # First build up dictionary # 1) 'string' or {1:'string1', (2,3):'string2'} if not kw: if np.iterable(value) and not isinstance(value, (str, dict)): value = {num + 1: item for num, item in enumerate(value)} elif not isinstance(value, dict): value = {range(1, naxs + 1): value} # 2) {'prop':value} or {1:{'prop':value1}, (2,3):{'prop':value2}} else: nested = [isinstance(value, dict) for value in value.values()] if not any(nested): # any([]) == False value = {range(1, naxs + 1): value.copy()} elif not all(nested): raise ValueError( 'Pass either of dictionary of key value pairs or ' 'a dictionary of dictionaries of key value pairs.') # Then *unfurl* keys that contain multiple axes numbers, i.e. are meant # to indicate properties for multiple axes at once kwargs = {} for nums, item in value.items(): nums = np.atleast_1d(nums) for num in nums.flat: if not kw: kwargs[num] = item else: kwargs[num] = item.copy() # Fill with default values for num in range(1, naxs + 1): if num not in kwargs: if kw: kwargs[num] = {} else: kwargs[num] = default # Verify numbers if {*range(1, naxs + 1)} != {*kwargs.keys()}: raise ValueError( f'Have {naxs} axes, but {value} has properties for axes ' + ', '.join(repr(i) for i in sorted(kwargs.keys())) + '.') return kwargs
[docs]def subplots( array=None, ncols=1, nrows=1, ref=1, order='C', aspect=1, figsize=None, width=None, height=None, journal=None, axwidth=None, axheight=None, hspace=None, wspace=None, space=None, hratios=None, wratios=None, width_ratios=None, height_ratios=None, flush=None, wflush=None, hflush=None, left=None, bottom=None, right=None, top=None, tight=None, pad=None, axpad=None, panelpad=None, span=None, spanx=None, spany=None, align=None, alignx=None, aligny=None, share=None, sharex=None, sharey=None, basemap=False, proj=None, projection=None, proj_kw=None, projection_kw=None, autoformat=True, includepanels=False): """ Analogous to `matplotlib.pyplot.subplots`, creates a figure with a single axes or arbitrary grids of axes, any of which can be map projections. Parameters ---------- array : 2D array-like of int, optional Array specifying complex grid of subplots. Think of this array as a "picture" of your figure. For example, the array ``[[1, 1], [2, 3]]`` creates one long subplot in the top row, two smaller subplots in the bottom row. Integers must range from 1 to the number of plots. ``0`` indicates an empty space. For example, ``[[1, 1, 1], [2, 0, 3]]`` creates one long subplot in the top row with two subplots in the bottom row separated by a space. ncols, nrows : int, optional Number of columns, rows. Ignored if `array` is not ``None``. Use these arguments for simpler subplot grids. order : {'C', 'F'}, optional Whether subplots are numbered in column-major (``'C'``) or row-major (``'F'``) order. Analogous to `numpy.array` ordering. This controls the order axes appear in the `axs` list, and the order of subplot a-b-c labeling (see `~proplot.axes.Axes.format`). figsize : length-2 tuple, optional Tuple specifying the figure `(width, height)`. width, height : float or str, optional The figure width and height. Units are interpreted by `~proplot.utils.units`. journal : str, optional String name corresponding to an academic journal standard that is used to control the figure width (and height, if specified). See below table. =========== ==================== ========================================================================================================================================================== Key Size description Organization =========== ==================== ========================================================================================================================================================== ``'aaas1'`` 1-column `American Association for the Advancement of Science <>`__ (e.g. *Science*) ``'aaas2'`` 2-column ” ``'agu1'`` 1-column `American Geophysical Union <>`__ ``'agu2'`` 2-column ” ``'agu3'`` full height 1-column ” ``'agu4'`` full height 2-column ” ``'ams1'`` 1-column `American Meteorological Society <>`__ ``'ams2'`` small 2-column ” ``'ams3'`` medium 2-column ” ``'ams4'`` full 2-column ” ``'nat1'`` 1-column `Nature Research <>`__ ``'nat2'`` 2-column ” ``'pnas1'`` 1-column `Proceedings of the National Academy of Sciences <>`__ ``'pnas2'`` 2-column ” ``'pnas3'`` landscape page ” =========== ==================== ========================================================================================================================================================== ref : int, optional The reference axes number. The `axwidth`, `axheight`, and `aspect` keyword args are applied to this axes, and aspect ratio is conserved for this axes in tight layout adjustment. axwidth, axheight : float or str, optional Sets the average width, height of your axes. Units are interpreted by `~proplot.utils.units`. Default is :rc:`subplots.axwidth`. These arguments are convenient where you don't care about the figure dimensions and just want your axes to have enough "room". aspect : float or length-2 list of floats, optional The (average) axes aspect ratio, in numeric form (width divided by height) or as (width, height) tuple. If you do not provide the `hratios` or `wratios` keyword args, all axes will have identical aspect ratios. hratios, wratios Aliases for `height_ratios`, `width_ratios`. width_ratios, height_ratios : float or list thereof, optional Passed to `GridSpec`, denotes the width and height ratios for the subplot grid. Length of `width_ratios` must match the number of rows, and length of `height_ratios` must match the number of columns. wspace, hspace, space : float or str or list thereof, optional Passed to `GridSpec`, denotes the spacing between grid columns, rows, and both, respectively. If float or string, expanded into lists of length ``ncols-1`` (for `wspace`) or length ``nrows-1`` (for `hspace`). Units are interpreted by `~proplot.utils.units` for each element of the list. By default, these are determined by the "tight layout" algorithm. left, right, top, bottom : float or str, optional Passed to `GridSpec`, denotes the width of padding between the subplots and the figure edge. Units are interpreted by `~proplot.utils.units`. By default, these are determined by the "tight layout" algorithm. sharex, sharey, share : {3, 2, 1, 0}, optional The "axis sharing level" for the *x* axis, *y* axis, or both axes. Default is ``3``. This can considerably redundancy in your figure. The options are as follows: 0. No axis sharing. Also sets the default `spanx` and `spany` values to ``False``. 1. Only draw *axis label* on the leftmost column (*y*) or bottommost row (*x*) of subplots. Axis tick labels still appear on every subplot. 2. As in 1, but forces the axis limits to be identical. Axis tick labels still appear on every subplot. 3. As in 2, but only show the *axis tick labels* on the leftmost column (*y*) or bottommost row (*x*) of subplots. spanx, spany, span : bool or {0, 1}, optional Default is ``False`` if `sharex`, `sharey`, or `share` are ``0``, ``True`` otherwise. Toggles "spanning" axis labels for the *x* axis, *y* axis, or both axes. When ``True``, a single, centered axis label is used for all axes with bottom and left edges in the same row or column. This can considerably redundancy in your figure. "Spanning" labels integrate with "shared" axes. For example, for a 3-row, 3-column figure, with ``sharey > 1`` and ``spany=1``, your figure will have 1 ylabel instead of 9. alignx, aligny, align : bool or {0, 1}, optional Default is ``False``. Whether to `align axis labels \ <>`__ for the *x* axis, *y* axis, or both axes. Only has an effect when `spanx`, `spany`, or `span` are ``False``. proj, projection : str or dict-like, optional The map projection name. The argument is interpreted as follows. * If string, this projection is used for all subplots. For valid names, see the `~proplot.projs.Proj` documentation. * If list of string, these are the projections to use for each subplot in their `array` order. * If dict-like, keys are integers or tuple integers that indicate the projection to use for each subplot. If a key is not provided, that subplot will be a `~proplot.axes.XYAxes`. For example, in a 4-subplot figure, ``proj={2:'merc', (3,4):'stere'}`` draws a Cartesian axes for the first subplot, a Mercator projection for the second subplot, and a Stereographic projection for the second and third subplots. proj_kw, projection_kw : dict-like, optional Keyword arguments passed to `~mpl_toolkits.basemap.Basemap` or cartopy `` classes on instantiation. If dictionary of properties, applies globally. If *dictionary of dictionaries* of properties, applies to specific subplots, as with `proj`. For example, with ``ncols=2`` and ``proj_kw={1:{'lon_0':0}, 2:{'lon_0':180}}``, the projection in the left subplot is centered on the prime meridian, and the projection in the right subplot is centered on the international dateline. basemap : bool or dict-like, optional Whether to use `~mpl_toolkits.basemap.Basemap` or `` for map projections. Default is ``False``. If boolean, applies to all subplots. If dictionary, values apply to specific subplots, as with `proj`. Other parameters ---------------- tight : bool, optional Toggles automatic tight layout adjustments. Default is :rc:`tight`. If you manually specify a spacing, it will be used to override the tight layout spacing -- for example, with ``left=0.1``, the left margin is set to 0.1 inches wide, while the remaining margin widths are calculated automatically. pad, axpad, panelpad : float or str, optional Padding for automatic tight layout adjustments. See `Figure` for details. includepanels : bool, optional Whether to include panels when calculating the position of certain spanning labels. See `Figure` for details. autoformat : bool, optional Whether to automatically format axes when special datasets are passed to plotting commands. See `Figure` for details. Returns ------- f : `Figure` The figure instance. axs : `subplot_grid` A special list of axes instances. See `subplot_grid`. """ # noqa rc._getitem_mode = 0 # Build array if order not in ('C', 'F'): # better error message raise ValueError( f'Invalid order {order!r}. Choose from "C" (row-major, default) ' f'and "F" (column-major).') if array is None: array = np.arange(1, nrows * ncols + 1)[..., None] array = array.reshape((nrows, ncols), order=order) # Standardize array try: array = np.array(array, dtype=int) # enforce array type if array.ndim == 1: # interpret as single row or column array = array[None, :] if order == 'C' else array[:, None] elif array.ndim != 2: raise ValueError array[array == None] = 0 # use zero for placeholder # noqa except (TypeError, ValueError): raise ValueError( f'Invalid subplot array {array!r}. ' 'Must be 1D or 2D array of integers.') # Get other props nums = np.unique(array[array != 0]) naxs = len(nums) if {*nums.flat} != {*range(1, naxs + 1)}: raise ValueError( f'Invalid subplot array {array!r}. Numbers must span integers ' '1 to naxs (i.e. cannot skip over numbers), with 0 representing ' 'empty spaces.') if ref not in nums: raise ValueError( f'Invalid reference number {ref!r}. For array {array!r}, must be ' 'one of {nums}.') nrows, ncols = array.shape # Figure out rows and columns "spanned" by each axes in list, for # axis sharing and axis label spanning settings sharex = int(_notNone(sharex, share, rc['share'])) sharey = int(_notNone(sharey, share, rc['share'])) if sharex not in range(4) or sharey not in range(4): raise ValueError( f'Axis sharing level can be 0 (no sharing), ' '1 (sharing, but keep all tick labels), ' '2 (sharing, keep one set of tick labels), ' 'or 3 (sharing, keep one axis label and one set of tick labels)' 'Got sharex={sharex} and sharey={sharey}.') spanx = _notNone(spanx, span, 0 if sharex == 0 else None, rc['span']) spany = _notNone(spany, span, 0 if sharey == 0 else None, rc['span']) alignx = _notNone(alignx, align) aligny = _notNone(aligny, align) if (spanx and alignx) or (spany and aligny): _warn_proplot( f'The "alignx" and "aligny" args have no effect when ' '"spanx" and "spany" are True.') alignx = _notNone(alignx, rc['align']) aligny = _notNone(alignx, rc['align']) # Get some axes properties, where locations are sorted by axes id. # NOTE: These ranges are endpoint exclusive, like a slice object! axids = [np.where(array == i) for i in np.sort( np.unique(array)) if i > 0] # 0 stands for empty xrange = np.array([[x.min(), x.max()] for _, x in axids]) yrange = np.array([[y.min(), y.max()] for y, _ in axids]) # range accounting for panels xref = xrange[ref - 1, :] # range for reference axes yref = yrange[ref - 1, :] # Get basemap.Basemap or instances for map proj = _notNone(projection, proj, None, names=('projection', 'proj')) proj_kw = _notNone(projection_kw, proj_kw, {}, names=('projection_kw', 'proj_kw')) proj = _axes_dict(naxs, proj, kw=False, default='xy') proj_kw = _axes_dict(naxs, proj_kw, kw=True) basemap = _axes_dict(naxs, basemap, kw=False, default=False) axes_kw = {num: {} for num in range(1, naxs + 1)} # stores add_subplot arguments for num, name in proj.items(): # The default is XYAxes if name is None or name == 'xy': axes_kw[num]['projection'] = 'xy' # Builtin matplotlib polar axes, just use my overridden version elif name == 'polar': axes_kw[num]['projection'] = 'polar' if num == ref: aspect = 1 # Custom Basemap and Cartopy axes else: package = 'basemap' if basemap[num] else 'geo' obj, iaspect = projs.Proj( name, basemap=basemap[num], **proj_kw[num]) if num == ref: aspect = iaspect axes_kw[num].update({'projection': package, 'map_projection': obj}) # Figure and/or axes dimensions names, values = (), () if journal: # if user passed width=<string > , will use that journal size figsize = _journals(journal) spec = f'journal={journal!r}' names = ('axwidth', 'axheight', 'width') values = (axwidth, axheight, width) width, height = figsize elif figsize: spec = f'figsize={figsize!r}' names = ('axwidth', 'axheight', 'width', 'height') values = (axwidth, axheight, width, height) width, height = figsize elif width is not None or height is not None: spec = [] if width is not None: spec.append(f'width={width!r}') if height is not None: spec.append(f'height={height!r}') spec = ', '.join(spec) names = ('axwidth', 'axheight') values = (axwidth, axheight) # Raise warning for name, value in zip(names, values): if value is not None: _warn_proplot( f'You specified both {spec} and {name}={value!r}. ' f'Ignoring {name!r}.') # Standardized dimensions width, height = units(width), units(height) axwidth, axheight = units(axwidth), units(axheight) # Standardized user input border spaces left, right = units(left), units(right) bottom, top = units(bottom), units(top) # Standardized user input spaces wspace = np.atleast_1d(units(_notNone(wspace, space))) hspace = np.atleast_1d(units(_notNone(hspace, space))) if len(wspace) == 1: wspace = np.repeat(wspace, (ncols - 1,)) if len(wspace) != ncols - 1: raise ValueError( f'Require {ncols-1} width spacings for {ncols} columns, ' 'got {len(wspace)}.') if len(hspace) == 1: hspace = np.repeat(hspace, (nrows - 1,)) if len(hspace) != nrows - 1: raise ValueError( f'Require {nrows-1} height spacings for {nrows} rows, ' 'got {len(hspace)}.') # Standardized user input ratios wratios = np.atleast_1d(_notNone(width_ratios, wratios, 1, names=('width_ratios', 'wratios'))) hratios = np.atleast_1d(_notNone(height_ratios, hratios, 1, names=('height_ratios', 'hratios'))) if len(wratios) == 1: wratios = np.repeat(wratios, (ncols,)) if len(hratios) == 1: hratios = np.repeat(hratios, (nrows,)) if len(wratios) != ncols: raise ValueError(f'Got {ncols} columns, but {len(wratios)} wratios.') if len(hratios) != nrows: raise ValueError(f'Got {nrows} rows, but {len(hratios)} hratios.') # Fill subplots_orig_kw with user input values # NOTE: 'Ratios' are only fixed for panel axes, but we store entire array wspace, hspace = wspace.tolist(), hspace.tolist() wratios, hratios = wratios.tolist(), hratios.tolist() subplots_orig_kw = { 'left': left, 'right': right, 'top': top, 'bottom': bottom, 'wspace': wspace, 'hspace': hspace, } # Default border spaces left = _notNone(left, units(rc['subplots.ylabspace'])) right = _notNone(right, units(rc['subplots.innerspace'])) top = _notNone(top, units(rc['subplots.titlespace'])) bottom = _notNone(bottom, units(rc['subplots.xlabspace'])) # Default spaces between axes wratios, hratios = [*wratios], [*hratios] # copies wspace, hspace = np.array(wspace), np.array(hspace) # also copies! wspace[wspace == None] = ( # noqa units(rc['subplots.innerspace']) if sharey == 3 else units(rc['subplots.ylabspace']) - units(rc['subplots.titlespace']) if sharey in (1, 2) else units(rc['subplots.ylabspace'])) hspace[hspace == None] = ( # noqa units(rc['subplots.titlespace']) + units(rc['subplots.innerspace']) if sharex == 3 else units(rc['subplots.xlabspace']) if sharex in (1, 2) else units(rc['subplots.titlespace']) + units(rc['subplots.xlabspace']) ) wspace, hspace = wspace.tolist(), hspace.tolist() # Parse arguments, fix dimensions in light of desired aspect ratio figsize, gridspec_kw, subplots_kw = _subplots_geometry( nrows=nrows, ncols=ncols, aspect=aspect, xref=xref, yref=yref, left=left, right=right, bottom=bottom, top=top, width=width, height=height, axwidth=axwidth, axheight=axheight, wratios=wratios, hratios=hratios, wspace=wspace, hspace=hspace, wpanels=[''] * ncols, hpanels=[''] * nrows, ) fig = plt.figure( FigureClass=Figure, tight=tight, figsize=figsize, ref=ref, pad=pad, axpad=axpad, panelpad=panelpad, autoformat=autoformat, includepanels=includepanels, subplots_orig_kw=subplots_orig_kw, subplots_kw=subplots_kw, gridspec_kw=gridspec_kw) gridspec = fig._gridspec_main # Draw main subplots axs = naxs * [None] # list of axes for idx in range(naxs): # Get figure gridspec ranges num = idx + 1 x0, x1 = xrange[idx, 0], xrange[idx, 1] y0, y1 = yrange[idx, 0], yrange[idx, 1] # Draw subplot subplotspec = gridspec[y0:y1 + 1, x0:x1 + 1] with fig._unlock(): axs[idx] = fig.add_subplot( subplotspec, number=num, spanx=spanx, spany=spany, alignx=alignx, aligny=aligny, sharex=sharex, sharey=sharey, main=True, **axes_kw[num]) # Shared axes setup # TODO: Figure out how to defer this to drawtime in #50 # For some reason just adding _share_setup() to draw() doesn't work for ax in axs: ax._share_setup() # Return figure and axes n = (ncols if order == 'C' else nrows) return fig, subplot_grid(axs, n=n, order=order)