Source code for sierra.core.graphs.heatmap

# Copyright 2018 John Harwell, All rights reserved.
#
#  SPDX-License-Identifier: MIT
#
"""
Heatmap graph generation classes for stage{4,5}.
"""

# Core packages
import textwrap
import typing as tp
import logging
import pathlib

# 3rd party packages
import numpy as np
import matplotlib.pyplot as plt
import holoviews as hv
import bokeh
import polars as pl

# Project packages
from sierra.core import utils, config, storage, types
from . import pathset as _pathset

_logger = logging.getLogger(__name__)


[docs] def generate_confusion( # noqa: PLR0913 pathset: _pathset.PathSet, input_stem: str, output_stem: str, medium: str, title: str, backend: str, truth_col: str, predicted_col: str, xlabels_rotate: bool = False, large_text: bool = False, ) -> bool: """ Generate a confusion matrix from a ``.mean`` file. If the necessary ``.mean`` file does not exist, the graph is not generated. Dataframe must be constructed with {truth,predicted} columns; e.g.:: truth,predicted a,a a,q b,b c,c d,f e,e ... """ hv.extension(backend, inline=False, logo=False) ofile_ext = _ofile_ext(backend) input_fpath = pathset.input_root / (input_stem + config.STATS["mean"].exts["mean"]) output_fpath = pathset.output_root / f"CM-{output_stem}.{ofile_ext}" if not utils.path_exists(input_fpath): _logger.debug( "Not generating <batchroot>/%s: <batchroot>/%s does not exist", output_fpath.relative_to(pathset.batchroot.resolve()), input_fpath.relative_to(pathset.batchroot.resolve()), ) return False title = "\n".join(textwrap.wrap(title, 40)) text_size = ( config.GRAPHS["text_size_large"] if large_text else config.GRAPHS["text_size_small"] ) # Read .csv and get counts of each <truth, predicted> pair. df = storage.df_read(input_fpath, medium) # Group by truth and predicted columns and count occurrences confusion_df = df.group_by([truth_col, predicted_col]).agg(pl.len().alias("count")) # Get category names. Need union in case the sets aren't the same. categories = sorted( set(df[truth_col].unique().to_list()) | set(df[predicted_col].unique().to_list()) ) # Create all combinations of categories all_combinations = pl.DataFrame( { truth_col: [t for t in categories for _ in categories], predicted_col: [p for _ in categories for p in categories], } ) # Merge with actual data, filling missing combinations with 0 confusion_df = all_combinations.join( confusion_df, on=[truth_col, predicted_col], how="left" ) # Fill null counts with 0 confusion_df = confusion_df.with_columns(pl.col("count").fill_null(0)) # Normalize by row to get fractions rather than counts # Calculate sum for each truth value row_totals = confusion_df.group_by(truth_col).agg( pl.col("count").sum().alias("row_total") ) # Join back and calculate fractions confusion_df = confusion_df.join(row_totals, on=truth_col) confusion_df = confusion_df.with_columns( (pl.col("count") / pl.col("row_total")).alias("fraction") ) # Drop the row_total column confusion_df = confusion_df.drop("row_total") # Convert to pandas for holoviews confusion_pd = confusion_df.to_pandas() dataset = hv.Dataset( confusion_pd, kdims=[predicted_col, truth_col], vdims="fraction" ) # Finally, plot the data! if backend == "matplotlib": plot = hv.HeatMap(dataset).opts(show_values=True, alpha=0.65, cmap="RdYlGn") elif backend == "bokeh": plot = hv.HeatMap(dataset).opts( colorbar=True, tools=["hover"], alpha=0.65, cmap="RdYlGn", ) else: raise ValueError(f"Bad value for backend: {backend}") # Add labels plot.opts(xlabel="Predicted Label") plot.opts(ylabel="True Label") # Set fontsizes plot.opts( fontsize={ "title": text_size["title"], "labels": text_size["xyz_label"], "ticks": text_size["tick_label"], } ) if backend == "matplotlib": # Add colorbar. # 2025-07-08 [JRH]: backend_opts is a mpl-specific Workaround; doing # colorbar_opts={"label": ...} doesn't work for unknown reasons. plot.opts(colorbar=True, backend_opts={"colorbar.label": ""}) # Add title plot.opts(title=title) if xlabels_rotate: plot.opts(xrotation=90) _save(plot, output_fpath, backend) _logger.debug( "Graph written to <batchroot>/%s", output_fpath.relative_to(pathset.batchroot), ) return True
[docs] def generate_numeric( # noqa: PLR0913 pathset: _pathset.PathSet, input_stem: str, output_stem: str, medium: str, title: str, backend: str, colnames: tuple[str, str, str] = ("x", "y", "z"), xlabel: tp.Optional[str] = "", ylabel: tp.Optional[str] = "", zlabel: tp.Optional[str] = "", large_text: bool = False, xticklabels: tp.Optional[list[str]] = None, yticklabels: tp.Optional[list[str]] = None, xticks: tp.Optional[list[float]] = None, yticks: tp.Optional[list[float]] = None, transpose: bool = False, ext=config.STATS["mean"].exts["mean"], ) -> bool: """ Generate a X vs. Y vs. Z heatmap plot of a ``.mean`` file. If the necessary ``.mean`` file does not exist, the graph is not generated. Dataframe must be constructed with {x,y,z} columns; e.g.:: x,y,z 0,0,4 0,1,5 0,2,6 0,3,4 1,0,4 0,1,4 ... The ``x``, ``y`` columns are the indices, and the ``z`` column is the value in that cell. The names of these columns are configurable. """ hv.extension(backend, inline=False, logo=False) ofile_ext = _ofile_ext(backend) input_fpath = pathset.input_root / (input_stem + ext) output_fpath = pathset.output_root / f"HM-{output_stem}.{ofile_ext}" if not utils.path_exists(input_fpath): _logger.debug( "Not generating <batchroot>/%s: <batchroot>/%s does not exist", output_fpath.relative_to(pathset.batchroot.resolve()), input_fpath.relative_to(pathset.batchroot.resolve()), ) return False title = "\n".join(textwrap.wrap(title, 40)) text_size = ( config.GRAPHS["text_size_large"] if large_text else config.GRAPHS["text_size_small"] ) # Read .csv and create raw heatmap from default configuration df = storage.df_read(input_fpath, medium) # Convert to pandas for holoviews df_pd = df.to_pandas() dataset = hv.Dataset(df_pd, kdims=[colnames[0], colnames[1]], vdims=colnames[2]) # Transpose if requested if transpose: dataset.data = dataset.data.transpose() # Plot heatmap, without showing the Z-value in each cell, which generally # obscures things more than it helps. Plus, statistical significance isn't # observable from a heatmap, so numerical values are kind of moot. if backend == "matplotlib": plot = hv.HeatMap( dataset, kdims=[colnames[0], colnames[1]], vdims=[colnames[2]] ).opts(show_values=False) elif backend == "bokeh": plot = hv.HeatMap( dataset, kdims=[colnames[0], colnames[1]], vdims=[colnames[2]] ) else: raise ValueError(f"Bad value for backend: {backend}") if not xticks: xticks = dataset.data[colnames[0]] if not yticks: yticks = dataset.data[colnames[1]] # Add X,Y ticks if xticklabels: plot.opts(xticks=list(zip(xticks, xticklabels))) if yticklabels: plot.opts(yticks=list(zip(yticks, yticklabels))) # Add labels plot.opts(xlabel=xlabel) plot.opts(ylabel=ylabel) # Set fontsizes plot.opts( fontsize={ "title": text_size["title"], "labels": text_size["xyz_label"], "ticks": text_size["tick_label"], } ) # Add title plot.opts(title=title) if backend == "matplotlib": # Add colorbar. # 2025-07-08 [JRH]: backend_opts is a mpl-specific Workaround; doing # colorbar_opts={"label": ...} doesn't work for unknown reasons. plot.opts(colorbar=True, backend_opts={"colorbar.label": zlabel}) _save(plot, output_fpath, backend) _logger.debug( "Graph written to <batchroot>/%s", output_fpath.relative_to(pathset.batchroot), ) return True
[docs] def generate_dual_numeric( # noqa: PLR0913 pathset: _pathset.PathSet, ipaths: types.PathList, output_stem: pathlib.Path, medium: str, title: str, xlabel: tp.Optional[str] = None, ylabel: tp.Optional[str] = None, zlabel: tp.Optional[str] = None, large_text: bool = False, xticklabels: tp.Optional[list[str]] = None, yticklabels: tp.Optional[list[str]] = None, ) -> bool: """Generate a side-by-side plot of two heataps from two CSV files. ``.mean`` files must be named as ``<input_stem_fpath>_X.mean``, where `X` is non-negative integer. Input ``.mean`` files must be 2D grids of the same cardinality. This graph does not plot standard deviation. If there are not exactly two file paths passed, the graph is not generated. """ hv.extension("matplotlib", inline=False, logo=False) output_fpath = ( pathset.output_root / f"HM-{output_stem}.{config.GRAPHS['static_type']}" ) # Optional arguments text_size = ( config.GRAPHS["text_size_large"] if large_text else config.GRAPHS["text_size_small"] ) dfs = [storage.df_read(f, medium) for f in ipaths] if not dfs or len(dfs) != 2: _logger.debug( ("Not generating dual heatmap: wrong # files %s (must be 2)"), len(dfs) ) return False # Convert polars DataFrames to pandas for holoviews dfs_pd = [df.to_pandas() for df in dfs] yticks = np.arange(len(dfs_pd[0].columns)) xticks = dfs_pd[0].index # Plot heatmaps plot = hv.Image(dfs_pd[0]) + hv.Image(dfs_pd[1]) # Add X,Y ticks if xticklabels: plot.opts(xformatter=lambda x: xticklabels[list(xticks).index(x)]) if yticklabels: plot.opts(yformatter=lambda y: yticklabels[list(yticks).index(y)]) # Add labels plot.opts(xlabel=xlabel) plot.opts(ylabel=ylabel) # Set fontsizes plot.opts( fontsize={ "title": text_size["title"], "labels": text_size["xyz_label"], "ticks": text_size["tick_label"], } ) # Add title plot.opts(title=title) # Add colorbar. plot.opts( hv.opts.Layout(shared_axes=False), hv.opts.Image( colorbar=True, colorbar_position="right", backend_opts={"colorbar.label": zlabel}, ), ) # Output figures plot.opts(fig_inches=config.GRAPHS["base_size"]) hv.save( plot, output_fpath, fig=config.GRAPHS["static_type"], dpi=config.GRAPHS["dpi"], ) plt.close("all") _logger.debug( "Graph written to <batchroot>/%s", output_fpath.relative_to(pathset.batchroot), ) return True
def _ofile_ext(backend: str) -> tp.Optional[str]: if backend == "matplotlib": return str(config.GRAPHS["static_type"]) if backend == "bokeh": return str(config.GRAPHS["interactive_type"]) return None def _save(plot: hv.Overlay, output_fpath: pathlib.Path, backend: str) -> None: if backend == "matplotlib": hv.save( plot.opts(fig_inches=config.GRAPHS["base_size"]), output_fpath, fig=config.GRAPHS["static_type"], dpi=config.GRAPHS["dpi"], ) plt.close("all") elif backend == "bokeh": fig = hv.render(plot) # 2025-12-02 [JRH]: We don't set dimensions, because that makes the # interactive plots fixed size, which makes them unsuitable for # embedding into webpages. fig.sizing_mode = "scale_width" html = bokeh.embed.file_html(fig, resources=bokeh.resources.INLINE) with output_fpath.open("w") as f: f.write(html) __all__ = ["generate_confusion", "generate_dual_numeric", "generate_numeric"]