Source code for sierra.core.pipeline.pipeline

# Copyright 2018 London Lowmanstone, John Harwell, All rights reserved.
#
#  SPDX-License-Identifier: MIT
"""The 5 pipeline stages implemented by SIERRA.

See :ref:`concepts/pipeline` for high-level documentation.

"""

# Core packages
import typing as tp
import logging
import argparse
import pathlib
import os

# 3rd party packages
import yaml

# Project packages
import sierra.core.plugin as pm
from sierra.core import config, utils, batchroot, types

from sierra.core.pipeline.stage1.pipeline_stage1 import PipelineStage1
from sierra.core.pipeline.stage2.pipeline_stage2 import PipelineStage2
from sierra.core.pipeline.stage3.pipeline_stage3 import PipelineStage3
from sierra.core.pipeline.stage4.pipeline_stage4 import PipelineStage4
from sierra.core.pipeline.stage5.pipeline_stage5 import PipelineStage5


[docs] class Pipeline: "Implements SIERRA's 5 stage pipeline." def __init__( self, args: argparse.Namespace, controller: tp.Optional[str], pathset: tp.Optional[batchroot.PathSet] = None, ) -> None: self.logger = logging.getLogger(__name__) self.pathset = pathset self.logger.info("Computed run-time tree:\n%s", self.pathset) assert all( stage in [1, 2, 3, 4, 5] for stage in args.pipeline ), f"Invalid pipeline stage in {args.pipeline}: Only 1-5 valid" # The namespace passed in contains arguments for the core and all # plugins, so its OK to handle shortforms which aren't in the SIERRA # core at this point. This also preserves the "longforms trump # shortforms if both are passed" policy because their converted # shortforms are overwritten below. self.args = args shortforms = self._handle_shortforms() # Check for problematic characters in arguments used to create directory # paths. if any( "+" in arg for arg in [ self.args.scenario or [], self.args.controller or [], self.args.batch_criteria, ] ): raise RuntimeError( "{--scenario, --controller, --batch-criteria} cannot contain '+'." ) self.cmdopts = self._init_cmdopts(shortforms) self._load_config() if 5 not in self.args.pipeline: bc = pm.module_load_tiered( project=self.cmdopts["project"], path="variables.batch_criteria" ) self.batch_criteria = bc.factory( self.main_config, self.cmdopts, self.pathset.input_root, self.args ) self.controller = controller
[docs] def run(self) -> None: """ Run pipeline stages 1-5 as configured. """ if 1 in self.args.pipeline: PipelineStage1( self.cmdopts, self.pathset, self.controller, self.batch_criteria, ).run() if 2 in self.args.pipeline: PipelineStage2(self.cmdopts, self.pathset).run(self.batch_criteria) if 3 in self.args.pipeline: PipelineStage3(self.main_config, self.cmdopts, self.pathset).run( self.batch_criteria ) if 4 in self.args.pipeline: PipelineStage4(self.main_config, self.cmdopts, self.pathset).run( self.batch_criteria ) # not part of default pipeline if 5 in self.args.pipeline: PipelineStage5(self.main_config, self.cmdopts).run(self.args)
def _init_cmdopts(self, shortforms: types.Cmdopts) -> types.Cmdopts: longforms = { # multistage "pipeline": self.args.pipeline, "sierra_root": pathlib.Path(self.args.sierra_root).expanduser(), "scenario": self.args.scenario, "expdef_template": self.args.expdef_template, "project": self.args.project, "execenv": self.args.execenv, "engine_vc": self.args.engine_vc, "n_runs": self.args.n_runs, "exp_overwrite": self.args.exp_overwrite, "exp_range": self.args.exp_range, "engine": self.args.engine, "processing_parallelism": self.args.processing_parallelism, "exec_parallelism_paradigm": self.args.exec_parallelism_paradigm, "expdef": self.args.expdef, # stage 1 "preserve_seeds": self.args.preserve_seeds, # stage 2 "nodefile": self.args.nodefile, # stage 3 "proc": self.args.proc, "df_verify": self.args.df_verify, "df_homogenize": self.args.df_homogenize, "processing_mem_limit": self.args.processing_mem_limit, "storage": self.args.storage, # stage 4 "prod": self.args.prod, # stage 5 "compare": self.args.compare, } cmdopts = longforms # Load additional cmdline options from --engine self.logger.debug("Updating cmdopts from --engine=%s", cmdopts["engine"]) module = pm.module_load_tiered("cmdline", engine=cmdopts["engine"]) cmdopts |= module.to_cmdopts(self.args) # Load additional cmdline options from --execenv path = "{}.cmdline".format(cmdopts["execenv"]) if pm.module_exists(path): self.logger.debug("Updating cmdopts from --execenv=%s", cmdopts["execenv"]) module = pm.module_load_tiered(path) cmdopts |= module.to_cmdopts(self.args) # Load additional cmdline options from --expdef path = "{}.cmdline".format(cmdopts["expdef"]) if pm.module_exists(path): self.logger.debug("Updating cmdopts from --expdef=%s", cmdopts["expdef"]) module = pm.module_load_tiered(path) cmdopts |= module.to_cmdopts(self.args) # Load additional cmdline options from --proc plugins for p in cmdopts["proc"]: path = "{}.cmdline".format(p) if pm.module_exists(path): self.logger.debug("Updating cmdopts from --proc=%s", p) module = pm.module_load_tiered(path) cmdopts |= module.to_cmdopts(self.args) for p in cmdopts["prod"]: path = "{}.cmdline".format(p) if pm.module_exists(path): self.logger.debug("Updating cmdopts from --prod=%s", p) module = pm.module_load_tiered(path) cmdopts |= module.to_cmdopts(self.args) for p in cmdopts["compare"]: path = "{}.cmdline".format(p) if pm.module_exists(path): self.logger.debug("Updating cmdopts from --compare=%s", p) module = pm.module_load_tiered(path) cmdopts |= module.to_cmdopts(self.args) # Load additional cmdline options from --storage path = "{}.cmdline".format(cmdopts["storage"]) if pm.module_exists(path): self.logger.debug("Updating cmdopts from --storage=%s", cmdopts["storage"]) module = pm.module_load_tiered(path) cmdopts |= module.to_cmdopts(self.args) # Load additional cmdline options from project. This is mandatory, # because all projects have to define --controller and --scenario # at a minimum. self.logger.debug("Updating cmdopts from --project=%s", cmdopts["project"]) path = "{}.cmdline".format(cmdopts["project"]) module = pm.module_load(path) cmdopts |= module.to_cmdopts(self.args) # This has to be AFTER loading cmdopts from all plugins so that any # unset/defaulted options don't override the shortform. This also means # that shortforms override longforms if both are passed. cmdopts |= shortforms # Projects are specified as X.Y on cmdline so to get the path to the # project dir we combine the parent_dir (which is already a path) and # the name of the project (Y component). project = pm.pipeline.get_plugin(cmdopts["project"]) path = project["parent_dir"] / "/".join(cmdopts["project"].split(".")) cmdopts["project_root"] = str(path) cmdopts["project_config_root"] = str(path / "config") cmdopts["project_model_root"] = str(path / "models") return cmdopts def _handle_shortforms(self) -> types.Cmdopts: """ Replace all shortform arguments in with their longform counterparts. SIERRA always references arguments internally via longform if needed, so this is required. """ shortform_map = { "p": "plot", "e": "exp", "x": "exec", "s": "skip", } ret = {} for k, v in shortform_map.items(): passed = getattr(self.args, k, None) if not passed: self.logger.trace( ("No shortform args for -%s -> --%s passed to SIERRA"), k, v ) continue self.logger.trace( "Collected shortform args for -%s -> --%s: %s", k, v, passed, ) # There are 3 ways to pass shortform arguments, assuming a shortform # of 'X: # # 1. -Xarg # 2. -Xarg=foo # 3. -Xarg foo for p in passed: if len(p) == 1 and "=" not in p[0]: # boolean # Boolean shortfrom flags should store False if they contain # "no", as a user would expect. key = "{}_{}".format(v, p[0].replace("-", "_").replace("no_", "")) ret[key] = "no" not in p[0] elif len(p) == 1 and "=" in p[0]: arg, value = p[0].split("=") key = "{}_{}".format(v, arg.replace("-", "_")) ret[key] = value else: key = "{}_{}".format(v, p[1:].replace("-", "_")) ret[key] = p[1:] return ret def _load_config(self) -> None: self.logger.debug( "Loading project config from '%s'", self.cmdopts["project_config_root"] ) main_path = pathlib.Path( self.cmdopts["project_config_root"], config.PROJECT_YAML.main ) try: with utils.utf8open(main_path) as f: self.main_config = yaml.load(f, yaml.FullLoader) except FileNotFoundError: self.logger.fatal("%s must exist!", main_path) raise
__all__ = ["Pipeline"]