Source code for sierra.plugins.expdef.xml.plugin

# Copyright 2024 John Harwell, All rights reserved.
#
#  SPDX-License-Identifier: MIT
"""Plugin for parsing and manipulating template input files in XML format."""

# Core packages
import pathlib
import logging
import xml.etree.ElementTree as ET
import typing as tp

# 3rd party packages

# Project packages
from sierra.core.experiment import definition
from sierra.core import types


class Writer:
    """Write the XML experiment to the filesystem according to configuration.

    More than one file may be written, as specified.
    """

    def __init__(self, tree: ET.ElementTree) -> None:
        self.tree = tree
        self.root = tree.getroot()
        self.logger = logging.getLogger(__name__)

    def __call__(
        self, write_config: definition.WriterConfig, base_opath: pathlib.Path
    ) -> None:
        for config in write_config.values:
            self._write_with_config(base_opath, config)

    def _write_with_config(
        self, base_opath: tp.Union[pathlib.Path, str], config: dict
    ) -> None:
        tree, src_root, opath = self._prepare_tree(pathlib.Path(base_opath), config)

        if tree is None:
            self.logger.warning(
                "Cannot write non-existent tree@'%s' to '%s'", src_root, opath
            )
            return

        self.logger.trace("Write tree@%s to %s", src_root, opath)

        # Renaming tree root is not required
        if "rename_to" in config and config["rename_to"] is not None:
            tree.tag = config["rename_to"]
            self.logger.trace("Rename tree root -> %s", config["rename_to"])

        # Adding new children not required
        if all(
            k in config and config[k] is not None
            for k in ["new_children_parent", "new_children"]
        ):
            self._add_new_children(config, tree)

        # Grafts are not required
        if all(
            k in config and config[k] is not None
            for k in ["child_grafts_parent", "child_grafts"]
        ):
            self._add_grafts(config, tree)

        to_write = ET.ElementTree(tree)

        ET.indent(to_write.getroot(), space="\t", level=0)
        ET.indent(to_write, space="\t", level=0)
        to_write.write(opath, encoding="utf-8")

    def _add_grafts(self, config: dict, tree: ET.Element) -> None:

        graft_parent = tree.find(config["child_grafts_parent"])
        assert graft_parent is not None, f"Bad parent '{graft_parent}' for grafting"

        for g in config["child_grafts"]:
            self.logger.trace("Graft tree@'%s' as child under '%s'", g, graft_parent)
            elt = self.root.find(g)
            graft_parent.append(elt)

    def _add_new_children(self, config: dict, tree: ET.ElementTree) -> None:
        """Given the experiment definition, add new children as configured.

        We operate on the whole definition in-situ, rather than creating a new
        subtree with all the children because that is less error prone in terms
        of grafting the new subtree back into the experiment definition.
        """

        parent = tree.find(config["new_children_parent"])

        assert (
            parent is not None
        ), f"Could not find parent '{0}' of new children".format(
            config["new_children_parent"]
        )
        for spec in config["new_children"]:
            if spec.as_root_elt:
                # Special case: Adding children to an empty tree
                tree = ET.Element(spec.path, spec.attr)
                continue

            elt = parent.find(spec.path)

            assert elt is not None, (
                f"Could not find parent '{spec.path}' of new child element '{spec.tag}' "
                "to add"
            )

            ET.SubElement(elt, spec.tag, spec.attr)

            self.logger.trace(
                "Create child element '%s' under '%s'",
                spec.tag,
                spec.path,
            )

    def _prepare_tree(
        self, base_opath: pathlib.Path, config: dict
    ) -> tuple[tp.Optional[ET.Element], str, pathlib.Path]:
        assert "src_parent" in config, "'src_parent' key is required"
        assert (
            "src_tag" in config and config["src_tag"] is not None
        ), "'src_tag' key is required"

        if config["src_parent"] is None:
            src_root = config["src_tag"]
        else:
            src_root = "{}/{}".format(config["src_parent"], config["src_tag"])

        tree_out = self.tree.getroot().find(src_root)

        # Customizing the output write path is not required
        opath = base_opath
        if "opath_leaf" in config and config["opath_leaf"] is not None:
            opath = base_opath.with_name(base_opath.name + str(config["opath_leaf"]))

        self.logger.trace(
            "Preparing subtree write of '%s' to '%s', root='%s'",
            tree_out,
            opath,
            tree_out,
        )

        return (tree_out, src_root, opath)


def root_querypath() -> str:
    return "."


[docs] class ExpDef(definition.BaseExpDef): """Read, write, and modify parsed XML files into experiment definitions.""" def __init__( self, input_fpath: pathlib.Path, write_config: tp.Optional[definition.WriterConfig] = None, ) -> None: self.write_config = write_config self.input_fpath = input_fpath self.tree = ET.parse(self.input_fpath) self.root = self.tree.getroot() self.element_adds = definition.ElementAddList() self.attr_chgs = definition.AttrChangeSet() self.logger = logging.getLogger(__name__)
[docs] def n_mods(self) -> tuple[int, int]: return len(self.element_adds), len(self.attr_chgs)
[docs] def write_config_set(self, config: definition.WriterConfig) -> None: """Set the write config for the object. Provided for cases in which the configuration is dependent on whether or not certain tags/element are present in the input file. """ self.write_config = config
[docs] def write(self, base_opath: pathlib.Path) -> None: assert self.write_config is not None, "Can't write without write config" writer = Writer(self.tree) writer(self.write_config, base_opath)
[docs] def flatten(self, keys: list[str]) -> None: raise NotImplementedError("The XML expdef plugin does not support flattening")
[docs] def attr_get(self, path: str, attr: str) -> tp.Optional[tp.Union[str, int, float]]: el = self.root.find(path) if el is not None and attr in el.attrib: return el.attrib[attr] return None
[docs] def attr_change( self, path: str, attr: str, value: tp.Union[str, int, float], noprint: bool = False, ) -> bool: el = self.root.find(path) if el is None: if not noprint: self.logger.warning("Parent element '%s' not found", path) return False if attr not in el.attrib: if not noprint: self.logger.warning("Attribute '%s' not found in path '%s'", attr, path) return False el.attrib[attr] = value self.logger.trace("Modify attr: '%s/%s' = '%s'", path, attr, value) self.attr_chgs.add(definition.AttrChange(path, attr, str(value))) return True
[docs] def attr_add( self, path: str, attr: str, value: tp.Union[str, int, float], noprint: bool = False, ) -> bool: el = self.root.find(path) if el is None: if not noprint: self.logger.warning("Parent element '%s' not found", path) return False if attr in el.attrib: if not noprint: self.logger.warning("Attribute '%s' already in path '%s'", attr, path) return False el.set(attr, value) self.logger.trace("Add new attribute: '%s/%s' = '%s'", path, attr, value) self.attr_chgs.add(definition.AttrChange(path, attr, str(value))) return True
[docs] def has_element(self, path: str) -> bool: return self.root.find(path) is not None
[docs] def has_attr(self, path: str, attr: str) -> bool: el = self.root.find(path) if el is None: return False return attr in el.attrib
[docs] def element_change(self, path: str, tag: str, value: str) -> bool: el = self.root.find(path) if el is None: self.logger.warning("Parent element '%s' not found", path) return False for child in el: if child.tag == tag: child.tag = value self.logger.trace("Modify element: '%s/%s' = '%s'", path, tag, value) return True self.logger.warning("No such element '%s' found in '%s'", tag, path) return False
[docs] def element_remove(self, path: str, tag: str, noprint: bool = False) -> bool: """Remove the first matching element in ``path`` matching ``tag``.""" parent = self.root.find(path) if parent is None: if not noprint: self.logger.warning("Parent node '%s' not found", path) return False victim = parent.find(tag) if victim is None: if not noprint: self.logger.warning("No victim '%s' found in parent '%s'", tag, path) return False parent.remove(victim) return True
[docs] def element_remove_all(self, path: str, tag: str, noprint: bool = False) -> bool: parent = self.root.find(path) if parent is None: if not noprint: self.logger.warning("Parent element '%s' not found", path) return False victims = parent.findall(tag) if not victims: if not noprint: self.logger.warning( "No victims matching '%s' found in parent '%s'", tag, path ) return False for victim in victims: parent.remove(victim) self.logger.trace("Remove matching element: '%s/%s'", path, tag) return True
[docs] def element_add( self, path: str, tag: str, attr: tp.Optional[types.StrDict] = None, allow_dup: bool = True, noprint: bool = False, ) -> bool: """ Add tag name as a child element of enclosing parent. """ parent = self.root.find(path) if parent is None: if not noprint: self.logger.warning("Parent element '%s' not found", path) return False if not allow_dup: if parent.find(tag) is not None: if not noprint: self.logger.warning( "Child element '%s' already in parent '%s'", tag, path ) return False ET.SubElement(parent, tag, attrib=attr if attr else {}) self.logger.trace( "Add new unique element: '%s/%s' = '%s'", path, tag, str(attr), ) else: # Use ET.Element instead of ET.SubElement so that child nodes with # the same 'tag' don't overwrite each other. child = ET.Element(tag, attrib=attr if attr else {}) parent.append(child) self.logger.trace("Add new element: '%s/%s' = '%s'", path, tag, str(attr)) self.element_adds.append(definition.ElementAdd(path, tag, attr, allow_dup)) return True
[docs] def unpickle( fpath: pathlib.Path, ) -> tp.Optional[tp.Union[definition.AttrChangeSet, definition.ElementAddList]]: """Unickle all XML modifications from the pickle file at the path. You don't know how many there are, so go until you get an exception. """ try: return definition.AttrChangeSet.unpickle(fpath) except (EOFError, TypeError): pass try: return definition.ElementAddList.unpickle(fpath) except EOFError: pass raise NotImplementedError
__all__ = ["ExpDef", "unpickle"]