Source code for sierra.core.graphs.heatmap

# Copyright 2018 John Harwell, All rights reserved.
#  SPDX-License-Identifier: MIT

# Core packages
import textwrap
import typing as tp
import logging
import pathlib

# 3rd party packages
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.axes_grid1
import pandas as pd

# Project packages
from sierra.core import utils, config, storage, types

[docs]class Heatmap: """ Generates a X vs. Y vs. Z heatmap plot of a ``.mean`` file. If the necessary .mean file does not exist, the graph is not generated. """
[docs] @staticmethod def set_graph_size(df: pd.DataFrame, fig) -> None: """ Set graph X,Y size based on dataframe dimensions. """ if len(df.index) > len(df.columns): xsize = config.kGraphBaseSize ysize = xsize * float(len(df.index)) / float(len(df.columns)) else: ysize = config.kGraphBaseSize xsize = ysize * float(len(df.columns)) / float(len(df.index)) fig.set_size_inches(xsize, ysize)
[docs] def __init__(self, input_fpath: pathlib.Path, output_fpath: pathlib.Path, title: str, xlabel: str, ylabel: str, zlabel: tp.Optional[str] = None, large_text: bool = False, xtick_labels: tp.Optional[tp.List[str]] = None, ytick_labels: tp.Optional[tp.List[str]] = None, transpose: bool = False, interpolation: tp.Optional[str] = None) -> None: # Required arguments self.input_fpath = input_fpath self.output_fpath = output_fpath self.title = '\n'.join(textwrap.wrap(title, 40)) self.xlabel = xlabel if transpose else ylabel self.ylabel = ylabel if transpose else xlabel # Optional arguments if large_text: self.text_size = config.kGraphTextSizeLarge else: self.text_size = config.kGraphTextSizeSmall self.transpose = transpose self.zlabel = zlabel if interpolation: self.interpolation = interpolation else: self.interpolation = 'nearest' if self.transpose: self.xtick_labels = xtick_labels self.ytick_labels = ytick_labels else: self.xtick_labels = ytick_labels self.ytick_labels = xtick_labels self.logger = logging.getLogger(__name__)
[docs] def generate(self) -> None: if not utils.path_exists(self.input_fpath): self.logger.debug( "Not generating heatmap: %s does not exist", self.input_fpath) return # Read .csv and create raw heatmap pfrom default configuration data_df = storage.DataFrameReader('storage.csv')(self.input_fpath) self._plot_df(data_df, self.output_fpath)
[docs] def _plot_df(self, df: pd.DataFrame, opath: pathlib.Path) -> None: """ Given a dataframe read from a file, plot it as a heatmap. """ fig, ax = plt.subplots(figsize=(config.kGraphBaseSize, config.kGraphBaseSize)) # Transpose if requested if self.transpose: df = df.transpose() # Plot heatmap plt.imshow(df, interpolation=self.interpolation, aspect='auto') # Add labels plt.xlabel(self.xlabel, fontsize=self.text_size['xyz_label']) plt.ylabel(self.ylabel, fontsize=self.text_size['xyz_label']) # Add X,Y ticks self._plot_ticks(ax) # Add graph title plt.title(self.title, fontsize=self.text_size['title']) # Add colorbar self._plot_colorbar(ax) # Output figure self.set_graph_size(df, fig) fig = ax.get_figure() fig.savefig(opath, bbox_inches='tight', dpi=config.kGraphDPI) # Prevent memory accumulation (fig.clf() does not close everything) plt.close(fig)
[docs] def _plot_colorbar(self, ax) -> None: """ Put the Z-axis colorbar on the plot. """ divider = mpl_toolkits.axes_grid1.make_axes_locatable(ax) cax = divider.append_axes('right', size='5%', pad=0.05) bar = plt.colorbar(cax=cax) if self.zlabel is not None:
[docs] def _plot_ticks(self, ax) -> None: """ Plot X,Y ticks and their corresponding labels. """ ax.tick_params(labelsize=self.text_size['tick_label']) if self.xtick_labels is not None: ax.set_xticks(np.arange(len(self.xtick_labels))) ax.set_xticklabels(self.xtick_labels, rotation='vertical') if self.ytick_labels is not None: ax.set_yticks(np.arange(len(self.ytick_labels))) ax.set_yticklabels(self.ytick_labels)
[docs]class DualHeatmap: """Generates a side-by-side plot of two heataps from two CSV files. ``.mean`` files must be named as ``<input_stem_fpath>_X.mean``, where `X` is non-negative integer. Input ``.mean`` files must be 2D grids of the same cardinality. This graph does not plot standard deviation. If there are not exactly two file paths passed, the graph is not generated. """ kCardinality = 2
[docs] def __init__(self, ipaths: types.PathList, output_fpath: pathlib.Path, title: str, xlabel: tp.Optional[str] = None, ylabel: tp.Optional[str] = None, zlabel: tp.Optional[str] = None, large_text: bool = False, xtick_labels: tp.Optional[tp.List[str]] = None, ytick_labels: tp.Optional[tp.List[str]] = None, legend: tp.Optional[tp.List[str]] = None) -> None: self.ipaths = ipaths self.output_fpath = output_fpath self.title = title self.legend = legend self.xlabel = xlabel self.ylabel = ylabel self.zlabel = zlabel self.xtick_labels = xtick_labels self.ytick_labels = ytick_labels # Optional arguments if large_text: self.text_size = config.kGraphTextSizeLarge else: self.text_size = config.kGraphTextSizeSmall self.logger = logging.getLogger(__name__)
[docs] def generate(self) -> None: reader = storage.DataFrameReader('storage.csv') dfs = [reader(f) for f in self.ipaths] if not dfs or len(dfs) != DualHeatmap.kCardinality: self.logger.debug(("Not generating dual heatmap: wrong # files " "(must be %s"), DualHeatmap.kCardinality) return # Scaffold graph. We can use either dataframe for setting the graph # size; we assume they have the same dimensions. # fig, axes = plt.subplots(ncols=2) Heatmap.set_graph_size(dfs[0], fig) y = np.arange(len(dfs[0].columns)) x = dfs[0].index ax1, ax2 = axes # Find min, max so the shared colorbar makes sense minval = min(dfs[0].min().min(), dfs[1].min().min()) maxval = max(dfs[0].max().max(), dfs[1].max().max()) # Plot heatmaps im1 = ax1.matshow(dfs[0], interpolation='none', vmin=minval, vmax=maxval) im2 = ax2.matshow(dfs[1], interpolation='none', vmin=minval, vmax=maxval) # Add titles fig.suptitle(self.title, fontsize=self.text_size['title']) ax1.xaxis.set_ticks_position('bottom') ax1.yaxis.set_ticks_position('left') ax2.xaxis.set_ticks_position('bottom') ax2.yaxis.set_ticks_position('left') if self.legend is not None: ax1.set_title("\n".join(textwrap.wrap(self.legend[0], 20)), size=self.text_size['legend_label']) ax2.set_title("\n".join(textwrap.wrap(self.legend[1], 20)), size=self.text_size['legend_label']) # Add colorbar. # # Add, then remove the colorbar for the heatmap on the left so that they # both end up the same size. Not pythonic, but it works. self._plot_colorbar(fig, im1, ax1, remove=True) self._plot_colorbar(fig, im2, ax2, remove=False) # Add X,Y,Z labels: # # - X labels are needed on both heatmaps. # - Y label only needed on left heatmap. self._plot_labels(ax1, xlabel=True, ylabel=True) self._plot_labels(ax2, xlabel=True, ylabel=False) # Add X,Y ticks: # # - X tick labels needed on both heatmaps # - Y tick labels only needed on left heatmap. self._plot_ticks(ax1, x, y, xlabels=True, ylabels=True) self._plot_ticks(ax2, x, y, xlabels=True, ylabels=False) # Output figures fig.subplots_adjust(wspace=0.0, hspace=0.0) fig.savefig(self.output_fpath, bbox_inches='tight', dpi=config.kGraphDPI) # Prevent memory accumulation (fig.clf() does not close everything) plt.close(fig)
[docs] def _plot_colorbar(self, fig, im, ax, remove: bool) -> None: """ Plot the Z-axis color bar on the dual heatmap. """ divider = mpl_toolkits.axes_grid1.make_axes_locatable(ax) cax = divider.append_axes('right', size='5%', pad=0.05) bar = fig.colorbar(im, cax=cax) if remove: fig.delaxes(fig.axes[2]) # p0 = axes[0].get_position().get_points().flatten() # p1 = axes[1].get_position().get_points().flatten() # ax_cbar = fig.add_axes([p0[0], , p1[2] - p0[0], 0.05]) # bar = fig.colorbar(im, cax=ax_cbar, orientation='horizontal') if self.zlabel is not None:, fontsize=self.text_size['xyz_label'])
[docs] def _plot_ticks(self, ax, xvals, yvals, xlabels: bool, ylabels: bool) -> None: """Plot ticks and tick labels. If the labels are numerical and the numbers are too large, force scientific notation (the ``rcParam`` way of doing this does not seem to work...) """ ax.tick_params(labelsize=self.text_size['tick_label']) if xlabels: ax.set_xticks(yvals) ax.set_xticklabels(self.ytick_labels, rotation='vertical') else: ax.set_xticks([]) ax.set_xticklabels([]) if ylabels: ax.set_yticks(xvals) ax.set_yticklabels(self.xtick_labels, rotation='horizontal') else: ax.set_yticks([]) ax.set_yticklabels([])
[docs] def _plot_labels(self, ax, xlabel: bool, ylabel: bool) -> None: """ Plot X,Y axis labels. """ if xlabel: ax.set_xlabel(self.ylabel, fontsize=self.text_size['xyz_label']) if ylabel: ax.set_ylabel(self.xlabel, fontsize=self.text_size['xyz_label'])
[docs]class HeatmapSet(): """ Generates a :class:`Heatmap` plot for each of the specified I/O path pairs. """
[docs] def __init__(self, ipaths: types.PathList, opaths: types.PathList, titles: tp.List[str], **kwargs) -> None: self.ipaths = ipaths self.opaths = opaths self.titles = titles self.kwargs = kwargs
[docs] def generate(self) -> None: for ipath, opath, title in zip(self.ipaths, self.opaths, self.titles): hm = Heatmap(input_fpath=ipath, output_fpath=opath, title=title, **self.kwargs) hm.generate()
__api__ = [ 'Heatmap', 'DualHeatmap', 'HeatmapSet' ]