Source code for sierra.core.graphs.stacked_line_graph

# Copyright 2018 John Harwell, All rights reserved.
#
#  SPDX-License-Identifier: MIT
#

# Core packages
import typing as tp
import copy
import logging
import pathlib

# 3rd party packages
import pandas as pd
import matplotlib.pyplot as plt

# Project packages
from sierra.core import config, utils, storage


[docs]class StackedLineGraph: """Generates a line graph from a set of columns in a CSV file. If the necessary data file does not exist, the graph is not generated. If the .stddev file that goes with the .mean does not exist, then no error bars are plotted. If the .model file that goes with the .mean does not exist, then no model predictions are plotted. Ideally, model predictions/stddev calculations would be in derived classes, but I can't figure out a good way to easily pull that stuff out of here. """
[docs] def __init__(self, stats_root: pathlib.Path, input_stem: str, output_fpath: pathlib.Path, title: str, xlabel: tp.Optional[str] = None, ylabel: tp.Optional[str] = None, large_text: bool = False, legend: tp.Optional[tp.List[str]] = None, cols: tp.Optional[tp.List[str]] = None, linestyles: tp.Optional[tp.List[str]] = None, dashstyles: tp.Optional[tp.List[str]] = None, logyscale: bool = False, stddev_fpath: tp.Optional[pathlib.Path] = None, stats: str = 'none', model_root: tp.Optional[pathlib.Path] = None) -> None: # Required arguments self.stats_root = stats_root self.input_stem = input_stem self.output_fpath = output_fpath self.title = title # Optional arguments if large_text: self.text_size = config.kGraphTextSizeLarge else: self.text_size = config.kGraphTextSizeSmall self.xlabel = xlabel self.ylabel = ylabel self.legend = legend self.cols = cols self.logyscale = logyscale self.linestyles = linestyles self.dashstyles = dashstyles self.stats = stats self.model_root = model_root self.stddev_fpath = stddev_fpath self.logger = logging.getLogger(__name__)
[docs] def generate(self) -> None: ext = config.kStats['mean'].exts['mean'] input_fpath = self.stats_root / (self.input_stem + ext) if not utils.path_exists(input_fpath): self.logger.debug("Not generating %s: %s does not exist", self.output_fpath, input_fpath) return data_df = storage.DataFrameReader('storage.csv')(input_fpath) model = self._read_models() stat_dfs = self._read_stats() # Plot specified columns from dataframe. if self.cols is None: ncols = max(1, int(len(data_df.columns) / 2.0)) ax = self._plot_selected_cols( data_df, stat_dfs, data_df.columns, model) else: ncols = max(1, int(len(self.cols) / 2.0)) ax = self._plot_selected_cols(data_df, stat_dfs, self.cols, model) self._plot_ticks(ax) self._plot_legend(ax, model[1], ncols) # Add title ax.set_title(self.title, fontsize=self.text_size['title']) # Add X,Y labels if self.xlabel is not None: ax.set_xlabel(self.xlabel, fontsize=self.text_size['xyz_label']) if self.ylabel is not None: ax.set_ylabel(self.ylabel, fontsize=self.text_size['xyz_label']) # Output figure fig = ax.get_figure() fig.set_size_inches(config.kGraphBaseSize, config.kGraphBaseSize) fig.savefig(self.output_fpath, bbox_inches='tight', dpi=config.kGraphDPI) # Prevent memory accumulation (fig.clf() does not close everything) plt.close(fig)
[docs] def _plot_ticks(self, ax) -> None: if self.logyscale: plt.yscale('symlog') ax.tick_params(labelsize=self.text_size['tick_label'])
[docs] def _plot_selected_cols(self, data_df: pd.DataFrame, stat_dfs: tp.Dict[str, pd.DataFrame], cols: tp.List[str], model: tp.Tuple[pd.DataFrame, tp.List[str]]): """ Plot the selected columns in a dataframe. This may include: - Errorbars - Models - Custom line styles - Custom dash styles """ # Always plot the data if self.linestyles is None and self.dashstyles is None: ax = data_df[cols].plot() elif self.dashstyles is None and self.linestyles is not None: for c, s in zip(cols, self.linestyles): ax = data_df[c].plot(linestyle=s) elif self.dashstyles is not None and self.linestyles is not None: for c, s, d in zip(cols, self.linestyles, self.dashstyles): ax = data_df[c].plot(linestyle=s, dashes=d) # Plot models if they have been computed if model[0] is not None: model[0][model[0].columns].plot(ax=ax) # Plot stddev if it has been computed if 'stddev' in stat_dfs.keys(): for c in cols: self._plot_col_errorbars(data_df, stat_dfs['stddev'], c) return ax
[docs] def _plot_col_errorbars(self, data_df: pd.DataFrame, stddev_df: pd.DataFrame, col: str) -> None: """ Plot the errorbars for a specific column in a dataframe. """ plt.fill_between(data_df.index, data_df[col] - 2 * stddev_df[col].abs(), data_df[col] + 2 * stddev_df[col].abs(), alpha=0.50)
[docs] def _plot_legend(self, ax, model_legend: tp.List[str], ncols: int) -> None: # Should always use one column: If a small number of lines are plotted, # then 1 vs. > 1 column makes no difference. If a large number of lines # are plotted, this will make the figures more portrait than landscape, # which is more amenable to inclusion in academic papers. # If the legend is not specified, then we assume this is not a graph # that will contain any models and/or no legend is desired. legend = self.legend if self.legend is not None: legend = copy.deepcopy(self.legend) if model_legend: legend.extend(model_legend) lines, _ = ax.get_legend_handles_labels() ax.legend(lines, legend, loc='lower center', bbox_to_anchor=(0.5, -0.5), ncol=1, fontsize=self.text_size['legend_label']) else: ax.legend().remove()
# ax.legend(loc='lower center', # bbox_to_anchor=(0.5, -0.5), # ncol=1, # fontsize=self.text_size['legend_label'])
[docs] def _read_stats(self) -> tp.Dict[str, pd.DataFrame]: dfs = {} reader = storage.DataFrameReader('storage.csv') if self.stats in ['conf95', 'all']: exts = config.kStats['conf95'].exts for k in exts: ipath = self.stats_root / (self.input_stem + exts[k]) if utils.path_exists(ipath): dfs[k] = reader(ipath) else: self.logger.warning("%sfile not found for '%s'", exts[k], self.input_stem) return dfs
[docs] def _read_models(self) -> tp.Tuple[pd.DataFrame, tp.List[str]]: if self.model_root is not None: model_fpath = self.model_root / \ (self.input_stem + config.kModelsExt['model']) legend_fpath = self.model_root / \ (self.input_stem + config.kModelsExt['legend']) if utils.path_exists(model_fpath): model = storage.DataFrameReader('storage.csv')(model_fpath) if utils.path_exists(legend_fpath): with utils.utf8open(legend_fpath, 'r') as f: model_legend = f.read().splitlines() else: self.logger.warning("No legend file for model '%s' found", model_fpath) model_legend = ['Model Prediction'] return (model, model_legend) return (None, [])
__api__ = [ 'StackedLineGraph' ]