# Copyright 2021 John Harwell, All rights reserved.
#
# SPDX-License-Identifier: MIT
"""SIERRA plugin management to make SIERRA OPEN/CLOSED.
Also contains checks that selected plugins implement the necessary classes and
functions. Currently checkes: ``--storage``, ``--execenv``, and ``--engine``.
"""
# Core packages
# Core packages
import importlib.util
import importlib
import typing as tp
import sys
import logging
import pathlib
import inspect
# 3rd party packages
# 3rd party packages
import json
# Project packages
from sierra.core import types
class BasePluginManager:
"""Base class for common functionality."""
def __init__(self) -> None:
self.logger = logging.getLogger(__name__)
self.loaded = {} # type: tp.Dict[str, tp.Dict]
def available_plugins(self):
raise NotImplementedError
def loaded_plugins(self):
return self.loaded
def load_plugin(self, name: str) -> None:
"""Load a plugin module."""
plugins = self.available_plugins()
if name not in plugins:
self.logger.fatal("Cannot locate plugin %s", name)
self.logger.fatal(
"Loaded plugins: %s\n",
json.dumps(self.loaded, default=lambda x: "<ModuleSpec>", indent=4),
)
raise RuntimeError(f"Cannot locate plugin '{name}'")
init = importlib.util.module_from_spec(plugins[name]["init_spec"])
plugins[name]["init_spec"].loader.exec_module(init)
if not hasattr(init, "sierra_plugin_type"):
self.logger.warning(
"Cannot load plugin %s: __init__.py does not define sierra_plugin_type()",
name,
)
return
plugin_type = init.sierra_plugin_type()
# The name of the module is only needed for pipeline plugins, not
# project plugins.
if plugins[name]["module_spec"] is None and plugin_type == "pipeline":
if not hasattr(init, "sierra_plugin_module"):
self.logger.warning(
"Cannot load plugin %s: __init__.py does not define sierra_plugin_module()",
name,
)
return
modname = init.sierra_plugin_module()
fpath = (
plugins[name]["parent_dir"] / name.replace(".", "/") / f"{modname}.py"
)
plugins[name]["module_spec"] = importlib.util.spec_from_file_location(
modname, fpath
)
if plugin_type == "pipeline":
self._load_pipeline_plugin(name)
elif plugin_type == "project":
self._load_project_plugin(name)
elif plugin_type == "model":
if not hasattr(init, "sierra_models"):
self.logger.warning(
"Cannot load plugin %s: __init__.py does not define sierra_models()",
name,
)
return
self._load_model_plugin(name)
else:
self.logger.warning(
"Unknown plugin type '%s' for %s: cannot load", plugin_type, name
)
def get_plugin(self, name: str) -> dict:
try:
return self.loaded[name]
except KeyError:
self.logger.fatal("No such plugin %s", name)
self.logger.fatal(
"Loaded plugins: %s\n",
json.dumps(self.loaded, default=lambda x: "<ModuleSpec>", indent=4),
)
raise
def get_plugin_module(self, name: str) -> types.ModuleType:
try:
return self.loaded[name]["module"]
except KeyError:
self.logger.fatal("No such plugin %s", name)
self.logger.fatal(
"Loaded plugins: %s\n",
json.dumps(self.loaded, default=lambda x: "<ModuleSpec>", indent=4),
)
raise
def has_plugin(self, name: str) -> bool:
return name in self.loaded
def _load_pipeline_plugin(self, name: str) -> None:
if name in self.loaded:
self.logger.warning("Pipeline plugin %s already loaded", name)
return
plugins = self.available_plugins()
# The parent directory of the plugin must be on sys.path so it can be
# imported, so we put in on there if it isn't.
new = str(plugins[name]["parent_dir"])
if new not in sys.path:
sys.path = [new, *sys.path[0:]]
self.logger.debug("Updated sys.path with %s", [new])
module = importlib.util.module_from_spec(plugins[name]["module_spec"])
plugins[name]["module_spec"].loader.exec_module(module)
# When importing with importlib, the module is not automatically added
# to sys.modules. This means that trying to pickle anything in it will
# fail with a rather cryptic 'AttributeError', so we explicitly add the
# last path component of the plugin name--which is the actual name of
# the module the plugin lives in--to sys.modules so that pickling will
# work.
sys_modname = name.split(".")[1]
if sys_modname not in sys.modules:
sys.modules[sys_modname] = module
self.loaded[name] = {
"spec": plugins[name]["module_spec"],
"parent_dir": plugins[name]["parent_dir"],
"module": module,
"type": "pipeline",
}
self.logger.debug(
"Loaded pipeline plugin %s from %s -> %s",
name,
plugins[name]["parent_dir"],
name,
)
def _load_project_plugin(self, name: str) -> None:
if name in self.loaded:
self.logger.warning("Project plugin %s already loaded", name)
return
plugins = self.available_plugins()
# The parent directory of the plugin must be on sys.path so it can be
# imported, so we put in on there if it isn't.
new = str(plugins[name]["parent_dir"])
if new not in sys.path:
sys.path = [new, *sys.path[0:]]
self.logger.debug("Updated sys.path with %s", [new])
self.loaded[name] = {
"spec": plugins[name]["module_spec"],
"parent_dir": plugins[name]["parent_dir"],
"type": "project",
}
self.logger.debug(
("Loaded project plugin %s from %s -> %s"),
name,
plugins[name]["parent_dir"],
name,
)
def _load_model_plugin(self, name: str) -> None:
if name in self.loaded:
self.logger.warning("Model plugin %s already loaded", name)
return
plugins = self.available_plugins()
# The parent directory of the plugin must be on sys.path so it can be
# imported, so we put in on there if it isn't.
new = str(plugins[name]["parent_dir"])
if new not in sys.path:
sys.path = [new, *sys.path[0:]]
self.logger.debug("Updated sys.path with %s", [new])
self.loaded[name] = {
"spec": plugins[name]["module_spec"],
"parent_dir": plugins[name]["parent_dir"],
"type": "model",
}
self.logger.debug(
("Loaded model plugin %s from %s -> %s"),
name,
plugins[name]["parent_dir"],
name,
)
[docs]
class DirectoryPluginManager(BasePluginManager):
"""Container for managing directory-based plugins."""
def __init__(self) -> None:
super().__init__()
self.plugins = {} # type: tp.Dict[str, tp.Dict]
def initialize(self, project: str, search_path: list[pathlib.Path]) -> None:
self.logger.debug(
"Initializing with plugin search path %s", [str(p) for p in search_path]
)
for path in search_path:
if not path.exists():
self.logger.warning(
"Non-existent path '%s' on SIERRA_PLUGIN_PATH", path
)
continue
self.logger.debug("Searching for plugins in '%s'", path)
def recursive_search(root: pathlib.Path) -> None:
for f in root.iterdir():
if not f.is_dir():
continue
recursive_search(f)
plugin = f / "plugin.py"
init = f / "__init__.py"
cookie = f / ".sierraplugin"
# 2025-11-24 [JRH]: The cookie is ALWAYS required. We used
# to just recognize a directory containing
# plugin.py+__init__.py as a SIERRA plugin, but that is far
# too generic, and caused conflicts with other python
# packages installed in the same environment.
if not (cookie.exists() and (plugin.exists() or init.exists())):
continue
name = f"{f.parent.name}.{f.name}"
try:
if plugin.exists():
module_spec = importlib.util.spec_from_file_location(
f.name, plugin
)
else:
module_spec = None
init_spec = importlib.util.spec_from_file_location(
"__init__", init
)
except FileNotFoundError:
self.logger.warning(
"Malformed plugin in %s: not loading", f.relative_to(root)
)
self.logger.debug("Found plugin in '%s' -> %s", f, name)
self.plugins[name] = {
"parent_dir": root.parent,
"module_spec": module_spec,
"init_spec": init_spec,
}
recursive_search(path)
def available_plugins(self):
return self.plugins
[docs]
def module_exists(name: str) -> bool:
"""
Check if a module exists before trying to import it.
"""
try:
_ = __import__(name)
except ImportError:
return False
return True
[docs]
def module_load(name: str) -> types.ModuleType:
"""
Import the specified module.
"""
return __import__(name, fromlist=["*"])
[docs]
def bc_load(cmdopts: types.Cmdopts, category: str):
"""
Load the specified :term:`Batch Criteria`.
"""
path = f"variables.{category}"
return module_load_tiered(
project=cmdopts["project"], engine=cmdopts["engine"], path=path
)
[docs]
def module_load_tiered(
path: str, project: tp.Optional[str] = None, engine: tp.Optional[str] = None
) -> types.ModuleType:
"""Attempt to load the specified python module with tiered precedence.
Generally, the precedence is project -> project submodule -> engine module
-> SIERRA core module, to allow users to override SIERRA core functionality
with ease. Specifically:
#. Check if the requested module directly exists. If it does, return
it.
#. Check if the requested module is a part of a project (i.e.,
``<project>.<path>`` exists). If it does, return it. This requires
that :envvar:`SIERRA_PLUGIN_PATH` to be set properly.
#. Check if the requested module is provided by the engine plugin (i.e.,
``sierra.engine.<engine>.<path>`` exists). If it does, return it.
#. Check if the requested module is part of the SIERRA core (i.e.,
``sierra.core.<path>`` exists). If it does, return it.
If no match was found using any of these, throw an error.
"""
# First, see if the requested module is a project/directly exists as
# specified.
if module_exists(path):
logging.trace("Using direct path %s", path)
return module_load(path)
# Next, check if the requested module is part of the project plugin
if project is not None:
component_path = f"{project}.{path}"
if module_exists(component_path):
logging.trace("Using project component path %s", component_path)
return module_load(component_path)
logging.trace(
"Project component path %s does not exist",
component_path,
)
# If that didn't work, check the engine plugin
if engine is not None:
engine_path = f"{engine}.{path}"
if module_exists(engine_path):
logging.trace("Using engine component path %s", engine_path)
return module_load(engine_path)
# If that didn't work, then check the SIERRA core
core_path = f"sierra.core.{path}"
if module_exists(core_path):
logging.trace("Using SIERRA core path %s", core_path)
return module_load(core_path)
logging.trace("SIERRA core path %s does not exist", core_path)
# Module does not exist
error = (
f"project: '{project}' "
f"engine: '{engine}' "
f"path: '{path}' "
f"sys.path: {sys.path}"
)
raise ImportError(error)
[docs]
def storage_sanity_checks(medium: str, module) -> None:
"""
Check the selected ``--storage`` plugin.
"""
logging.trace("Verifying --storage=%s plugin interface", medium)
functions = ["supports_input", "supports_output"]
in_module = inspect.getmembers(module, inspect.isfunction)
for f in functions:
assert any(
f == name for (name, _) in in_module
), f"Storage medium {medium} does not define {f}()"
def expdef_sanity_checks(expdef: str, module) -> None:
"""
Check the selected ``--expdef`` plugin.
"""
logging.trace("Verifying --expdef=%s plugin interface", expdef)
functions = ["root_querypath", "unpickle"]
module_funcs = inspect.getmembers(module, inspect.isfunction)
module_classes = inspect.getmembers(module, inspect.isclass)
classes = ["ExpDef", "Writer"]
for c in classes:
assert any(
c == name for (name, _) in module_classes
), f"Expdef plugin {expdef} does not define {c}"
for f in functions:
assert any(
f == name for (name, _) in module_funcs
), f"Expdef {expdef} does not define {f}()"
[docs]
def proc_sanity_checks(proc: str, module) -> None:
"""
Check the selected ``--proc`` plugins.
"""
logging.trace("Verifying --proc=%s plugin interface", proc)
functions = ["proc_batch_exp"]
in_module = inspect.getmembers(module, inspect.isfunction)
for f in functions:
assert any(
f == name for (name, _) in in_module
), f"Processing plugin {proc} does not define {f}()"
[docs]
def prod_sanity_checks(prod: str, module) -> None:
"""
Check the selected ``--prod`` plugins.
"""
logging.trace("Verifying --prod=%s plugin interface", prod)
functions = ["proc_batch_exp"]
in_module = inspect.getmembers(module, inspect.isfunction)
for f in functions:
assert any(
f == name for (name, _) in in_module
), f"Product plugin {prod} does not define {f}()"
[docs]
def compare_sanity_checks(compare: str, module) -> None:
"""
Check the selected ``--compare`` plugins.
"""
logging.trace("Verifying --compare=%s plugin interface", compare)
functions = ["proc_exps"]
in_module = inspect.getmembers(module, inspect.isfunction)
for f in functions:
assert any(
f == name for (name, _) in in_module
), f"Comparison plugin {compare} does not define {f}()"
[docs]
def execenv_sanity_checks(execenv: str, module) -> None:
"""
Check the selected ``--execenv`` plugin.
"""
logging.trace("Verifying --execenv=%s plugin interface", execenv)
in_module = inspect.getmembers(module, inspect.isclass)
opt_functions = ["cmdline_postparse_configure", "execenv_check"]
opt_classes = ["ExpRunShellCmdsGenerator", "ExpShellCmdsGenerator"]
for c in opt_classes:
if not any(c == name for (name, _) in in_module):
logging.debug(
(
"Execution environment plugin %s does not define "
"%s--some SIERRA functionality may not be "
"available. See docs for details."
),
execenv,
c,
)
for f in opt_functions:
if not any(f in name for (name, _) in in_module):
logging.debug(
("Execution environment plugin %s does not define %s()."),
execenv,
f,
)
[docs]
def engine_sanity_checks(engine: str, module) -> None:
"""
Check the selected ``--engine`` plugin.
"""
logging.trace("Verifying --engine=%s plugin interface", engine)
req_classes = [
"ExpConfigurer",
]
req_functions = [] # type: list[str]
opt_classes = ["ExpRunShellCmdsGenerator", "ExpShellCmdsGenerator"]
opt_functions = [
"cmdline_postparse_configure",
"execenv_check",
"agent_prefix_extract",
"arena_dims_from_criteria",
"population_size_from_def",
"population_size_from_pickle",
]
in_module = inspect.getmembers(module, inspect.isclass)
for c in req_classes:
assert any(
c == name for (name, _) in in_module
), f"Engine plugin {engine} does not define {c}"
for f in opt_classes:
if not any(f in name for (name, _) in in_module):
logging.debug(
(
"Engine plugin %s does not define %s"
"--some SIERRA functionality may not be available. "
"See docs for details."
),
engine,
f,
)
in_module = inspect.getmembers(module, inspect.isfunction)
for f in req_functions:
assert any(
f == name for (name, _) in in_module
), f"Engine plugin {engine} does not define {f}()"
for f in opt_functions:
if not any(f == name for (name, _) in in_module):
logging.debug(
(
"Engine plugin %s does not define %s()"
"--some SIERRA functionality may not be available. "
"See docs for details."
),
engine,
f,
)
# Singletons
pipeline = DirectoryPluginManager()
models = DirectoryPluginManager()
__all__ = [
"DirectoryPluginManager",
"bc_load",
"compare_sanity_checks",
"engine_sanity_checks",
"execenv_sanity_checks",
"module_exists",
"module_load",
"module_load_tiered",
"proc_sanity_checks",
"prod_sanity_checks",
"storage_sanity_checks",
]