Source code for mpl_toolkits.axisartist.grid_helper_curvelinear

"""
An experimental support for curvilinear grid.
"""
from itertools import chain

import numpy as np

from matplotlib.path import Path
from matplotlib.transforms import Affine2D, IdentityTransform
from .axislines import AxisArtistHelper, GridHelperBase
from .axis_artist import AxisArtist
from .grid_finder import GridFinder, _deprecate_factor_none


[docs]class FixedAxisArtistHelper(AxisArtistHelper.Fixed): """ Helper class for a fixed axis. """
[docs] def __init__(self, grid_helper, side, nth_coord_ticks=None): """ nth_coord = along which coordinate value varies. nth_coord = 0 -> x axis, nth_coord = 1 -> y axis """ super().__init__(loc=side) self.grid_helper = grid_helper if nth_coord_ticks is None: nth_coord_ticks = self.nth_coord self.nth_coord_ticks = nth_coord_ticks self.side = side self._limits_inverted = False
[docs] def update_lim(self, axes): self.grid_helper.update_lim(axes) if self.nth_coord == 0: xy1, xy2 = axes.get_ylim() else: xy1, xy2 = axes.get_xlim() if xy1 > xy2: self._limits_inverted = True else: self._limits_inverted = False
[docs] def change_tick_coord(self, coord_number=None): if coord_number is None: self.nth_coord_ticks = 1 - self.nth_coord_ticks elif coord_number in [0, 1]: self.nth_coord_ticks = coord_number else: raise Exception("wrong coord number")
[docs] def get_tick_transform(self, axes): return axes.transData
[docs] def get_tick_iterators(self, axes): """tick_loc, tick_angle, tick_label""" g = self.grid_helper if self._limits_inverted: side = {"left": "right", "right": "left", "top": "bottom", "bottom": "top"}[self.side] else: side = self.side ti1 = g.get_tick_iterator(self.nth_coord_ticks, side) ti2 = g.get_tick_iterator(1-self.nth_coord_ticks, side, minor=True) return chain(ti1, ti2), iter([])
[docs]class FloatingAxisArtistHelper(AxisArtistHelper.Floating):
[docs] def __init__(self, grid_helper, nth_coord, value, axis_direction=None): """ nth_coord = along which coordinate value varies. nth_coord = 0 -> x axis, nth_coord = 1 -> y axis """ super().__init__(nth_coord, value) self.value = value self.grid_helper = grid_helper self._extremes = -np.inf, np.inf self._get_line_path = None # a method that returns a Path. self._line_num_points = 100 # number of points to create a line
[docs] def set_extremes(self, e1, e2): if e1 is None: e1 = -np.inf if e2 is None: e2 = np.inf self._extremes = e1, e2
[docs] def update_lim(self, axes): self.grid_helper.update_lim(axes) x1, x2 = axes.get_xlim() y1, y2 = axes.get_ylim() grid_finder = self.grid_helper.grid_finder extremes = grid_finder.extreme_finder(grid_finder.inv_transform_xy, x1, y1, x2, y2) lon_min, lon_max, lat_min, lat_max = extremes e_min, e_max = self._extremes # ranges of other coordinates if self.nth_coord == 0: lat_min = max(e_min, lat_min) lat_max = min(e_max, lat_max) elif self.nth_coord == 1: lon_min = max(e_min, lon_min) lon_max = min(e_max, lon_max) lon_levs, lon_n, lon_factor = \ grid_finder.grid_locator1(lon_min, lon_max) lat_levs, lat_n, lat_factor = \ grid_finder.grid_locator2(lat_min, lat_max) if self.nth_coord == 0: xx0 = np.full(self._line_num_points, self.value, type(self.value)) yy0 = np.linspace(lat_min, lat_max, self._line_num_points) xx, yy = grid_finder.transform_xy(xx0, yy0) elif self.nth_coord == 1: xx0 = np.linspace(lon_min, lon_max, self._line_num_points) yy0 = np.full(self._line_num_points, self.value, type(self.value)) xx, yy = grid_finder.transform_xy(xx0, yy0) self.grid_info = { "extremes": (lon_min, lon_max, lat_min, lat_max), "lon_info": (lon_levs, lon_n, _deprecate_factor_none(lon_factor)), "lat_info": (lat_levs, lat_n, _deprecate_factor_none(lat_factor)), "lon_labels": grid_finder.tick_formatter1( "bottom", _deprecate_factor_none(lon_factor), lon_levs), "lat_labels": grid_finder.tick_formatter2( "bottom", _deprecate_factor_none(lat_factor), lat_levs), "line_xy": (xx, yy), }
[docs] def get_axislabel_transform(self, axes): return Affine2D() # axes.transData
[docs] def get_axislabel_pos_angle(self, axes): extremes = self.grid_info["extremes"] if self.nth_coord == 0: xx0 = self.value yy0 = (extremes[2] + extremes[3]) / 2 dxx = 0 dyy = abs(extremes[2] - extremes[3]) / 1000 elif self.nth_coord == 1: xx0 = (extremes[0] + extremes[1]) / 2 yy0 = self.value dxx = abs(extremes[0] - extremes[1]) / 1000 dyy = 0 grid_finder = self.grid_helper.grid_finder (xx1,), (yy1,) = grid_finder.transform_xy([xx0], [yy0]) data_to_axes = axes.transData - axes.transAxes p = data_to_axes.transform([xx1, yy1]) if 0 <= p[0] <= 1 and 0 <= p[1] <= 1: xx1c, yy1c = axes.transData.transform([xx1, yy1]) (xx2,), (yy2,) = grid_finder.transform_xy([xx0 + dxx], [yy0 + dyy]) xx2c, yy2c = axes.transData.transform([xx2, yy2]) return (xx1c, yy1c), np.rad2deg(np.arctan2(yy2c-yy1c, xx2c-xx1c)) else: return None, None
[docs] def get_tick_transform(self, axes): return IdentityTransform() # axes.transData
[docs] def get_tick_iterators(self, axes): """tick_loc, tick_angle, tick_label, (optionally) tick_label""" grid_finder = self.grid_helper.grid_finder lat_levs, lat_n, lat_factor = self.grid_info["lat_info"] lat_levs = np.asarray(lat_levs) yy0 = lat_levs / _deprecate_factor_none(lat_factor) dy = 0.01 / _deprecate_factor_none(lat_factor) lon_levs, lon_n, lon_factor = self.grid_info["lon_info"] lon_levs = np.asarray(lon_levs) xx0 = lon_levs / _deprecate_factor_none(lon_factor) dx = 0.01 / _deprecate_factor_none(lon_factor) if None in self._extremes: e0, e1 = self._extremes else: e0, e1 = sorted(self._extremes) if e0 is None: e0 = -np.inf if e1 is None: e1 = np.inf if self.nth_coord == 0: mask = (e0 <= yy0) & (yy0 <= e1) #xx0, yy0 = xx0[mask], yy0[mask] yy0 = yy0[mask] elif self.nth_coord == 1: mask = (e0 <= xx0) & (xx0 <= e1) #xx0, yy0 = xx0[mask], yy0[mask] xx0 = xx0[mask] def transform_xy(x, y): x1, y1 = grid_finder.transform_xy(x, y) x2y2 = axes.transData.transform(np.array([x1, y1]).transpose()) x2, y2 = x2y2.transpose() return x2, y2 # find angles if self.nth_coord == 0: xx0 = np.full_like(yy0, self.value) xx1, yy1 = transform_xy(xx0, yy0) xx00 = xx0.copy() xx00[xx0 + dx > e1] -= dx xx1a, yy1a = transform_xy(xx00, yy0) xx1b, yy1b = transform_xy(xx00+dx, yy0) xx2a, yy2a = transform_xy(xx0, yy0) xx2b, yy2b = transform_xy(xx0, yy0+dy) labels = self.grid_info["lat_labels"] labels = [l for l, m in zip(labels, mask) if m] elif self.nth_coord == 1: yy0 = np.full_like(xx0, self.value) xx1, yy1 = transform_xy(xx0, yy0) xx1a, yy1a = transform_xy(xx0, yy0) xx1b, yy1b = transform_xy(xx0, yy0+dy) xx00 = xx0.copy() xx00[xx0 + dx > e1] -= dx xx2a, yy2a = transform_xy(xx00, yy0) xx2b, yy2b = transform_xy(xx00+dx, yy0) labels = self.grid_info["lon_labels"] labels = [l for l, m in zip(labels, mask) if m] def f1(): dd = np.arctan2(yy1b-yy1a, xx1b-xx1a) # angle normal dd2 = np.arctan2(yy2b-yy2a, xx2b-xx2a) # angle tangent mm = (yy1b == yy1a) & (xx1b == xx1a) # mask where dd not defined dd[mm] = dd2[mm] + np.pi / 2 tick_to_axes = self.get_tick_transform(axes) - axes.transAxes for x, y, d, d2, lab in zip(xx1, yy1, dd, dd2, labels): c2 = tick_to_axes.transform((x, y)) delta = 0.00001 if 0-delta <= c2[0] <= 1+delta and 0-delta <= c2[1] <= 1+delta: d1, d2 = np.rad2deg([d, d2]) yield [x, y], d1, d2, lab return f1(), iter([])
[docs] def get_line_transform(self, axes): return axes.transData
[docs] def get_line(self, axes): self.update_lim(axes) x, y = self.grid_info["line_xy"] if self._get_line_path is None: return Path(np.column_stack([x, y])) else: return self._get_line_path(axes, x, y)
[docs]class GridHelperCurveLinear(GridHelperBase):
[docs] def __init__(self, aux_trans, extreme_finder=None, grid_locator1=None, grid_locator2=None, tick_formatter1=None, tick_formatter2=None): """ aux_trans : a transform from the source (curved) coordinate to target (rectilinear) coordinate. An instance of MPL's Transform (inverse transform should be defined) or a tuple of two callable objects which defines the transform and its inverse. The callables need take two arguments of array of source coordinates and should return two target coordinates. e.g., ``x2, y2 = trans(x1, y1)`` """ super().__init__() self.grid_info = None self._old_values = None self._aux_trans = aux_trans self.grid_finder = GridFinder(aux_trans, extreme_finder, grid_locator1, grid_locator2, tick_formatter1, tick_formatter2)
[docs] def update_grid_finder(self, aux_trans=None, **kw): if aux_trans is not None: self.grid_finder.update_transform(aux_trans) self.grid_finder.update(**kw) self.invalidate()
def _update(self, x1, x2, y1, y2): """bbox in 0-based image coordinates""" # update wcsgrid if self.valid() and self._old_values == (x1, x2, y1, y2): return self._update_grid(x1, y1, x2, y2) self._old_values = (x1, x2, y1, y2) self._force_update = False
[docs] def new_fixed_axis(self, loc, nth_coord=None, axis_direction=None, offset=None, axes=None): if axes is None: axes = self.axes if axis_direction is None: axis_direction = loc _helper = FixedAxisArtistHelper(self, loc, nth_coord_ticks=nth_coord) axisline = AxisArtist(axes, _helper, axis_direction=axis_direction) # Why is clip not set on axisline, unlike in new_floating_axis or in # the floating_axig.GridHelperCurveLinear subclass? return axisline
[docs] def new_floating_axis(self, nth_coord, value, axes=None, axis_direction="bottom" ): if axes is None: axes = self.axes _helper = FloatingAxisArtistHelper( self, nth_coord, value, axis_direction) axisline = AxisArtist(axes, _helper) # _helper = FloatingAxisArtistHelper(self, nth_coord, # value, # label_direction=label_direction, # ) # axisline = AxisArtistFloating(axes, _helper, # axis_direction=axis_direction) axisline.line.set_clip_on(True) axisline.line.set_clip_box(axisline.axes.bbox) # axisline.major_ticklabels.set_visible(True) # axisline.minor_ticklabels.set_visible(False) return axisline
def _update_grid(self, x1, y1, x2, y2): self.grid_info = self.grid_finder.get_grid_info(x1, y1, x2, y2)
[docs] def get_gridlines(self, which="major", axis="both"): grid_lines = [] if axis in ["both", "x"]: for gl in self.grid_info["lon"]["lines"]: grid_lines.extend(gl) if axis in ["both", "y"]: for gl in self.grid_info["lat"]["lines"]: grid_lines.extend(gl) return grid_lines
[docs] def get_tick_iterator(self, nth_coord, axis_side, minor=False): # axisnr = dict(left=0, bottom=1, right=2, top=3)[axis_side] angle_tangent = dict(left=90, right=90, bottom=0, top=0)[axis_side] # angle = [0, 90, 180, 270][axisnr] lon_or_lat = ["lon", "lat"][nth_coord] if not minor: # major ticks for (xy, a), l in zip( self.grid_info[lon_or_lat]["tick_locs"][axis_side], self.grid_info[lon_or_lat]["tick_labels"][axis_side]): angle_normal = a yield xy, angle_normal, angle_tangent, l else: for (xy, a), l in zip( self.grid_info[lon_or_lat]["tick_locs"][axis_side], self.grid_info[lon_or_lat]["tick_labels"][axis_side]): angle_normal = a yield xy, angle_normal, angle_tangent, ""
# for xy, a, l in self.grid_info[lon_or_lat]["ticks"][axis_side]: # yield xy, a, ""