Source code for mpl_toolkits.axisartist.grid_finder

import numpy as np

from matplotlib import cbook, ticker as mticker
from matplotlib.transforms import Bbox, Transform
from .clip_path import clip_line_to_rect


def _deprecate_factor_none(factor):
    # After the deprecation period, calls to _deprecate_factor_none can just be
    # removed.
    if factor is None:
        cbook.warn_deprecated(
            "3.2", message="factor=None is deprecated since %(since)s and "
            "support will be removed %(removal)s; use/return factor=1 instead")
        factor = 1
    return factor


[docs]class ExtremeFinderSimple: """ A helper class to figure out the range of grid lines that need to be drawn. """
[docs] def __init__(self, nx, ny): """ Parameters ---------- nx, ny : int The number of samples in each direction. """ self.nx = nx self.ny = ny
[docs] def __call__(self, transform_xy, x1, y1, x2, y2): """ Compute an approximation of the bounding box obtained by applying *transform_xy* to the box delimited by ``(x1, y1, x2, y2)``. The intended use is to have ``(x1, y1, x2, y2)`` in axes coordinates, and have *transform_xy* be the transform from axes coordinates to data coordinates; this method then returns the range of data coordinates that span the actual axes. The computation is done by sampling ``nx * ny`` equispaced points in the ``(x1, y1, x2, y2)`` box and finding the resulting points with extremal coordinates; then adding some padding to take into account the finite sampling. As each sampling step covers a relative range of *1/nx* or *1/ny*, the padding is computed by expanding the span covered by the extremal coordinates by these fractions. """ x, y = np.meshgrid( np.linspace(x1, x2, self.nx), np.linspace(y1, y2, self.ny)) xt, yt = transform_xy(np.ravel(x), np.ravel(y)) return self._add_pad(xt.min(), xt.max(), yt.min(), yt.max())
def _add_pad(self, x_min, x_max, y_min, y_max): """Perform the padding mentioned in `__call__`.""" dx = (x_max - x_min) / self.nx dy = (y_max - y_min) / self.ny return x_min - dx, x_max + dx, y_min - dy, y_max + dy
class GridFinder: def __init__(self, transform, extreme_finder=None, grid_locator1=None, grid_locator2=None, tick_formatter1=None, tick_formatter2=None): """ transform : transform from the image coordinate (which will be the transData of the axes to the world coordinate. or transform = (transform_xy, inv_transform_xy) locator1, locator2 : grid locator for 1st and 2nd axis. """ if extreme_finder is None: extreme_finder = ExtremeFinderSimple(20, 20) if grid_locator1 is None: grid_locator1 = MaxNLocator() if grid_locator2 is None: grid_locator2 = MaxNLocator() if tick_formatter1 is None: tick_formatter1 = FormatterPrettyPrint() if tick_formatter2 is None: tick_formatter2 = FormatterPrettyPrint() self.extreme_finder = extreme_finder self.grid_locator1 = grid_locator1 self.grid_locator2 = grid_locator2 self.tick_formatter1 = tick_formatter1 self.tick_formatter2 = tick_formatter2 self.update_transform(transform) def get_grid_info(self, x1, y1, x2, y2): """ lon_values, lat_values : list of grid values. if integer is given, rough number of grids in each direction. """ extremes = self.extreme_finder(self.inv_transform_xy, x1, y1, x2, y2) # min & max rage of lat (or lon) for each grid line will be drawn. # i.e., gridline of lon=0 will be drawn from lat_min to lat_max. lon_min, lon_max, lat_min, lat_max = extremes lon_levs, lon_n, lon_factor = self.grid_locator1(lon_min, lon_max) lat_levs, lat_n, lat_factor = self.grid_locator2(lat_min, lat_max) lon_values = lon_levs[:lon_n] / _deprecate_factor_none(lon_factor) lat_values = lat_levs[:lat_n] / _deprecate_factor_none(lat_factor) lon_lines, lat_lines = self._get_raw_grid_lines(lon_values, lat_values, lon_min, lon_max, lat_min, lat_max) ddx = (x2-x1)*1.e-10 ddy = (y2-y1)*1.e-10 bb = Bbox.from_extents(x1-ddx, y1-ddy, x2+ddx, y2+ddy) grid_info = { "extremes": extremes, "lon_lines": lon_lines, "lat_lines": lat_lines, "lon": self._clip_grid_lines_and_find_ticks( lon_lines, lon_values, lon_levs, bb), "lat": self._clip_grid_lines_and_find_ticks( lat_lines, lat_values, lat_levs, bb), } tck_labels = grid_info["lon"]["tick_labels"] = {} for direction in ["left", "bottom", "right", "top"]: levs = grid_info["lon"]["tick_levels"][direction] tck_labels[direction] = self.tick_formatter1( direction, lon_factor, levs) tck_labels = grid_info["lat"]["tick_labels"] = {} for direction in ["left", "bottom", "right", "top"]: levs = grid_info["lat"]["tick_levels"][direction] tck_labels[direction] = self.tick_formatter2( direction, lat_factor, levs) return grid_info def _get_raw_grid_lines(self, lon_values, lat_values, lon_min, lon_max, lat_min, lat_max): lons_i = np.linspace(lon_min, lon_max, 100) # for interpolation lats_i = np.linspace(lat_min, lat_max, 100) lon_lines = [self.transform_xy(np.full_like(lats_i, lon), lats_i) for lon in lon_values] lat_lines = [self.transform_xy(lons_i, np.full_like(lons_i, lat)) for lat in lat_values] return lon_lines, lat_lines def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb): gi = { "values": [], "levels": [], "tick_levels": dict(left=[], bottom=[], right=[], top=[]), "tick_locs": dict(left=[], bottom=[], right=[], top=[]), "lines": [], } tck_levels = gi["tick_levels"] tck_locs = gi["tick_locs"] for (lx, ly), v, lev in zip(lines, values, levs): xy, tcks = clip_line_to_rect(lx, ly, bb) if not xy: continue gi["levels"].append(v) gi["lines"].append(xy) for tck, direction in zip(tcks, ["left", "bottom", "right", "top"]): for t in tck: tck_levels[direction].append(lev) tck_locs[direction].append(t) return gi def update_transform(self, aux_trans): if isinstance(aux_trans, Transform): def transform_xy(x, y): ll1 = np.column_stack([x, y]) ll2 = aux_trans.transform(ll1) lon, lat = ll2[:, 0], ll2[:, 1] return lon, lat def inv_transform_xy(x, y): ll1 = np.column_stack([x, y]) ll2 = aux_trans.inverted().transform(ll1) lon, lat = ll2[:, 0], ll2[:, 1] return lon, lat else: transform_xy, inv_transform_xy = aux_trans self.transform_xy = transform_xy self.inv_transform_xy = inv_transform_xy def update(self, **kw): for k in kw: if k in ["extreme_finder", "grid_locator1", "grid_locator2", "tick_formatter1", "tick_formatter2"]: setattr(self, k, kw[k]) else: raise ValueError("Unknown update property '%s'" % k) @cbook.deprecated("3.2") class GridFinderBase(GridFinder): def __init__(self, extreme_finder, grid_locator1=None, grid_locator2=None, tick_formatter1=None, tick_formatter2=None): super().__init__((None, None), extreme_finder, grid_locator1, grid_locator2, tick_formatter1, tick_formatter2) class MaxNLocator(mticker.MaxNLocator): def __init__(self, nbins=10, steps=None, trim=True, integer=False, symmetric=False, prune=None): # trim argument has no effect. It has been left for API compatibility mticker.MaxNLocator.__init__(self, nbins, steps=steps, integer=integer, symmetric=symmetric, prune=prune) self.create_dummy_axis() self._factor = 1 def __call__(self, v1, v2): self.set_bounds(v1 * self._factor, v2 * self._factor) locs = mticker.MaxNLocator.__call__(self) return np.array(locs), len(locs), self._factor @cbook.deprecated("3.3") def set_factor(self, f): self._factor = _deprecate_factor_none(f) class FixedLocator: def __init__(self, locs): self._locs = locs self._factor = 1 def __call__(self, v1, v2): v1, v2 = sorted([v1 * self._factor, v2 * self._factor]) locs = np.array([l for l in self._locs if v1 <= l <= v2]) return locs, len(locs), self._factor @cbook.deprecated("3.3") def set_factor(self, f): self._factor = _deprecate_factor_none(f) # Tick Formatter class FormatterPrettyPrint: def __init__(self, useMathText=True): self._fmt = mticker.ScalarFormatter( useMathText=useMathText, useOffset=False) self._fmt.create_dummy_axis() def __call__(self, direction, factor, values): return self._fmt.format_ticks(values)
[docs]class DictFormatter:
[docs] def __init__(self, format_dict, formatter=None): """ format_dict : dictionary for format strings to be used. formatter : fall-back formatter """ super().__init__() self._format_dict = format_dict self._fallback_formatter = formatter
[docs] def __call__(self, direction, factor, values): """ factor is ignored if value is found in the dictionary """ if self._fallback_formatter: fallback_strings = self._fallback_formatter( direction, factor, values) else: fallback_strings = [""] * len(values) return [self._format_dict.get(k, v) for k, v in zip(values, fallback_strings)]