Source code for sierra.core.graphs.stacked_line

# Copyright 2018 John Harwell, All rights reserved.
#
#  SPDX-License-Identifier: MIT
#
"""
Intra-experiment line graph generation classes for stage{4,5}.
"""

# Core packages
import typing as tp
import logging
import pathlib

# 3rd party packages
import polars as pl
import holoviews as hv
import matplotlib.pyplot as plt
import bokeh

# Project packages
from sierra.core import config, utils, storage, models
from . import pathset

_logger = logging.getLogger(__name__)


def _ofile_ext(backend: str) -> tp.Optional[str]:
    if backend == "matplotlib":
        return str(config.GRAPHS["static_type"])

    if backend == "bokeh":
        return str(config.GRAPHS["interactive_type"])

    return None


[docs] def generate( # noqa: PLR0913 paths: pathset.PathSet, input_stem: str, output_stem: str, title: str, medium: str, backend: str, xticks: tp.Optional[list[float]] = None, stats: tp.Optional[str] = None, xlabel: tp.Optional[str] = None, ylabel: tp.Optional[str] = None, points: tp.Optional[bool] = False, large_text: bool = False, legend: tp.Optional[list[str]] = None, xticklabels: tp.Optional[list[str]] = None, cols: tp.Optional[list[str]] = None, logyscale: bool = False, ext: str = config.STATS["mean"].exts["mean"], ) -> bool: """Generate a line graph from a set of columns in a file. If the necessary data file does not exist, the graph is not generated. If the .stddev file that goes with the .mean does not exist, then no error bars are plotted. If the .model file that goes with the .mean does not exist, then no model predictions are plotted. Ideally, model predictions/stddev calculations would be in derived classes, but I can't figure out a good way to easily pull that stuff out of here. """ hv.extension(backend, inline=False, logo=False) input_fpath = paths.input_root / (input_stem + ext) output_fpath = paths.output_root / "SLN-{}.{}".format( output_stem, _ofile_ext(backend) ) text_size = ( config.GRAPHS["text_size_large"] if large_text else config.GRAPHS["text_size_small"] ) if not utils.path_exists(input_fpath): _logger.debug( "Not generating <batchroot>/%s: <batchroot>/%s does not exist", output_fpath.relative_to(paths.batchroot), input_fpath.relative_to(paths.batchroot), ) return False df = storage.df_read(input_fpath, medium) # Use xticks if provided, otherwise default to using row indices as xticks dfcols = df.columns # Add row index first df = df.with_row_index("index") # Add xticks column if xticks is not None: df = df.with_columns(pl.Series("xticks", xticks)) else: df = df.with_columns(pl.col("index").cast(pl.Float64).alias("xticks")) # Convert to pandas for holoviews compatibility df_pd = df.to_pandas() dataset = hv.Dataset( data=df_pd, kdims=["index"], vdims=cols if cols else list(dfcols), ) assert len(df) == len( df["xticks"] ), "Length mismatch between xticks,# data points: {} vs {}".format( len(df["xticks"]), len(df) ) model = _read_models(paths.model_root, input_stem, medium) stat_dfs = _read_stats(stats, paths.input_root, input_stem, medium) # Plot stats if they have been computed FIRST, so they appear behind the # actual data. if stats and "conf95" in stats and "stddev" in stat_dfs: plot = _plot_stats_stddev(dataset, stat_dfs["stddev"]) plot *= _plot_selected_cols(dataset, model, legend, points, backend) elif ( stats and "bw" in stats and all(k in stat_dfs for k in config.STATS["bw"].exts) ): # 2025-10-06 [JRH]: This is a limitation of hv (I think). Manually # specifying bw plots around each datapoint on a graph can easily exceed # the max # of things that can be in a single overlay. _logger.warning("bw statistics not implemented for stacked_line graphs") plot = _plot_selected_cols(dataset, model, legend, points, backend) else: # Plot specified columns from dataframe. plot = _plot_selected_cols(dataset, model, legend, points, backend) # Let the backend decide # of columns; can override with # legend_cols=N in the future if desired. plot.opts(legend_position="bottom") # Add title plot.opts(title=title) # Add X,Y labels if xlabel is not None: plot.opts(xlabel=xlabel) if ylabel is not None: plot.opts(ylabel=ylabel) # Set fontsizes plot.opts( fontsize={ "title": text_size["title"], "labels": text_size["xyz_label"], "ticks": text_size["tick_label"], "legend": text_size["legend_label"], }, ) if logyscale: _min = min(dataset[vdim].min() for vdim in dataset.vdims) _max = max(dataset[vdim].max() for vdim in dataset.vdims) plot.opts( logy=True, ylim=( _min * 0.9, _max * 1.1, ), ) _save(plot, output_fpath, backend) _logger.debug( "Graph written to <batchroot>/%s", output_fpath.relative_to(paths.batchroot), ) return True
def _save(plot: hv.Overlay, output_fpath: pathlib.Path, backend: str) -> None: if backend == "matplotlib": hv.save( plot.opts(fig_inches=config.GRAPHS["base_size"]), output_fpath, fig=config.GRAPHS["static_type"], dpi=config.GRAPHS["dpi"], ) plt.close("all") elif backend == "bokeh": fig = hv.render(plot) # 2025-12-02 [JRH]: We don't set dimensions, because that makes the # interactive plots fixed size, which makes them unsuitable for # embedding into webpages. fig.sizing_mode = "scale_width" html = bokeh.embed.file_html(fig, resources=bokeh.resources.INLINE) with utils.utf8open(output_fpath, "w") as f: f.write(html) def _plot_selected_cols( dataset: hv.Dataset, model_info: models.ModelInfo, legend: list[str], show_points: bool, backend: str, ) -> hv.NdOverlay: """ Plot the selected columns in a dataframe. """ # Always plot the data plot = hv.Overlay( [ hv.Curve( dataset, dataset.kdims[0], vdim.name, label=legend[dataset.vdims.index(vdim)] if legend else "", ) for vdim in dataset.vdims ] ) # Plot the points for each curve if configured to do so, OR if there aren't # that many. If you print them and there are a lot, you essentially get # really fat lines which doesn't look good. plot *= hv.Overlay( [ hv.Points((dataset[dataset.kdims[0]], dataset[v])) for v in dataset.vdims if len(dataset[v]) <= 50 or show_points ] ) if backend == "matplotlib": opts = { "linestyle": "--", } elif backend == "bokeh": opts = {"line_dash": [6, 3]} # Plot models if they have been computed if model_info.dataset: plot *= hv.Overlay( [ hv.Curve( model_info.dataset, model_info.dataset.kdims[0], vdim.name, label=model_info.legend[model_info.dataset.vdims.index(vdim)], ).opts(**opts) for vdim in model_info.dataset.vdims ] ) # Plot the points for each curve plot *= hv.Overlay( [ hv.Points( ( model_info.dataset[model_info.dataset.kdims[0]], model_info.dataset[v], ) ) for v in model_info.dataset.vdims if len(model_info.dataset[v]) <= 50 or show_points ] ) return plot def _plot_stats_stddev(dataset: hv.Dataset, stddev_df: pl.DataFrame) -> hv.NdOverlay: """Plot the stddev for all columns in the dataset.""" # Build stddev columns dictionary stddev_cols = {} for c in dataset.vdims: stddev_vals = stddev_df[c.name].abs().to_numpy() stddev_cols[f"{c}_stddev_l"] = dataset.data[c.name] - 2 * stddev_vals stddev_cols[f"{c}_stddev_u"] = dataset.data[c.name] + 2 * stddev_vals # Add stddev columns to dataset for col_name, col_data in stddev_cols.items(): dataset.data[col_name] = col_data return hv.Overlay( [ hv.Area( dataset, vdims=[f"{vdim.name}_stddev_l", f"{vdim.name}_stddev_u"] ).opts( alpha=0.5, ) for vdim in dataset.vdims ] ) def _read_stats( setting: tp.Optional[str], stats_root: pathlib.Path, input_stem: str, medium: str ) -> dict[str, pl.DataFrame]: dfs = {} # type: tp.Dict[str, pl.DataFrame] settings = [] if setting == "none": return dfs settings = ["conf95", "bw"] if setting == "all" else [setting] if setting in settings: exts = config.STATS[setting].exts for k in exts: ipath = stats_root / (input_stem + exts[k]) if utils.path_exists(ipath): dfs[k] = storage.df_read(ipath, medium) else: _logger.warning("%s not found for '%s'", exts[k], input_stem) return dfs # 2024/09/13 [JRH]: The union is for compatability with type checkers in # python {3.8,3.11}. def _read_models( model_root: tp.Optional[pathlib.Path], input_stem: str, medium: str ) -> models.ModelInfo: if model_root is None: return models.ModelInfo() modelf = model_root / (input_stem + config.MODELS_EXT["model"]) legendf = model_root / (input_stem + config.MODELS_EXT["legend"]) if not utils.path_exists(modelf): _logger.trace("Model file %s missing for graph", str(modelf)) return models.ModelInfo() info = models.ModelInfo() df = storage.df_read(modelf, medium) cols = list(df.columns) # Add index and convert to pandas for holoviews df = df.with_row_index("index") df_pd = df.to_pandas() info.dataset = hv.Dataset(data=df_pd, kdims=["index"], vdims=cols) with utils.utf8open(legendf, "r") as f: info.legend = f.read().splitlines() _logger.trace( "Loaded model='%s',legend='%s'", modelf.relative_to(model_root), legendf.relative_to(model_root), ) return info __all__ = ["generate"]