Source code for psy_strat.stratplot

"""Module to create new stratigraphic plots

This module defines the :func:`stratplot` function that can be used to create
stratigraphic plots such as pollen diagrams
"""
from __future__ import division
import weakref
import six
from itertools import groupby, chain, islice
import matplotlib as mpl
import matplotlib.transforms as mt
from collections import defaultdict
import psyplot
from psyplot.utils import DefaultOrderedDict
import xarray as xr
import numpy as np
from psy_strat.plotters import StratPlotter, BarStratPlotter
from psyplot.data import ArrayList
import psyplot.project as psy
from docrep import DocstringProcessor

docstrings = DocstringProcessor()


gui_plugin = 'psy_strat.strat_widget:StratPlotsWidget:stratplots'


NOGROUP = 'nogroup'


def _no_grouper(col):
    """Return an empty string to disable the grouping"""
    return NOGROUP


[docs]def stratplot(df, group_func=None, formatoptions=None, ax=None, thresh=0.01, percentages=[], exclude=[], widths=None, calculate_percentages=True, min_percentage=20.0, trunc_height=0.3, fig=None, all_in_one=[], stacked=[], summed=[], use_bars=False, subgroups={}): """Visualize a dataframe as a stratigraphic plot This functions takes a :class:`pandas.DataFrame` and transforms it to a stratigraphic plot. The columns in the DataFrame may be grouped together using the `group_func` and the widths per group should then be specified. This function uses matplotlib axes for each subdiagram that all share a common vertical axes, the index of `df`. The variables are managed in the order of occurence in the input `df` but, however, are grouped together depending on the `group_func`. The default is to plot every variable in `df` into separete line plots that line up vertically. You can use the `percentages` parameter for area plots, the `all_in_one` parameter for groups that should be all in one single plot (i.e. axes) and the `stacked` parameter for stacked plots. Parameters ---------- df: pandas.DataFrame The dataframe containing the data to plot. group_func: function A function that groups the columns in the input `df` together. It must accept the name of a column and return the corresponding group name:: def group_func(col_name: str): return "name of it's group" If this parameter is not specified, each column will be assigned to the `'nogroup'` group that can then be used in the other parameters, such as `formatoptions` and `percentages`. Each group may also be divided into `subgroups` (see below), in this case, the `group_func` should return the corresponding subgroup. formatoptions: dict The formatoption for each group. Depending on the chosen plot method, this contains the formatoptions for the psyplot plotter. ax: matplotlib.axes.Axes The matplotlib axes to plot on. New axes will be created that cover all the space of the given axes. If this parameter is not specified and `fig` is None, a new matplotlib figure is created with a new matplotlib axes. thresh: float A minimum number between 0 and 100 (by default 1%) that a `percentages` column has to fullfil in order to be included in the plot. If a variable is always below this threshold, it will not be included in the figure percentages: list of str or bool The group names (see `group_func`) that represent percentage values. This variables will be visualized using an area plot and can be rescaled to sum up to 100% using the `calculate_percentages` parameter. This parameter can also be set to True if all groups shall be considered as percentage data exclude: list of str Either group names of column names in `df` that should be excluded in the plot widths: dict A mapping from group name to it's relative width in the plot. The values of this mapping should some up to 1, e.g.:: widths = {'group1': 0.3, 'group2': 0.5, 'group3': 0.2} calculate_percentages: bool or list of str If True, rescale the groups mentioned in the `percentages` parameter to sum up to 100%. In case of a list of str, this parameter represents the group (or variable) names that shall be used for the normalization min_percentage: float The minimum percentage (between 0 and 100) that should be covered by variables displaying `percentages` data. Each plot in one of the `percentages` groups will have at least have a xlim from 0 to `min_percentage` trunc_height: float A float between 0 and 1. The fraction of the `ax` that should be reserved for the group titles. fig: matplotlib.Figure The matplotlib figure to draw the plot on. If neither `ax` nor `fig` is specified, a new figure will be created. all_in_one: list of str The groups mentioned in this parameter will all be plotted in one single axes whereas the default is to plot each variable in a separate plot stacked: list of str The groups mentioned in this parameter will all be plotted in one single axes, stacked onto each other summed: list of str The groups (or subgroups) mentioned in this parameter will be summed and an extra plot will be appended to the right of the stratigraphic diagram use_bars: list of str or bool The variables specified in this parameter (or all variables if `use_bars` is ``True``) will be visualized by a bar diagram, instead of a line or area plot. subgroups: dict A mapping from group name to a list of subgroups, e.g.:: subgroups = {'Pollen': ['Trees', 'Shrubs']} to divide an overarching group into subgroups. Returns ------- psyplot.project.Project The newly created psyplot subproject that contains the displayed data list of :class:`StratGroup` The groupers that manage the different variables. There is one grouper per group""" import psyplot.project as psy import matplotlib.pyplot as plt if group_func is None: group_func = _no_grouper groups = DefaultOrderedDict(list) # we invert subgroups here subgroup2group = dict(chain.from_iterable( ((sub, group) for sub in subs) for group, subs in subgroups.items())) cols = {} for col in df.columns: group = group_func(col) group = subgroup2group.get(group, group) groups[group].append(col) cols[col] = group # Setup percentages if isinstance(percentages, six.string_types): percentages = [percentages] try: percentages = list(percentages) except TypeError: percentages = list(groups) if percentages else [] formatoptions = formatoptions or {} if calculate_percentages and set(percentages).intersection(groups): df = df.copy(True) for group in set(percentages).intersection(groups): members = groups[group] try: calculate_percentages = list(calculate_percentages) except TypeError: norm_members = members else: norm_members = list(set(chain.from_iterable( [var] if var in df.columns else groups[var] for var in calculate_percentages))) df[members] *= 100. / np.tile( df[norm_members].fillna(0).sum(axis=1)[:, np.newaxis], (1, len(members))) if summed: try: summed = list(summed) except TypeError: summed = list(groups) stacked.append('Summed') groups['Summed'] = [g + '_summed' for g in summed] if widths: widths.setdefault('Summed', 0.2) formatoptions.setdefault('Summed', {}).setdefault( 'legendlabels', '%(long_name)s') formatoptions['Summed'].setdefault('title', '') else: summed = [] widths = widths or defaultdict( lambda: 1. / (len(set(groups).difference(percentages)) or 1)) # Setup use_bars if isinstance(use_bars, six.string_types): use_bars = [use_bars] try: use_bars = list(use_bars) except TypeError: use_bars = list(groups) if use_bars else [] # NOTE: we create the Dataset manually instead of using # xarray.Dataset.from_dataframe becuase that is much faster idx = df.index.name or 'y' ds = xr.Dataset( {col: xr.Variable((idx, ), df[col]) for col in df.columns}, {idx: xr.Variable((idx, ), df.index)}) for var, varo in ds.variables.items(): if var not in ds.coords: varo.attrs['group'] = group_func(var) varo.attrs['maingroup'] = cols[var] for group in summed: variables = [var for var, varo in ds.variables.items() if varo.attrs.get('group') == group] ds[group + '_summed'] = xr.Variable( (idx, ), df[variables].sum(axis=1).values, attrs={'long_name': group, 'group': 'Summed', 'maingroup': 'Summed'}) cols[group + '_summed'] = 'Summed' plot_vars = [ var for var, varo in ds.variables.items() if ((var not in ds.coords) and (var not in exclude and varo.attrs['group'] not in exclude) and (cols[var] not in percentages or ds[var].max().values > thresh))] arr_names = [] if ax is None: fig = fig or plt.figure() bbox = mt.Bbox.from_extents( mpl.rcParams['figure.subplot.left'], mpl.rcParams['figure.subplot.bottom'], mpl.rcParams['figure.subplot.right'], mpl.rcParams['figure.subplot.top']) elif isinstance(ax, (mpl.axes.SubplotBase, mpl.axes.Axes)): bbox = ax.get_position() fig = ax.figure else: # the bbox is given bbox = ax fig = fig or plt.gcf() x0 = bbox.x0 y0 = bbox.y0 orig_height = bbox.height height = orig_height * (1 - trunc_height) total_width = bbox.width x1 = x0 + total_width i = 0 ax0 = None x = x0 mp = psy.gcp(True) groupers = [] with psy.Project.block_signals: for group, variables in groups.items(): variables = [v for v in variables if v in plot_vars] if not variables: continue w = widths[group] * total_width if group in all_in_one: identifier = 'all_in_one' elif group in stacked: identifier = 'stacked' elif group in percentages: identifier = 'percentages' else: identifier = 'default' grouper_cls = strat_groupers[identifier] grouper = grouper_cls.from_dataset( fig, mt.Bbox.from_bounds(x, y0, w, height), ds, variables, fmt=dict(formatoptions.get(group, {})), project=mp, ax0=ax0, use_bars=use_bars, group=group) if identifier == 'percentages': resize = False for plotter in grouper.plotters: if plotter.ax.get_xlim()[1] < min_percentage: plotter.update(xlim=(0, min_percentage)) resize = True if resize: grouper.resize_axes(grouper.axes) if group != NOGROUP: grouper.group_plots(trunc_height / height) ds[group] = xr.Variable(tuple(), '', attrs={'identifier': identifier}) ax0 = ax0 or grouper.axes[0] x += w arr_names.extend( arr.psy.arr_name for arr in grouper.plotter_arrays) groupers.append(grouper) if psyplot.with_gui: from psyplot_gui.main import mainwindow mainwindow.plugins[gui_plugin].add_tree(groupers) # invert the vertical axis ax0.invert_yaxis() sp = psy.gcp(True)(arr_name=arr_names) sp[0].psy.update( ylabel='%(name)s', ytickprops={'left': True, 'labelleft': True}, draw=False) for ax, p in sp.axes.items(): ax_bbox = ax.get_position() d = {} if ax_bbox.x0 != x0: d['left'] = ':' if ax_bbox.x1 != x1: d['right'] = ':' p.update(axislinestyle=d, draw=False) psy.scp(sp.main) psy.scp(sp) return sp, groupers
[docs]class StratGroup(object): """Base class for visualizing stratigraphic plots""" #: list of weakref. Weak references to the created arrays _refs = [] _arrays = None _plotter_arrays = None grouper_height = None #: The default formatoptions for the plots default_fmt = { 'ytickprops': {'left': False, 'labelleft': False}, } bar_default_fmt = default_fmt.copy() bar_default_fmt['categorical'] = False @property def plotter_arrays(self): """The data objects that contain the plotters""" return self._plotter_arrays or ArrayList([ref() for ref in self._refs]) @plotter_arrays.setter def plotter_arrays(self, value): self._plotter_arrays = value @property def arrays(self): """The arrays managed by this :class:`StratGroup`. One array for each variable""" return self.plotter_arrays @property def all_arrays(self): """All variables of this group in the dataset""" arr = self.arrays[0] group = arr.group ds = arr.psy.base return [ds.psy[arr] for arr, v in ds.variables.items() if v.attrs.get('group') == group] @property def plotters(self): """The plotters of the :attr:`arrays`""" return list(filter(lambda p: p is not None, [arr.psy.plotter for arr in self.plotter_arrays])) @property def axes(self): return [plotter.ax for plotter in self.plotters] @property def arr_names(self): return [arr.psy.arr_name for arr in self.plotter_arrays] def __init__(self, arrays, bbox=None, use_weakref=True, group=None): """ Parameters ---------- arrays: list of xarray.DataArray The data arrays that are plotted by this :class:`StratGroup` instance bbox: matplotlib.transforms.Bbox The bounding box for the axes use_weakref: bool If True, only weak references are used group: str The groupname of this grouper. If not given, it will be taken from the ``'maingroup'`` attribute of the first array """ if use_weakref: self._refs = [weakref.ref(arr) for arr in arrays] else: self.plotter_arrays = arrays if bbox is None: boxes = [arr.psy.ax.get_position() for arr in arrays] x0 = min(bbox.x0 for bbox in boxes) y0 = min(bbox.y0 for bbox in boxes) w = sum(bbox.width for bbox in boxes) bbox = boxes[0].from_bounds(x0, y0, w, boxes[0].height) self.bbox = bbox self.group = group or arrays[0].attrs.get('maingroup')
[docs] def resize_axes(self, axes): """Resize the axes in this group""" width = self.bbox.width w = width / len(axes) x0 = self.bbox.x0 for ax in axes: ax_bbox = ax.get_position() ax.set_position([x0, ax_bbox.y0, w, ax_bbox.height]) x0 += w
[docs] def group_plots(self, height=None): """Group the variables visually Parameters ---------- height: float The height of the grouper. If not specified, the previous :attr:`grouper_height` attribute will be used""" for plotter in (plotter for plotter in self.plotters if plotter.ax.get_visible() and not plotter.grouper.shared_by): height = height or self.grouper_height if height is None and plotter['grouper']: height = plotter['grouper'][0] elif height is None: return self.grouper_height = height plotter.update(grouper=(height, '%(group)s'), draw=False, force=True) for fmto in plotter.grouper.shared: with fmto.plotter.no_validation: fmto.plotter['grouper'] = (height, '%(group)s')
@property def figure(self): """The figure that contains the plots""" return self.axes[0].figure
[docs] def is_visible(self, arr): """Check if the given `arr` is shown""" return arr.psy.plotter.ax.get_visible()
[docs] @classmethod @docstrings.get_sectionsf('StratGroup.from_dataset', sections=['Parameters', 'Returns']) def from_dataset(cls, fig, bbox, ds, variables, fmt=None, project=None, ax0=None, use_bars=False, group=None): """ Create :class:`StratGroup` while creating a stratigraphic plot Create a stratigraphic plot within the given `bbox` of `fig`. Parameters ---------- fig: matplotlib.figure.Figure The figure to plot in bbox: matplotlib.transforms.Bbox The bounding box for the newly created axes ds: xarray.Dataset The dataset variables: list The variables that shall be plot in the given `ds` project: psyplot.project.Project The mother project. If given, only weak references are stored in the returned :class:`StratGroup` and each array is appended to the `project`. ax0: matplotlib.axes.Axes The first subplot to share the y-axis with use_bars: bool Whether to use a bar plot or a line/area plot Returns ------- StratGroup The newly created instance with the arrays """ if ax0 is None: ax0 = fig.add_axes(bbox.from_bounds(*bbox.bounds), label='ax0') axes = [ax0] + [fig.add_axes(bbox.from_bounds(*bbox.bounds), sharey=ax0, label='ax%i' % i) for i in range(1, len(variables))] else: axes = [fig.add_axes(bbox.from_bounds(*bbox.bounds), sharey=ax0, label='ax%i' % i) for i in range(len(variables))] grouped = DefaultOrderedDict(list) for name in variables: grouped[ds[name].attrs.get('group', 'group')].append(name) # Use group specific bars if use_bars: try: use_bars = list(use_bars) except TypeError: use_bars = list(grouped) else: use_bars = [] sp = None axes_it = iter(axes) for subgroup, names in grouped.items(): formatoptions = dict(fmt or {}) if subgroup in use_bars or group in use_bars: plotter_cls = BarStratPlotter defaults = cls.bar_default_fmt else: plotter_cls = StratPlotter defaults = cls.default_fmt for key, val in six.iteritems(defaults): formatoptions.setdefault(key, val) sp2 = psy.Project()._add_data( plotter_cls, ds, name=names, draw=False, fmt=formatoptions, prefer_list=False, ax=islice(axes_it, len(names)), share='grouper', attrs=dict(maingroup=group)) if project is not None: project.extend(sp2, new_name=True) sp = sp2 if sp is None else sp + sp2 ret = cls(list(sp), bbox, use_weakref=project is not None, group=group) ret.resize_axes(axes) return ret
[docs] def hide_array(self, name): """Hide the variable of the given `name` Parameters ---------- name: str The variable name""" arr = next(iter(self.plotter_arrays(name=name)), None) if arr is None: return group = arr.group i, first_visible = next(filter( lambda t: t[1].psy.ax.get_visible() and t[1].group == group, enumerate(self.plotter_arrays))) if arr is None or not arr.psy.ax.get_visible(): # array isn't plotted return elif arr is first_visible: p = psy.Project(self.plotter_arrays)(group=group) p.unshare(keys='grouper', draw=False) p(name=set(p.names) - {name}).share(keys='grouper', draw=False) arr.psy.ax.set_visible(False) if arr is self.arrays[0]: pass self.resize_axes([ax for ax in self.axes if ax.get_visible()]) self.group_plots()
[docs] def show_array(self, name): """Show the variable of the given `name` Parameters ---------- name: str The variable name""" arrays = self.plotter_arrays arr = next(iter(arrays(name=name)), None) if arr is None: return group = arr.group key, first_invisibles = next( groupby(filter(lambda a: a.group == group, arrays), lambda arr: arr.psy.ax.get_visible())) if key: # first plot is visible first_invisibles = [] if arr.psy.ax.get_visible(): # array isn't plotted return elif any(arr is invisible_arr for invisible_arr in first_invisibles): p = psy.Project(arrays)(group=group) p.unshare(keys='grouper', draw=False) # i = next(i for i, a in enumerate(p) if a.name == name) p(name=set(p.names) - {name}).share(arr, keys='grouper', draw=False) arr.psy.ax.set_visible(True) self.resize_axes([ax for ax in self.axes if ax.get_visible()]) self.group_plots()
[docs] def reorder(self, names): """Reorder the plot objects Parameters ---------- names: list of str The variable names that should be the first""" arrays = self._plotter_arrays or self._refs old = list(arrays) old_da = list(self.plotter_arrays) project = self.plotters[0].project if project is not None: project = project.main i = project.arr_names.index(old_da[0].psy.arr_name) arr_names = project.arr_names[i:i+len(arrays)] reorder_project = arr_names == self.arr_names arrays.clear() for name in names: arr = next((old[i] for i, arr in enumerate(old_da) if str(arr.name) == name), None) if arr is not None: arrays.append(arr) # now add the ones that are not mentioned in `names` for i, arr in enumerate(old_da): if str(arr.name) not in names: arrays.append(old[i]) if project is not None and reorder_project: project[i:i+len(arrays)] = self.plotter_arrays if project.is_csp or project.is_cmp: project.oncpchange.emit(project) self.resize_axes([ax for ax in self.axes if ax.get_visible()]) self.plotter_arrays.update(force=['grouper'], draw=False)
[docs]class StratPercentages(StratGroup): """A :class:`StratGroup` for percentages plots""" default_fmt = StratGroup.default_fmt.copy() default_fmt['xlim'] = (0, 'rounded') default_fmt['xticks'] = np.arange(10, 100, 20) bar_default_fmt = default_fmt.copy() bar_default_fmt['categorical'] = False default_fmt['plot'] = 'areax'
[docs] def resize_axes(self, axes): """Resize the axes in this group""" width = self.bbox.width width /= sum(ax.get_xlim()[1] for ax in axes) / 100. x0 = self.bbox.x0 for ax in axes: w = width * ax.get_xlim()[1] / 100. ax_bbox = ax.get_position() ax.set_position([x0, ax_bbox.y0, w, ax_bbox.height]) x0 += w
[docs]class StratAllInOne(StratGroup): """A :class:`StratGroup` for single plots""" default_fmt = StratGroup.default_fmt.copy() default_fmt['title'] = '%(group)s' default_fmt['titleprops'] = {} default_fmt['legend'] = True bar_default_fmt = default_fmt.copy() bar_default_fmt['categorical'] = False @property def arrays(self): return self.plotter_arrays[0]
[docs] def group_plots(self, height): """Reimplemented to do nothing because all variables are in one axes """ pass
[docs] def is_visible(self, arr): """Check if the given `arr` is shown""" return arr.name in self.plotter_arrays[0].names
[docs] @classmethod @docstrings.dedent def from_dataset(cls, fig, bbox, ds, variables, fmt=None, project=None, ax0=None, use_bars=False, group=None): """ Create :class:`StratGroup` while creating a stratigraphic plot Create a stratigraphic plot within the given `bbox` of `fig`. Parameters ---------- %(StratGroup.from_dataset.parameters)s Returns ------- %(StratGroup.from_dataset.returns)s """ fmt = fmt or {} if use_bars: try: use_bars = list(iter(use_bars)) except TypeError: use_bars = True else: use_bars = group in use_bars defaults = cls.bar_default_fmt if use_bars else cls.default_fmt for key, val in six.iteritems(defaults): fmt.setdefault(key, val) if ax0 is None: ax = fig.add_axes(bbox.from_bounds(*bbox.bounds), label='ax0') else: ax = fig.add_axes(bbox.from_bounds(*bbox.bounds), sharey=ax0, label='ax0') plotter_cls = BarStratPlotter if use_bars else StratPlotter sp = psy.Project()._add_data( plotter_cls, ds, name=variables, draw=False, fmt=fmt, prefer_list=True, ax=ax, share='grouper', attrs=dict(maingroup=group)) if project is not None: project.extend(sp, new_name=True) return cls(list(sp), bbox, use_weakref=project is not None, group=group)
[docs] def hide_array(self, name): """Hide the variable of the given `name` Parameters ---------- name: str The variable name""" i, arr = next(((i, arr) for i, arr in enumerate(self.arrays) if arr.name == name), (None, None)) plotter = self.plotters[0] v = plotter['plot'] if v is None or isinstance(v, six.string_types): v = [v] * len(self.arrays) if arr is None or v[i] is None: # array isn't plotted return v[i] = None plotter.update(plot=v, force=True)
[docs] def show_array(self, name): """Show the variable of the given `name` Parameters ---------- name: str The variable name""" i, arr = next(((i, arr) for i, arr in enumerate(self.arrays) if arr.name == name), (None, None)) plotter = self.plotters[0] v = plotter['plot'] if v is None or isinstance(v, six.string_types): v = [v] * len(self.arrays) if arr is None or v[i] is not None: # array is plotted return v[i] = self.default_fmt.get('plot', '-') plotter.update(plot=v, force=True)
[docs] def reorder(self, names): """Reorder the plot objects Parameters ---------- mapping: dict A mapping from the new index to the old one""" plotter = self.plotters[0] data = plotter.data current = list(data) visibilities = list(map(self.is_visible, data)) plot = [] data.clear() ls = self.default_fmt.get('plot', '-') for name in names: i = next(i for i, arr in enumerate(current) if str(arr.name) == name) data.append(current[i]) plot.append((visibilities[i] and ls) or None) plotter.update(plot=plot, replot=True, draw=False)
[docs]class StackedGroup(StratAllInOne): """A grouper for stacked plots""" default_fmt = StratAllInOne.default_fmt.copy() default_fmt['plot'] = 'stacked' bar_default_fmt = StratAllInOne.bar_default_fmt.copy() bar_default_fmt['plot'] = 'stacked'
strat_groupers = { 'all_in_one': StratAllInOne, 'percentages': StratPercentages, 'default': StratGroup, 'stacked': StackedGroup}