"""
End-to-end enzymatic reaction workflow: extract, scan, MEP, TS, IRC, freq, DFT.

Example:
    mlmm all -i R.pdb P.pdb -c 'GPP,MMT' -l 'GPP:-3,MMT:-1'

For detailed documentation, see: docs/all.md

Table of contents (top-level definitions; refresh manually after structural edits):
    class _EchoState
    def _echo
    def _echo_section
    def _run_cli_main
    def _append_cli_arg
    def _append_toggle_arg
    def _resolve_override_dir
    def _build_effective_args_yaml
    def _inject_coord_type_into_args_yaml
    def _write_ml_region_definition
    def _summarize_existing_bfactor_layers
    def _mm_charge_mapping
    def _mm_mult_mapping
    def _build_mm_parm7
    def _parse_atom_key_from_line
    def _key_variants
    def _parse_scan_lists_literals
    def _format_scan_stage
    def _round_charge_with_note
    def _derive_charge_from_ligand_charge_when_extract_skipped
    def _pdb_needs_elem_fix
    def _enrich_summary
    def _json_safe
    def _copy_structures_to_seg_dir
    def _read_summary
    def _pdb_models_to_coords_and_elems
    def _geom_from_angstrom
    def _load_segment_end_geoms
    def _irc_and_match
    def _save_single_geom_for_tools
    def _run_tsopt_on_hei
    def _write_segment_energy_diagram
    def _build_global_segment_labels
    def _merge_irc_trajectories_to_single_plot
    def _run_freq_for_state
    def _run_opt_for_state
    def _dft_succeeded
    def _dft_energy_ha
    def _run_dft_for_state
    def _show_advanced_help
    def _configure_all_help_visibility
    def cli
"""

from __future__ import annotations

from pathlib import Path
from typing import List, Sequence, Optional, Tuple, Dict, Any
import shutil
import tempfile

import gc
import logging
import sys
import math
import click
from mlmm.cli.common_options import add_coord_type_option, add_precision_option, add_backend_model_option, add_calc_file_option, add_deterministic_option, add_allow_charge_mult_mismatch_option
from mlmm.cli.decorators import make_is_param_explicit
import time
import json
import yaml
import numpy as np
import torch

logger = logging.getLogger(__name__)

# Biopython for PDB parsing (post-processing helpers)
from Bio import PDB

# pysisyphus helpers/constants
from pysisyphus.helpers import geom_loader
from pysisyphus.constants import BOHR2ANG, AU2KCALPERMOL

# Local imports from the package
from mlmm.workflows.extract import extract_api, compute_charge_summary, log_charge_summary
from mlmm.workflows import path_search as _path_search
from mlmm.workflows import path_opt as _path_opt
from mlmm.workflows import opt as _opt_cli
from mlmm.workflows import tsopt as _ts_opt
from mlmm.workflows import freq as _freq_cli
from mlmm.workflows import irc as _irc_cli

from mlmm.io.trj2fig import run_trj2fig
from mlmm.io.summary import write_summary_log
from mlmm.workflows.align_freeze import align_and_refine_sequence_inplace
from mlmm.core.defaults import (
    OUT_DIR_ALL,
    SEGMENTS_DIRNAME,
    THRESH_CHOICES,
    WORK_DIRNAME,
)
from mlmm.core.utils import (
    apply_ref_pdb_override,
    build_energy_diagram,
    close_matplotlib_figures,
    convert_xyz_to_pdb,
    ensure_dir,
    format_elapsed,
    prepare_input_structure,
    collect_single_option_values,
    load_yaml_dict,
    load_pdb_atom_metadata,
    parse_scan_list_triples,
    read_bfactors_from_pdb,
    read_xyz_as_blocks,
    read_xyz_first_last,
    verbose_level,
    xyz_blocks_first_last,
)
from mlmm.cli.decorators import resolve_yaml_sources, load_merged_yaml_cfg
from mlmm.cli.preflight import validate_existing_files, ensure_commands_available
from mlmm.workflows import scan as _scan_cli
from mlmm.domain.add_elem_info import assign_elements as _assign_elem_info
from mlmm.workflows.define_layer import define_layers as _define_layers
from mlmm.backends.mlmm_calc import mlmm as _mlmm_calc
from mlmm.workflows.mm_parm import (
    Args as _AutoMMArgs,
    parse_ligand_charge as _mm_parse_ligand_charge,
    parse_ligand_mult as _mm_parse_ligand_mult,
    run_pipeline as _mm_run,
)

AtomKey = Tuple[str, str, str, str, str, str]

class _EchoState:
    """Encapsulate CLI output state for section-spacing logic."""

    def __init__(self) -> None:
        self._started = False

    def reset(self) -> None:
        self._started = False

    def echo(self, *args, **kwargs) -> None:
        click.echo(*args, **kwargs)
        self._started = True

    def section(self, message: str, **kwargs) -> None:
        # Section banners form the narrative backbone of the pipeline log, so
        # they default to narrative (visible at default verbosity). The leading
        # blank carries the same flag to preserve spacing around a shown banner.
        narrative = kwargs.setdefault("narrative", True)
        if self._started:
            click.echo(narrative=narrative)
        click.echo(message, **kwargs)
        self._started = True


_echo_state = _EchoState()


def _echo(*args, **kwargs) -> None:
    """Echo a line with local output tracking for section spacing.

    Untagged by default (visible at ``-v 3`` inside ``all``). Use
    ``_echo_detail`` for default ``-v 2`` stage details and ``narrative=True``
    for milestone lines.
    """
    _echo_state.echo(*args, **kwargs)


def _echo_detail(*args, **kwargs) -> None:
    """Echo a level-2 detail line with local output tracking."""
    kwargs.setdefault("detail", True)
    _echo_state.echo(*args, **kwargs)


def _echo_section(message: str, **kwargs) -> None:
    """Echo a section header (narrative) with a leading blank line unless first."""
    _echo_state.section(message, **kwargs)


def _emit_final_summary(out_dir: Path | None, time_start: float) -> None:
    """Print a visual `====== Pipeline summary ======` block + Elapsed line.

    Reads ``summary.json`` if present and lifts the most-asked-for numbers
    (status, rate-limiting barrier, reaction energy, reactive-segment count,
    output dir) so the user sees them at the bottom of the log without
    scrolling back through `[diagram] Wrote ...` / `[time] Elapsed Time
    for X:` clutter. Falls back to just the Elapsed line when summary.json
    is absent (dry-run, early failure, TSOPT-only without aggregation).
    """
    summary: Dict[str, Any] = {}
    if out_dir is not None:
        summary_path = Path(out_dir) / "summary.json"
        if summary_path.exists():
            try:
                _loaded = json.loads(summary_path.read_text(encoding="utf-8"))
                if isinstance(_loaded, dict):
                    summary = _loaded
            except (OSError, json.JSONDecodeError):
                summary = {}
    if summary:
        _echo_section("====== Pipeline summary ======")
        status = summary.get("status")
        if status:
            _echo(f"Status: {status}", narrative=True)
        rls = summary.get("rate_limiting_step")
        if isinstance(rls, dict):
            barrier = rls.get("barrier_kcal")
            seg_idx = rls.get("segment")
            method = rls.get("method", "?")
            if barrier is not None:
                _echo(
                    f"Rate-limiting barrier: {float(barrier):.2f} kcal/mol (segment {seg_idx}, method {method})",
                    narrative=True,
                )
        rxn_e = summary.get("overall_reaction_energy_kcal")
        if rxn_e is not None:
            _echo(f"Reaction energy: {float(rxn_e):.2f} kcal/mol", narrative=True)
        n_reactive = summary.get("n_segments_reactive")
        if n_reactive is not None:
            _echo(f"Reactive segments: {n_reactive}", narrative=True)
        # Report the pipeline out-dir the user passed (-o), not the stage
        # sub-dir that summary.json happens to record (e.g. <out>/path_search).
        out_dir_show = str(out_dir) if out_dir is not None else summary.get("out_dir")
        if out_dir_show:
            _echo(f"Output dir: {out_dir_show}", narrative=True)
        _echo(narrative=True)
    _echo(format_elapsed("[all] Elapsed for Whole Pipeline", time_start), narrative=True)


def _run_cli_main(
    cmd_name: str,
    cli_obj,
    args: Sequence[str],
    *,
    on_nonzero: str = "warn",
    on_exception: str = "raise",
    prefix: Optional[str] = None,
) -> None:
    """Run a Click command with temporary argv and consistent error handling."""
    saved = list(sys.argv)
    label = prefix or cmd_name
    # In-proc subcommand dispatch — flag the child's banner / device echo to
    # stay silent so a 4-stage `all` pipeline doesn't reprint the same lines
    # `mlmm-toolkit ver. X` / `[calc] Resolved device: cuda` once per stage.
    from mlmm.core.utils import set_child_mode
    set_child_mode(True)
    try:
        sys.argv = ["mlmm", cmd_name] + list(args)
        _echo("")
        cli_obj.main(args=list(args), standalone_mode=False)
    except SystemExit as e:
        code = getattr(e, "code", 1)
        if code not in (None, 0):
            if on_nonzero == "raise":
                raise click.ClickException(f"[{label}] {cmd_name} exit code {code}.")
            _echo(f"[{label}] WARNING: {cmd_name} exited with code {code}")
    except Exception as e:
        if on_exception == "raise":
            raise click.ClickException(f"[{label}] {cmd_name} failed: {e}")
        _echo(f"[{label}] WARNING: {cmd_name} failed: {e}")
    finally:
        sys.argv = saved
        set_child_mode(False)
        # Release GPU memory between pipeline stages to prevent OOM.
        # Subcommand finally blocks unbind their heavy locals (= None).
        # gc.collect() is needed to break cyclic refs inside torch.nn.Module,
        # then empty_cache() reclaims the CUDA allocator cache.
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        _echo("")



def _append_cli_arg(args: List[str], flag: str, value: Any | None) -> None:
    """Append ``flag`` and ``value`` (converted to string) to ``args`` when ``value`` is not ``None``."""
    if value is None:
        return
    if isinstance(value, bool):
        args.extend([flag, "True" if value else "False"])
    else:
        args.extend([flag, str(value)])


def _append_toggle_arg(args: List[str], flag: str, value: Any | None) -> None:
    """Append Click bool-toggle option as ``--flag`` / ``--no-flag`` when value is not ``None``."""
    if value is None:
        return
    if not isinstance(value, bool):
        raise TypeError(f"Toggle flag '{flag}' requires bool value, got {type(value).__name__}.")
    base = flag if not flag.startswith("--no-") else f"--{flag[5:]}"
    neg = f"--no-{base[2:]}"
    args.append(base if value else neg)


def _resolve_override_dir(default: Path, override: Path | None) -> Path:
    """Return ``override`` when provided (respecting absolute paths); otherwise ``default``."""
    if override is None:
        return default
    if override.is_absolute():
        return override
    return default.parent / override


def _build_effective_args_yaml(
    config_yaml: Optional[Path],
    override_yaml: Optional[Path],
    *,
    tmp_prefix: str,
) -> Tuple[Optional[Path], Dict[str, Any]]:
    """
    Build an effective args-yaml file path.

    Precedence for file layering:
      config_yaml < override_yaml
    """
    merged, base_cfg, override_cfg = load_merged_yaml_cfg(config_yaml, override_yaml)

    if config_yaml is None and override_yaml is None:
        return None, {}
    if config_yaml is None:
        return override_yaml, override_cfg
    if override_yaml is None:
        return config_yaml, base_cfg

    with tempfile.NamedTemporaryFile(
        mode="w",
        encoding="utf-8",
        suffix=".yaml",
        prefix=tmp_prefix,
        delete=False,
    ) as tf:
        yaml.safe_dump(merged, tf, sort_keys=False, allow_unicode=True)
        effective = Path(tf.name).resolve()

    # Register cleanup so the temp file is removed when the process exits.
    import atexit
    atexit.register(lambda p=effective: p.unlink(missing_ok=True))

    return effective, merged


def _inject_coord_type_into_args_yaml(
    args_yaml: Optional[Path],
    coord_type: Optional[str],
    precision: Optional[str] = None,
    backend_model: Optional[str] = None,
    calc_file: Optional[str] = None,
    calc_factory: str = "get_calculator",
) -> Optional[Path]:
    """Inject ``geom.coord_type`` (and optionally ``calc.uma_precision``) into
    the all-pipeline args YAML.

    Used by ``mlmm all --coord-type cart|dlc`` / ``--precision fp32|fp64`` to
    propagate the choice through the all-pipeline args YAML. Only the opt/tsopt
    stages honour ``coord_type`` (DLC is meaningful there via microiteration);
    freq/scan/path stages are fixed to cartesian and ignore it. Returns the
    original ``args_yaml`` unchanged when both ``coord_type`` and ``precision``
    are None.
    """
    if coord_type is None and precision is None and backend_model is None and calc_file is None:
        return args_yaml

    cfg = {} if args_yaml is None else load_yaml_dict(args_yaml)
    if not isinstance(cfg, dict):
        cfg = {}
    if coord_type is not None:
        geom_cfg = cfg.get("geom")
        if not isinstance(geom_cfg, dict):
            geom_cfg = {}
        geom_cfg = dict(geom_cfg)
        geom_cfg["coord_type"] = coord_type
        cfg["geom"] = geom_cfg
    if precision is not None or backend_model is not None or calc_file is not None:
        calc_cfg = cfg.get("calc")
        if not isinstance(calc_cfg, dict):
            calc_cfg = {}
        calc_cfg = dict(calc_cfg)
        # Translate --precision into the per-backend NATIVE kwarg
        # (uma_precision / orb_precision / mace_dtype) HERE and write THAT into
        # the args YAML. Writing the raw ``precision`` token instead leaks an
        # unknown kwarg into a sub-stage's Calculator(**calc_cfg) whenever that
        # stage's own --precision is unset (the stage does not re-translate the
        # YAML token), raising "Calculator.__init__() got an unexpected keyword
        # argument 'precision'". apply_precision_to_calc_cfg also pops any stray
        # raw ``precision`` key, keeping calc_cfg Calculator-clean.
        from mlmm.backends import apply_precision_to_calc_cfg, apply_backend_model_to_calc_cfg, apply_calc_file_to_calc_cfg
        apply_precision_to_calc_cfg(calc_cfg, precision)
        apply_backend_model_to_calc_cfg(calc_cfg, backend_model)
        apply_calc_file_to_calc_cfg(calc_cfg, calc_file, calc_factory)
        cfg["calc"] = calc_cfg

    with tempfile.NamedTemporaryFile(
        mode="w",
        encoding="utf-8",
        suffix=".yaml",
        prefix="mlmm_all_coord_type_",
        delete=False,
    ) as tf:
        yaml.safe_dump(cfg, tf, sort_keys=False, allow_unicode=True)
        new_path = Path(tf.name).resolve()
    import atexit
    atexit.register(lambda p=new_path: p.unlink(missing_ok=True))
    return new_path


def _write_ml_region_definition(pocket_pdb: Path, dest: Path) -> Path:
    """
    Copy ``pocket_pdb`` to ``dest`` for downstream ML/MM commands.

    The copy preserves whatever link-hydrogen policy was used during extraction; set ``--add-linkh False``
    if you need a link-free ML-region definition.
    """
    dest.parent.mkdir(parents=True, exist_ok=True)
    try:
        shutil.copyfile(pocket_pdb, dest)
    except FileNotFoundError:
        raise click.ClickException(f"[all] Pocket PDB not found while building ML region: {pocket_pdb}")
    return dest.resolve()


def _write_bfactor_ml_subset(src_pdb: Path, dest: Path) -> Optional[Path]:
    """Write only the ML-layer (B-factor ≈ 0) ATOM/HETATM records of *src_pdb* to *dest*.

    Used as the ``--model-pdb`` ML-region definition when extraction is skipped but
    the input carries B-factor layers (detect-layer). Without it ml_region.pdb is the
    FULL input (skip_extract copies the whole system), so any downstream stage whose
    detect-layer can't read B-factors (e.g. an XYZ geometry) falls back to --model-pdb
    and treats the ENTIRE system as the ML/QM region (sum_Z huge → electron-count
    error). Returns ``None`` if the input has no B≈0 atoms (caller falls back).
    """
    try:
        ml_lines: List[str] = []
        for ln in open(src_pdb, "r", encoding="utf-8", errors="ignore"):
            if ln.startswith(("ATOM", "HETATM")):
                try:
                    bf = float(ln[60:66])
                except ValueError:
                    continue
                if abs(bf) < 0.5:
                    ml_lines.append(ln)
        if not ml_lines:
            return None
        dest.parent.mkdir(parents=True, exist_ok=True)
        with open(dest, "w", encoding="utf-8") as fh:
            fh.writelines(ml_lines)
            fh.write("END\n")
        return dest.resolve()
    except Exception:
        return None


def _summarize_existing_bfactor_layers(pdb_path: Path) -> Dict[str, int]:
    """Count atoms per B-factor layer (ML=0 / MovableMM=10 / FrozenMM=20).

    Atoms whose B-factor is none of these landmark values are reported under
    ``"other"`` so users can spot non-layered inputs quickly.
    """
    counts = {"ml": 0, "movable": 0, "frozen": 0, "other": 0}
    try:
        with open(pdb_path, "r") as fh:
            for ln in fh:
                if not (ln.startswith("ATOM") or ln.startswith("HETATM")):
                    continue
                try:
                    bf = float(ln[60:66])
                except ValueError:
                    counts["other"] += 1
                    continue
                if abs(bf) < 0.5:
                    counts["ml"] += 1
                elif abs(bf - 10.0) < 0.5:
                    counts["movable"] += 1
                elif abs(bf - 20.0) < 0.5:
                    counts["frozen"] += 1
                else:
                    counts["other"] += 1
    except FileNotFoundError:
        pass
    return counts


def _mm_charge_mapping(expr: Optional[str]) -> Dict[str, int]:
    """Return a ligand-charge mapping for mm_parm when ``expr`` uses RES=Q or RES:Q syntax."""
    if not expr:
        return {}
    if ("=" not in expr) and (":" not in expr):
        return {}
    try:
        return _mm_parse_ligand_charge(expr)
    except Exception as exc:  # pragma: no cover - defensive
        raise click.ClickException(f"[all] Invalid --ligand-charge mapping for mm_parm: {exc}")


def _mm_mult_mapping(expr: Optional[str]) -> Dict[str, int]:
    """Return a ligand-multiplicity mapping for mm_parm when ``expr`` uses RES=M or RES:M syntax."""
    if not expr:
        return {}
    try:
        return _mm_parse_ligand_mult(expr)
    except Exception as exc:  # pragma: no cover - defensive
        raise click.ClickException(f"[all] Invalid --auto-mm-ligand-mult mapping for mm_parm: {exc}")


def _build_mm_parm7(
    pdb: Path,
    ligand_charge_expr: Optional[str],
    ligand_mult_expr: Optional[str],
    out_dir: Path,
    ff_set: str,
    add_ter: bool,
    keep_temp: bool,
) -> Tuple[Path, Path]:
    """Run mm_parm on ``pdb`` and return (parm7, rst7)."""
    out_dir.mkdir(parents=True, exist_ok=True)
    out_prefix = (out_dir / pdb.stem).resolve()
    out_prefix.parent.mkdir(parents=True, exist_ok=True)
    args = _AutoMMArgs(
        pdb=pdb.resolve(),
        out_prefix=str(out_prefix),
        ligand_charge=_mm_charge_mapping(ligand_charge_expr),
        ligand_mult=_mm_mult_mapping(ligand_mult_expr),
        keep_temp=bool(keep_temp),
        add_ter=bool(add_ter),
        add_h=False,
        ph=7.0,
        ff_set=str(ff_set),
        out_prefix_given=True,
    )
    try:
        _mm_run(args)
    except SystemExit as exc:  # pragma: no cover - click exit translation
        code = getattr(exc, "code", 1)
        raise click.ClickException(f"[all] mm_parm exited with code {code}.")
    except Exception as exc:
        raise click.ClickException(f"[all] mm_parm failed: {exc}")

    parm7 = Path(f"{args.out_prefix}.parm7").resolve()
    rst7 = Path(f"{args.out_prefix}.rst7").resolve()
    if not parm7.exists():
        raise click.ClickException(f"[all] mm_parm did not produce parm7 at {parm7}")
    if not rst7.exists():
        raise click.ClickException(f"[all] mm_parm did not produce rst7 at {rst7}")
    return parm7, rst7


def _parse_atom_key_from_line(line: str) -> Optional[AtomKey]:
    """Extract a structural identity key from a PDB ATOM/HETATM record."""
    if not (line.startswith("ATOM") or line.startswith("HETATM")):
        return None
    atomname = line[12:16].strip()
    altloc = (line[16] if len(line) > 16 else " ").strip()
    resname = line[17:20].strip()
    chain = (line[21] if len(line) > 21 else " ").strip()
    resseq = line[22:26].strip()
    icode = (line[26] if len(line) > 26 else " ").strip()
    return (chain, resname, resseq, icode, atomname, altloc)


def _key_variants(key: AtomKey) -> List[AtomKey]:
    """Return key variants with progressively relaxed identity fields (deduplicated)."""
    chain, resn, resseq, icode, atom, alt = key
    raw_variants = [
        (chain, resn, resseq, icode, atom, alt),
        (chain, resn, resseq, icode, atom, ""),
        (chain, resn, resseq, "", atom, alt),
        (chain, resn, resseq, "", atom, ""),
    ]
    seen: set[AtomKey] = set()
    variants: List[AtomKey] = []
    for variant in raw_variants:
        if variant in seen:
            continue
        seen.add(variant)
        variants.append(variant)
    return variants


def _parse_scan_lists_literals(
    scan_lists_raw: Sequence[str],
    atom_meta: Optional[Sequence[Dict[str, Any]]] = None,
    one_based: bool = True,
) -> List[List[Tuple[int, int, float]]]:
    """Parse ``--scan-lists`` literals without re-basing atom indices.

    Parameters
    ----------
    one_based : bool, default True
        Honour the CLI ``--scan-one-based`` toggle so users can pass 0-based
        indices via ``all`` and have them forwarded unchanged to ``scan``.
    """
    stages: List[List[Tuple[int, int, float]]] = []
    for idx_stage, literal in enumerate(scan_lists_raw, start=1):
        tuples, _ = parse_scan_list_triples(
            literal,
            one_based=one_based,
            atom_meta=atom_meta,
            option_name=f"--scan-lists #{idx_stage}",
            return_one_based=one_based,
        )
        if not tuples:
            raise click.BadParameter(
                f"--scan-lists #{idx_stage} must contain at least one (i,j,target) triple."
            )
        stages.append(tuples)
    return stages


def _format_scan_stage(stage: List[Tuple[int, int, float]]) -> str:
    """Serialize a scan stage back into a Python-like literal string."""
    return "[" + ", ".join(f"({i},{j},{target})" for (i, j, target) in stage) + "]"


def _round_charge_with_note(q: float) -> int:
    """
    Cast the extractor's total charge (float) to an integer suitable for the path search.
    If it is not already an integer within 1e-6, round to the nearest integer with a console note.
    """
    q_rounded = int(round(float(q)))
    if not math.isfinite(q):
        raise click.BadParameter(f"Computed total charge is non-finite: {q!r}")
    if abs(float(q) - q_rounded) > 1e-6:
        click.echo(f"[all] NOTE: extractor total charge = {q:g} → rounded to integer {q_rounded} for the path search.")
    return q_rounded


def _derive_charge_from_ligand_charge_when_extract_skipped(
    pdb_path: Path,
    ligand_charge: Optional[str],
) -> Optional[int]:
    """Derive total charge from a PDB using extract-style charge summary.

    *pdb_path* may be a full-complex PDB or a --model-pdb pocket.
    """
    if ligand_charge is None:
        return None
    try:
        parser = PDB.PDBParser(QUIET=True)
        complex_struct = parser.get_structure("complex", str(pdb_path))
        selected_ids = {res.get_full_id() for res in complex_struct.get_residues()}
        summary = compute_charge_summary(complex_struct, selected_ids, set(), ligand_charge)
        log_charge_summary("[all]", summary)
        q_total = float(summary.get("total_charge", 0.0))
        click.echo(f"[all] Charge summary from {pdb_path.name} (--ligand-charge without extraction):")
        click.echo(
            f"  Protein: {summary.get('protein_charge', 0.0):+g},  "
            f"Ligand: {summary.get('ligand_total_charge', 0.0):+g},  "
            f"Ions: {summary.get('ion_total_charge', 0.0):+g},  "
            f"Total: {q_total:+g}"
        )
        return _round_charge_with_note(q_total)
    except Exception as e:
        click.echo(
            f"[all] NOTE: failed to derive total charge from --ligand-charge: {e}",
            err=True,
        )
        return None


def _derive_ml_charge_from_layered_pdb(
    pdb_path: Path,
    ligand_charge: Optional[str],
) -> Optional[int]:
    """Derive the ML-region (B≈0) charge from a B-factor-layered PDB when extraction
    is skipped (ts-only / no -c), reusing extract's ``compute_charge_summary`` WITH
    terminal-cap correction at the ML/MM boundary.

    Needed because ML ⊊ system in ONIOM: summing the whole input gives the total
    system charge (mis-applied as the ML model charge — the ts-only charge bug),
    while summing only the B≈0 atoms misses the backbone-cut terminal caps (off by
    the number of cut termini). The cut residues (peptide neighbor not in the ML
    set) are flagged as N-/C-caps, exactly as the extract path does. Validated:
    CM r4.0 ML=0, R90A ML=-2. Returns ``None`` on any failure so the caller can
    fall back to the full-input derivation.
    """
    if ligand_charge is None:
        return None
    try:
        from mlmm.workflows.extract import (
            compute_charge_summary,
            are_peptide_adjacent,
            AMINO_ACIDS,
        )

        parser = PDB.PDBParser(QUIET=True)
        st = parser.get_structure("complex", str(pdb_path))
        ml_ids = {
            r.get_full_id()
            for r in st.get_residues()
            if any(abs(a.get_bfactor()) < 0.5 for a in r.get_atoms())
        }
        if not ml_ids:
            return None
        keep_ncap = set()
        keep_ccap = set()
        for res in st.get_residues():
            fid = res.get_full_id()
            if fid not in ml_ids or res.get_resname() not in AMINO_ACIDS:
                continue
            chain_res = list(res.get_parent().get_residues())
            idx = chain_res.index(res)
            prev_aa = next(
                (chain_res[j] for j in range(idx - 1, -1, -1) if chain_res[j].get_resname() in AMINO_ACIDS),
                None,
            )
            next_aa = next(
                (chain_res[j] for j in range(idx + 1, len(chain_res)) if chain_res[j].get_resname() in AMINO_ACIDS),
                None,
            )
            if not (prev_aa is not None and are_peptide_adjacent(prev_aa, res) and prev_aa.get_full_id() in ml_ids):
                keep_ncap.add(fid)
            if not (next_aa is not None and are_peptide_adjacent(res, next_aa) and next_aa.get_full_id() in ml_ids):
                keep_ccap.add(fid)
        summary = compute_charge_summary(
            st, ml_ids, set(), ligand_charge,
            keep_ncap_ids=keep_ncap, keep_ccap_ids=keep_ccap,
        )
        q_total = float(summary.get("total_charge", 0.0))
        click.echo(
            f"[all] ML-region charge from {pdb_path.name} "
            f"(detect-layer, extraction skipped; cap-corrected): "
            f"Protein: {summary.get('protein_charge', 0.0):+g},  "
            f"Ligand: {summary.get('ligand_total_charge', 0.0):+g},  "
            f"Total: {q_total:+g}"
        )
        return _round_charge_with_note(q_total)
    except Exception as e:
        click.echo(
            f"[all] NOTE: cap-aware ML-region charge derivation failed: {e}",
            err=True,
        )
        return None


def _pdb_needs_elem_fix(p: Path) -> bool:
    """
    Return True if the PDB has at least one ATOM/HETATM record whose element field (cols 77–78) is empty.
    This is a light-weight check to decide whether to run add_elem_info.
    """
    try:
        with p.open("r", encoding="utf-8", errors="ignore") as fh:
            for line in fh:
                if line.startswith("ATOM") or line.startswith("HETATM"):
                    if len(line) < 78 or not line[76:78].strip():
                        return True
        return False
    except Exception:
        # On I/O errors, skip fixing (use original)
        return False


# ---------- Post-processing helpers (minimal, reuse internals) ----------


def _enrich_summary(
    summary: dict,
    *,
    version: str,
    pipeline_mode: str,
    mlip_backend: str,
    charge: int,
    spin: int,
    command: str = "",
    post_segments: Optional[list] = None,
    config: Optional[dict] = None,
    freeze_atoms: Optional[str] = None,
    out_dir: Optional[Path] = None,
) -> dict:
    """Add machine-readable metadata to summary dict for AI agent consumption.

    The resulting dict is written as summary.json and is intended to be the
    single machine-readable output consumed by MCP tools and AI agents.
    It should contain ALL information present in summary.log.
    """
    try:
        from mlmm._version import __version__
    except Exception:
        __version__ = "unknown"

    segments = summary.get("segments", [])
    reactive = [s for s in segments if s.get("kind") != "bridge"]
    n_reactive = len(reactive)

    has_diagrams = bool(summary.get("energy_diagrams"))
    status = "success" if has_diagrams else ("partial" if segments else "failed")

    best_method = None
    rls = None
    if reactive:
        for diag in reversed(summary.get("energy_diagrams", [])):
            name = diag.get("name", "")
            if "G_UMA" in name and "all" in name:
                best_method = "UMA_Gibbs"; break
            elif "UMA" in name and "all" in name:
                best_method = "UMA"; break
        if best_method is None:
            best_method = "MEP"

        max_barrier = -1e9
        for s in reactive:
            b = s.get("barrier_kcal", 0) or 0
            if b > max_barrier:
                max_barrier = b
                rls = {"segment": s.get("index"), "barrier_kcal": round(b, 2), "method": best_method}

    overall_rxn_e = None
    for diag in reversed(summary.get("energy_diagrams", [])):
        name = diag.get("name", "")
        if "all" in name:
            energies = diag.get("energies_kcal", [])
            if len(energies) >= 2:
                overall_rxn_e = round(energies[-1] - energies[0], 2)
                break

    summary["mlmm_toolkit_version"] = __version__
    summary["pipeline_mode"] = pipeline_mode
    summary["status"] = status
    summary["mlip_backend"] = mlip_backend
    summary["charge"] = charge
    summary["spin"] = spin
    summary["n_segments_reactive"] = n_reactive
    if rls:
        summary["rate_limiting_step"] = rls
    if overall_rxn_e is not None:
        summary["overall_reaction_energy_kcal"] = overall_rxn_e
    if command:
        summary["command"] = command
    if config:
        summary["config"] = config
    if freeze_atoms:
        summary["freeze_atoms"] = freeze_atoms
    if post_segments:
        summary["post_segments"] = _json_safe(post_segments)

    # Key output file paths for AI agent consumption
    if "out_dir" in summary:
        # Real pipeline root. Fall back to the legacy module_dir.parent for
        # any caller that does not pass out_dir explicitly.
        root = Path(out_dir) if out_dir is not None else Path(summary["out_dir"]).parent
        key_files: Dict[str, Any] = {}
        # Root-level deliverables (MEP products + authored/mirrored summaries live at root)
        for name, desc in [
            ("summary.log", "Human-readable results summary"),
            ("summary.json", "Machine-readable results summary"),
            ("mep_trj.xyz", "Full MEP trajectory"),
            ("mep.pdb", "Full MEP as PDB"),
            ("energy_diagram_MEP.png", "MEP energy plot"),
            ("mep_plot.png", "MEP energy plot (trj2fig)"),
            ("irc_plot_all.png", "Aggregated IRC plot"),
        ]:
            if (root / name).exists():
                key_files.setdefault(name, desc)
        # Per-segment deliverables under segments/seg_NN/
        seg_parent = root / SEGMENTS_DIRNAME
        if seg_parent.exists():
            for child in sorted(seg_parent.iterdir()):
                if child.is_dir() and child.name.startswith("seg_"):
                    seg_files = [f.name for f in sorted(child.iterdir()) if f.is_file()]
                    key_files[child.name] = {
                        "description": f"Per-segment results for {child.name}",
                        "files": seg_files,
                    }
        if key_files:
            summary["key_output_files"] = key_files

    try:
        from mlmm.core.utils import _collect_environment_info
        summary.setdefault("environment", _collect_environment_info())
    except Exception:
        pass

    return summary


def _json_safe(obj):
    """Recursively convert Path objects to strings for JSON serialization."""
    if isinstance(obj, Path):
        return str(obj)
    if isinstance(obj, dict):
        return {k: _json_safe(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_json_safe(item) for item in obj]
    return obj


def _copy_structures_to_seg_dir(
    state_structs, out_dir, seg_idx, input_suffix,
    prepared_input=None, ref_pdb_path=None,
):
    """Copy R/TS/P structures to out_dir/segments/seg_XX/ in the input format."""
    seg_dir = out_dir / SEGMENTS_DIRNAME / f"seg_{seg_idx:02d}"
    seg_dir.mkdir(parents=True, exist_ok=True)
    name_map = {"R": "reactant", "TS": "ts", "P": "product"}
    for key, src_xyz in state_structs.items():
        src = Path(src_xyz)
        if not src.exists():
            continue
        dst_name = name_map.get(key, key.lower())
        if input_suffix == ".pdb":
            src_pdb = src.with_suffix(".pdb")
            if src_pdb.exists():
                shutil.copy2(src_pdb, seg_dir / f"{dst_name}.pdb")
            else:
                shutil.copy2(src, seg_dir / f"{dst_name}.xyz")
        elif input_suffix == ".gjf":
            if (prepared_input is not None and getattr(prepared_input, "gjf_template", None) is not None):
                try:
                    from mlmm.core.utils import convert_xyz_to_gjf
                    convert_xyz_to_gjf(src, prepared_input.gjf_template, seg_dir / f"{dst_name}.gjf")
                except Exception:
                    shutil.copy2(src, seg_dir / f"{dst_name}.xyz")
            else:
                shutil.copy2(src, seg_dir / f"{dst_name}.xyz")
        else:
            shutil.copy2(src, seg_dir / f"{dst_name}.xyz")
    return seg_dir


def _read_summary(summary_json: Path) -> List[Dict[str, Any]]:
    """
    Read path_search/summary.json and return segments list (empty if not found).
    """
    try:
        if not summary_json.exists():
            return []
        data = json.loads(summary_json.read_text(encoding="utf-8")) or {}
        segs = data.get("segments", []) or []
        if not isinstance(segs, list):
            return []
        return segs
    except Exception:
        return []


def _pdb_models_to_coords_and_elems(pdb_path: Path) -> Tuple[List[np.ndarray], List[str]]:
    """
    Return ([coords_model1, coords_model2, ...] in Å), [elements] from a multi-model PDB.
    """
    parser = PDB.PDBParser(QUIET=True)
    st = parser.get_structure("seg", str(pdb_path))
    models = list(st.get_models())
    if not models:
        raise click.ClickException(f"[post] No MODEL found in PDB: {pdb_path}")
    # atom order taken from first model
    atoms0 = [a for a in models[0].get_atoms()]
    elems: List[str] = []
    for a in atoms0:
        el = (a.element or "").strip()
        if not el:
            # fall back: derive from atom name
            nm = a.get_name().strip()
            el = "".join([c for c in nm if c.isalpha()])[:2].title() or "C"
        elems.append(el)
    coords_list: List[np.ndarray] = []
    for m in models:
        atoms = [a for a in m.get_atoms()]
        if len(atoms) != len(atoms0):
            raise click.ClickException(f"[post] Atom count mismatch across models in {pdb_path}")
        coords = np.array([a.get_coord() for a in atoms], dtype=float)
        coords_list.append(coords)
    return coords_list, elems


def _geom_from_angstrom(elems: Sequence[str],
                        coords_ang: np.ndarray,
                        freeze_atoms: Sequence[int]) -> Any:
    """
    Create a Geometry from Å coordinates using _path_search._new_geom_from_coords (expects Bohr).
    """
    coords_bohr = np.asarray(coords_ang, dtype=float) / BOHR2ANG
    return _path_search._new_geom_from_coords(elems, coords_bohr, coord_type="cart", freeze_atoms=freeze_atoms)


def _load_segment_end_geoms(seg_pdb: Path, freeze_atoms: Sequence[int]) -> Tuple[Any, Any]:
    """
    Load first/last model as Geometries from a per-segment pocket PDB.
    """
    coords_list, elems = _pdb_models_to_coords_and_elems(seg_pdb)
    gL = _geom_from_angstrom(elems, coords_list[0], freeze_atoms)
    gR = _geom_from_angstrom(elems, coords_list[-1], freeze_atoms)
    return gL, gR


def _irc_and_match(seg_idx: int,
                   seg_dir: Path,
                   ref_pdb_for_seg: Path,
                   seg_pocket_pdb: Path,
                   g_ts: Any,
                   q_int: int,
                   spin: int,
                   mep_dir: Optional[Path] = None,
                   real_parm7: Optional[Path] = None,
                   model_pdb: Optional[Path] = None,
                   detect_layer: bool = False,
                   backend: Optional[str] = None,
                   embedcharge: bool = False,
                   embedcharge_cutoff: Optional[float] = None,
                   embedcharge_explicit: bool = False,
                   link_atom_method: Optional[str] = None,
                   mm_backend: Optional[str] = None,
                   use_cmap: Optional[bool] = None,
                   args_yaml: Optional[Path] = None) -> Dict[str, Any]:
    """
    Run EulerPC IRC from a TS geometry, then map the IRC endpoints to (left, right)
    by comparing bond states with the GSM segment endpoints (when available).
    Falls back to raw IRC orientation in TSOPT-only mode.

    Endpoint matching logic (when GSM endpoints exist):
      - Compute bond change sets at IRC's two endpoints (`bond_changes.compare_structures`).
      - Score each pairing (IRC.fwd, IRC.bwd) ↔ (GSM.left, GSM.right) by symmetric-diff
        bond change count, pick the orientation with minimum total diff.
      - On tie, prefer the orientation whose forward endpoint shares more atoms
        with the GSM reactant side (= side selected by `seg_idx`-based ordering convention).

    TSOPT-only fallback: when no GSM endpoints (= TS-only pipeline), IRC's raw
    forward/backward orientation is preserved as (left, right) without remapping;
    the caller can post-hoc swap if needed.

    GPU memory handling: caller-supplied `g_ts` pins TS-stage allocator pages.
    The fix-C `gc.collect()` + `torch.cuda.empty_cache()` at function entry frees
    the previous stage's residency so IRC's `initial_displacement eigh` (large
    contiguous ~9 GiB block) succeeds.
    """
    # Fix C: free GPU memory carried over from the preceding TS-opt stage
    # before IRC. IRC's initial_displacement eigh needs a large contiguous
    # block; ~9 GiB of TS-stage allocator residency otherwise leaves too little
    # free even after the Fix-A ML-macro Hessian reduction. The per-stage
    # _run_cli_main finally also does this, but orchestrator locals (g_ts, etc.)
    # still pin memory here at the TS->IRC boundary.
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    irc_dir = seg_dir / "irc"
    ensure_dir(irc_dir)

    # Build irc CLI arguments
    irc_args: List[str] = [
        "-i", str(ref_pdb_for_seg),
        "--parm", str(real_parm7),
        "--model-pdb", str(model_pdb),
        "-q", str(int(q_int)),
        "-m", str(int(spin)),
        "--out-dir", str(irc_dir),
    ]
    irc_args.append("--detect-layer" if detect_layer else "--no-detect-layer")
    from mlmm.workflows._all_helpers import append_backend_forwarding_args
    append_backend_forwarding_args(
        irc_args,
        backend=backend,
        embedcharge=embedcharge,
        embedcharge_cutoff=embedcharge_cutoff,
        embedcharge_explicit=embedcharge_explicit,
        link_atom_method=link_atom_method,
        mm_backend=mm_backend,
        use_cmap=use_cmap,
        args_yaml=args_yaml,
    )

    _echo_detail(f"[irc] Running EulerPC IRC → out={irc_dir}")
    try:
        _run_cli_main("irc", _irc_cli.cli, irc_args, on_nonzero="raise", prefix="irc")
    except BaseException:
        # Make the consequence explicit instead of dying mid-recovery with a
        # bare stack: IRC is a hard prerequisite for this segment's IRC
        # endpoint Hessians and the subsequent freq/thermo/DFT, so those are
        # not produced for this segment when IRC fails.
        _echo(
            f"[all] IRC failed for segment {seg_idx}; freq/thermochemistry/DFT "
            f"post-processing for this segment is skipped (the pipeline will "
            f"now abort with the IRC error above).",
            err=True,
        )
        raise

    # Read IRC endpoints
    finished_trj = irc_dir / "finished_irc_trj.xyz"
    finished_pdb = irc_dir / "finished_irc.pdb"
    irc_plot = irc_dir / "irc_plot.png"

    if not finished_trj.exists():
        raise click.ClickException(f"[irc] IRC trajectory not found: {finished_trj}")

    # Convert to PDB if not already done
    if not finished_pdb.exists():
        _path_search._maybe_convert_to_pdb(finished_trj, ref_pdb_path=seg_pocket_pdb, out_path=finished_pdb)

    elems, c_first, c_last = read_xyz_first_last(finished_trj)

    # Create geometries from IRC endpoints
    _irc_calc_kwargs = dict(
        model_charge=int(q_int),
        model_mult=int(spin),
        input_pdb=str(ref_pdb_for_seg),
        real_parm7=str(real_parm7) if real_parm7 else None,
        model_pdb=str(model_pdb) if model_pdb else None,
        use_bfactor_layers=detect_layer,
        backend=backend,
        embedcharge=embedcharge,
    )
    if link_atom_method is not None:
        _irc_calc_kwargs["link_atom_method"] = link_atom_method
    if mm_backend is not None:
        _irc_calc_kwargs["mm_backend"] = mm_backend
    if use_cmap is not None:
        _irc_calc_kwargs["use_cmap"] = use_cmap
    calc = _mlmm_calc(**_irc_calc_kwargs)

    g_left = _path_search._new_geom_from_coords(
        elems, c_first / BOHR2ANG, coord_type="cart", freeze_atoms=[])
    g_right = _path_search._new_geom_from_coords(
        elems, c_last / BOHR2ANG, coord_type="cart", freeze_atoms=[])
    g_left.set_calculator(calc)
    g_right.set_calculator(calc)
    _ = float(g_left.energy)
    _ = float(g_right.energy)

    # Reload TS geometry with energy
    if g_ts.calculator is None:
        g_ts.set_calculator(calc)
    _ = float(g_ts.energy)

    left_tag = "backward"
    right_tag = "forward"
    reverse_irc = False

    # Try to load segment endpoints for mapping.
    # mep_seg_NN.pdb is written by the MEP engine under path_dir (now _work/path_*);
    # seg_dir moved to segments/, so read from mep_dir when provided.
    gL_end = None
    gR_end = None
    mep_root = mep_dir if mep_dir is not None else seg_dir.parent
    seg_pocket_path = mep_root / f"mep_seg_{seg_idx:02d}.pdb"
    if seg_pocket_path.exists():
        try:
            gL_end, gR_end = _load_segment_end_geoms(seg_pocket_path, [])
        except Exception as e:
            click.echo(f"[post] WARNING: failed to load segment endpoints: {e}", err=True)

    # Map IRC endpoints to left/right using bond-change analysis
    if gL_end is not None and gR_end is not None:
        bond_cfg = dict(_path_search.BOND_KW)

        def _matches(x, y) -> bool:
            try:
                chg, _ = _path_search._has_bond_change(x, y, bond_cfg)
                return not chg
            except Exception:
                return False

        def _rmsd_cart(g1, g2) -> float:
            c1 = np.asarray(g1.coords).reshape(-1, 3)
            c2 = np.asarray(g2.coords).reshape(-1, 3)
            n = min(len(c1), len(c2))
            return float(np.sqrt(np.mean((c1[:n] - c2[:n]) ** 2)))

        # Check if IRC endpoints need swapping
        match_LL = _matches(g_left, gL_end)
        match_LR = _matches(g_left, gR_end)
        match_RL = _matches(g_right, gL_end)
        match_RR = _matches(g_right, gR_end)

        if match_LR and match_RL and not (match_LL and match_RR):
            # Swap: IRC backward→right, forward→left
            g_left, g_right = g_right, g_left
            left_tag, right_tag = right_tag, left_tag
            reverse_irc = True
        elif not (match_LL and match_RR):
            # RMSD-based fallback
            d_LL = _rmsd_cart(g_left, gL_end)
            d_LR = _rmsd_cart(g_left, gR_end)
            d_RL = _rmsd_cart(g_right, gL_end)
            d_RR = _rmsd_cart(g_right, gR_end)
            if (d_LR + d_RL) < (d_LL + d_RR):
                g_left, g_right = g_right, g_left
                left_tag, right_tag = right_tag, left_tag
                reverse_irc = True

    return {
        "left_min_geom": g_left,
        "right_min_geom": g_right,
        "ts_geom": g_ts,
        "left_tag": left_tag,
        "right_tag": right_tag,
        "irc_trj": str(finished_trj) if finished_trj.exists() else None,
        "irc_plot": str(irc_plot) if irc_plot.exists() else None,
        "reverse_irc": reverse_irc,
    }


def _save_single_geom_for_tools(g: Any, ref_pdb: Path, out_dir: Path, name: str) -> Tuple[Path, Path]:
    """
    Write XYZ (primary, full precision) + PDB (companion) for a single geometry.
    Returns (xyz_path, pdb_path).
    """
    out_dir.mkdir(parents=True, exist_ok=True)
    # XYZ — full precision
    xyz_out = out_dir / f"{name}.xyz"
    with open(xyz_out, "w") as f:
        f.write(g.as_xyz() + "\n")
    # TRJ with energy (for PDB conversion and trajectory viewers)
    xyz_trj = out_dir / f"{name}_trj.xyz"
    _path_search._write_xyz_trj_with_energy([g], [float(g.energy)], xyz_trj)
    # PDB companion
    pdb_out = out_dir / f"{name}.pdb"
    _path_search._maybe_convert_to_pdb(xyz_trj, ref_pdb_path=ref_pdb, out_path=pdb_out)
    return xyz_out, pdb_out


def _run_tsopt_on_hei(hei_pdb: Path,
                      charge: int,
                      spin: int,
                      real_parm7: Path,
                      model_pdb: Path,
                      detect_layer: bool,
                      args_yaml: Optional[Path],
                      out_dir: Path,
                      opt_mode_default: str,
                      overrides: Optional[Dict[str, Any]] = None,
                      backend: Optional[str] = None,
                      embedcharge: bool = False,
                      embedcharge_cutoff: Optional[float] = None,
                      embedcharge_explicit: bool = False,
                      link_atom_method: Optional[str] = None,
                      mm_backend: Optional[str] = None,
                      use_cmap: Optional[bool] = None,
                      ref_pdb: Optional[Path] = None) -> Tuple[Path, Any]:
    """
    Run tsopt CLI on a HEI structure; return (final_ts_pdb_path, ts_geom).

    When *ref_pdb* (layered PDB with B-factor layer info) is given, the HEI XYZ
    is used as input and *ref_pdb* is passed via ``--ref-pdb`` so that the
    calculator correctly detects ML/MM layers from B-factors.
    """
    overrides = overrides or {}
    # Prefer HEI XYZ (full precision) + layered ref-pdb (B-factor layer info)
    hei_xyz = hei_pdb.with_suffix(".xyz")
    if ref_pdb is not None and hei_xyz.exists():
        input_file = hei_xyz
        topology_pdb = ref_pdb
    else:
        input_file = hei_pdb
        topology_pdb = hei_pdb
    prepared_input = prepare_input_structure(input_file)
    if input_file.suffix.lower() == ".xyz" and ref_pdb is not None:
        apply_ref_pdb_override(prepared_input, ref_pdb)
    try:
        ts_dir = _resolve_override_dir(out_dir / "ts", overrides.get("out_dir"))
        ensure_dir(ts_dir)

        opt_mode = overrides.get("opt_mode", opt_mode_default)

        ts_args: List[str] = [
            "-i", str(prepared_input.geom_path),
        ]
        if input_file.suffix.lower() == ".xyz" and ref_pdb is not None:
            ts_args.extend(["--ref-pdb", str(ref_pdb)])
        ts_args.extend([
            "--parm", str(real_parm7),
            "--model-pdb", str(model_pdb),
            "-q", str(int(charge)),
            "-m", str(int(spin)),
            "--out-dir", str(ts_dir),
        ])
        ts_args.append("--detect-layer" if detect_layer else "--no-detect-layer")

        if opt_mode is not None:
            ts_args.extend(["--opt-mode", str(opt_mode)])

        _append_cli_arg(ts_args, "--max-cycles", overrides.get("max_cycles"))
        _append_toggle_arg(ts_args, "--dump", overrides.get("dump"))
        _append_toggle_arg(ts_args, "--convert-files", overrides.get("convert_files"))
        _append_cli_arg(ts_args, "--thresh", overrides.get("thresh"))
        _append_toggle_arg(ts_args, "--flatten", overrides.get("flatten"))

        hess_mode = overrides.get("hessian_calc_mode")
        if hess_mode:
            ts_args.extend(["--hessian-calc-mode", str(hess_mode)])

        if args_yaml is not None:
            ts_args.extend(["--config", str(args_yaml)])

        from mlmm.workflows._all_helpers import append_backend_forwarding_args
        append_backend_forwarding_args(
            ts_args,
            backend=backend,
            embedcharge=embedcharge,
            embedcharge_cutoff=embedcharge_cutoff,
            embedcharge_explicit=embedcharge_explicit,
            link_atom_method=link_atom_method,
            mm_backend=mm_backend,
            use_cmap=use_cmap,
        )
        if overrides.get("skip_final_freq"):
            ts_args.append("--skip-final-freq")

        _echo_detail(f"[tsopt] Running tsopt on HEI → out={ts_dir}")
        _run_cli_main("tsopt", _ts_opt.cli, ts_args, on_nonzero="raise", prefix="tsopt")

        # Prefer XYZ (full precision) for geometry loading; PDB for topology
        final_xyz = ts_dir / "final_geometry.xyz"
        ts_pdb = ts_dir / "final_geometry.pdb"
        if not ts_pdb.exists() and final_xyz.exists():
            _path_search._maybe_convert_to_pdb(final_xyz, topology_pdb, ts_pdb)
        if not final_xyz.exists() and not ts_pdb.exists():
            raise click.ClickException("[tsopt] TS outputs not found.")
        geom_src = final_xyz if final_xyz.exists() else ts_pdb
        g_ts = geom_loader(geom_src, coord_type="cart")

        # Ensure calculator to have energy on g_ts
        _ts_calc_kwargs = dict(
            model_charge=int(charge),
            model_mult=int(spin),
            input_pdb=str(topology_pdb),
            real_parm7=str(real_parm7),
            model_pdb=str(model_pdb),
            use_bfactor_layers=detect_layer,
            backend=backend,
            embedcharge=embedcharge,
        )
        if link_atom_method is not None:
            _ts_calc_kwargs["link_atom_method"] = link_atom_method
        if mm_backend is not None:
            _ts_calc_kwargs["mm_backend"] = mm_backend
        if use_cmap is not None:
            _ts_calc_kwargs["use_cmap"] = use_cmap
        calc = _mlmm_calc(**_ts_calc_kwargs)
        g_ts.set_calculator(calc)
        _ = float(g_ts.energy)

        return ts_pdb, g_ts
    finally:
        prepared_input.cleanup()


def _write_segment_energy_diagram(
    prefix: Path,
    labels: List[str],
    energies_eh: List[float],
    title_note: str,
    ylabel: str = "ΔE (kcal/mol)",
    write_html: bool = False,
) -> Optional[Dict[str, Any]]:
    """
    Write energy diagram (PNG only) using utils.build_energy_diagram.
    """
    if not energies_eh:
        return None
    e0 = energies_eh[0]
    energies_kcal = [(e - e0) * AU2KCALPERMOL for e in energies_eh]
    fig = build_energy_diagram(
        energies=energies_kcal,
        labels=labels,
        ylabel=ylabel,
        baseline=True,
        showgrid=False,
    )
    if title_note:
        fig.update_layout(title=title_note)
    png = prefix.with_suffix(".png")
    try:
        fig.write_image(str(png), scale=2)
    except Exception as e:
        click.echo(f"[diagram] NOTE: PNG export skipped (install 'kaleido' to enable): {e}", err=True)
    else:
        click.echo(f"[diagram] Wrote energy diagram → {png.name}", detail=True)

    payload: Dict[str, Any] = {
        "name": prefix.stem,
        "labels": labels,
        "energies_kcal": energies_kcal,
        "ylabel": ylabel,
        "energies_au": list(energies_eh),
        "image": str(png),
    }
    if title_note:
        payload["title"] = title_note
    return payload


def _build_global_segment_labels(n_segments: int) -> List[str]:
    """
    Build GSM-like labels for aggregated R/TS/P diagrams over multiple segments.

    Pattern:
      - n = 1: ["R", "TS1", "P"]
      - n >= 2: R, TS1, IM1_1, IM1_2, TS2, IM2_1, IM2_2, ..., TSN, P
    """
    if n_segments <= 0:
        return []
    if n_segments == 1:
        return ["R", "TS1", "P"]

    labels: List[str] = []
    for seg_idx in range(1, n_segments + 1):
        if seg_idx == 1:
            labels.extend(["R", "TS1", "IM1_1"])
        elif seg_idx == n_segments:
            labels.extend([f"IM{seg_idx - 1}_2", f"TS{seg_idx}", "P"])
        else:
            labels.extend(
                [f"IM{seg_idx - 1}_2", f"TS{seg_idx}", f"IM{seg_idx}_1"]
            )
    return labels


def _merge_irc_trajectories_to_single_plot(
    trj_and_flags: Sequence[Tuple[Path, bool]],
    out_png: Path,
) -> None:
    """
    Build a single IRC plot over all reactive segments using trj2fig.
    """
    all_blocks: List[str] = []
    for trj_path, reverse in trj_and_flags:
        if not isinstance(trj_path, Path) or not trj_path.exists():
            continue
        try:
            blocks = read_xyz_as_blocks(trj_path)
        except click.ClickException as e:
            click.echo(str(e), err=True)
            continue
        if not blocks:
            continue
        if reverse:
            blocks = list(reversed(blocks))
        all_blocks.extend("\n".join(b) for b in blocks)

    if not all_blocks:
        return

    tmp_trj = out_png.with_name(f"{out_png.stem}_trj.xyz")
    ensure_dir(tmp_trj.parent)
    try:
        tmp_trj.write_text("\n".join(all_blocks) + "\n", encoding="utf-8")
    except Exception as e:
        click.echo(f"[irc_all] WARNING: Failed to write concatenated IRC trajectory: {e}", err=True)
        return

    try:
        run_trj2fig(tmp_trj, [out_png], unit="kcal", reference="init", reverse_x=False)
        click.echo(f"[irc_all] Wrote aggregated IRC plot → {out_png}")
    except Exception as e:
        click.echo(f"[irc_all] WARNING: failed to plot concatenated IRC trajectory: {e}", err=True)
    finally:
        try:
            tmp_trj.unlink()
        except Exception:
            logger.debug("Failed to unlink temp trajectory file", exc_info=True)


def _run_freq_for_state(pdb_path: Path,
                        q_int: int,
                        spin: int,
                        real_parm7: Path,
                        model_pdb: Path,
                        detect_layer: bool,
                        out_dir: Path,
                        args_yaml: Optional[Path],
                        overrides: Optional[Dict[str, Any]] = None,
                        backend: Optional[str] = None,
                        embedcharge: bool = False,
                        embedcharge_cutoff: Optional[float] = None,
                        embedcharge_explicit: bool = False,
                        link_atom_method: Optional[str] = None,
                        mm_backend: Optional[str] = None,
                        use_cmap: Optional[bool] = None,
                        xyz_path: Optional[Path] = None) -> Dict[str, Any]:
    """
    Run freq CLI; return parsed thermo dict (may be empty).
    When *xyz_path* is given, use it for full-precision coordinates with
    *pdb_path* as topology reference (--ref-pdb).
    """
    fdir = out_dir
    ensure_dir(fdir)
    overrides = overrides or {}

    dump_use = overrides.get("dump")
    if dump_use is None:
        dump_use = True

    # Prefer XYZ (full precision) with --ref-pdb for topology
    if xyz_path is not None and xyz_path.exists():
        args = ["-i", str(xyz_path), "--ref-pdb", str(pdb_path)]
    else:
        args = ["-i", str(pdb_path)]
    args.extend([
        "--parm", str(real_parm7),
        "--model-pdb", str(model_pdb),
        "-q", str(int(q_int)),
        "-m", str(int(spin)),
        "--out-dir", str(fdir),
    ])
    args.append("--detect-layer" if detect_layer else "--no-detect-layer")

    _append_cli_arg(args, "--max-write", overrides.get("max_write"))
    _append_cli_arg(args, "--amplitude-ang", overrides.get("amplitude_ang"))
    _append_cli_arg(args, "--n-frames", overrides.get("n_frames"))
    if overrides.get("sort") is not None:
        args.extend(["--sort", str(overrides.get("sort"))])
    _append_cli_arg(args, "--temperature", overrides.get("temperature"))
    _append_cli_arg(args, "--pressure", overrides.get("pressure"))
    _append_toggle_arg(args, "--dump", dump_use)
    _append_toggle_arg(args, "--convert-files", overrides.get("convert_files"))

    hess_mode = overrides.get("hessian_calc_mode")
    if hess_mode:
        args.extend(["--hessian-calc-mode", str(hess_mode)])

    from mlmm.workflows._all_helpers import append_backend_forwarding_args
    append_backend_forwarding_args(
        args,
        backend=backend,
        embedcharge=embedcharge,
        embedcharge_cutoff=embedcharge_cutoff,
        embedcharge_explicit=embedcharge_explicit,
        link_atom_method=link_atom_method,
        mm_backend=mm_backend,
        use_cmap=use_cmap,
        args_yaml=args_yaml,
    )
    _run_cli_main("freq", _freq_cli.cli, args, on_nonzero="warn", on_exception="raise", prefix="freq")
    # parse thermoanalysis.yaml if any
    y = fdir / "thermoanalysis.yaml"
    if y.exists():
        try:
            return yaml.safe_load(y.read_text(encoding="utf-8")) or {}
        except Exception:
            return {}
    return {}


def _run_opt_for_state(
    pdb_path: Path,
    q_int: int,
    spin: int,
    real_parm7: Path,
    model_pdb: Path,
    detect_layer: bool,
    out_dir: Path,
    args_yaml: Optional[Path],
    opt_mode_default: str,
    convert_files: Optional[bool] = None,
    backend: Optional[str] = None,
    embedcharge: bool = False,
    embedcharge_cutoff: Optional[float] = None,
    embedcharge_explicit: bool = False,
    link_atom_method: Optional[str] = None,
    mm_backend: Optional[str] = None,
    use_cmap: Optional[bool] = None,
    thresh: Optional[str] = None,
    xyz_path: Optional[Path] = None,
) -> Tuple[Any, Path]:
    """
    Run opt CLI for a single endpoint and return (optimized Geometry, final geometry path).
    When *xyz_path* is given, pass it as ``-i`` with ``--ref-pdb pdb_path`` to
    preserve full coordinate precision.
    """
    opt_dir = out_dir
    ensure_dir(opt_dir)

    # Use XYZ (full precision) when available; fall back to PDB
    if xyz_path is not None and xyz_path.exists():
        prepared_input = prepare_input_structure(xyz_path)
        apply_ref_pdb_override(prepared_input, pdb_path)
        input_label = xyz_path.name
    else:
        prepared_input = prepare_input_structure(pdb_path)
        input_label = pdb_path.name
    try:
        opt_mode = str(opt_mode_default or "heavy").lower()
        args = [
            "-i", str(prepared_input.geom_path),
        ]
        # Add --ref-pdb when input is XYZ
        if prepared_input.geom_path.suffix.lower() == ".xyz":
            args.extend(["--ref-pdb", str(prepared_input.source_path)])
        args.extend([
            "--parm", str(real_parm7),
            "--model-pdb", str(model_pdb),
            "-q", str(int(q_int)),
            "-m", str(int(spin)),
            "--out-dir", str(opt_dir),
            "--opt-mode", opt_mode,
        ])
        args.append("--detect-layer" if detect_layer else "--no-detect-layer")
        _append_toggle_arg(args, "--convert-files", convert_files)
        _append_cli_arg(args, "--thresh", thresh)

        if args_yaml is not None:
            args.extend(["--config", str(args_yaml)])

        from mlmm.workflows._all_helpers import append_backend_forwarding_args
        append_backend_forwarding_args(
            args,
            backend=backend,
            embedcharge=embedcharge,
            embedcharge_cutoff=embedcharge_cutoff,
            embedcharge_explicit=embedcharge_explicit,
            link_atom_method=link_atom_method,
            mm_backend=mm_backend,
            use_cmap=use_cmap,
        )

        _echo_detail(f"[endpoint-opt] Running opt on {input_label} (mode={opt_mode}) → out={opt_dir}")
        _run_cli_main("opt", _opt_cli.cli, args, on_nonzero="raise", on_exception="raise", prefix="endpoint-opt")

        final_pdb = opt_dir / "final_geometry.pdb"
        final_xyz = opt_dir / "final_geometry.xyz"
        # Prefer XYZ (full precision) for geometry loading
        if final_xyz.exists():
            final_geom_path = final_xyz
        elif final_pdb.exists():
            final_geom_path = final_pdb
        else:
            raise click.ClickException(f"[endpoint-opt] opt outputs not found under {opt_dir}")

        g_opt = geom_loader(final_geom_path, coord_type="cart")
        calc_input_pdb = final_pdb if final_pdb.exists() else pdb_path
        _opt_calc_kwargs = dict(
            model_charge=int(q_int),
            model_mult=int(spin),
            input_pdb=str(calc_input_pdb),
            real_parm7=str(real_parm7),
            model_pdb=str(model_pdb),
            use_bfactor_layers=detect_layer,
            backend=backend,
            embedcharge=embedcharge,
        )
        if link_atom_method is not None:
            _opt_calc_kwargs["link_atom_method"] = link_atom_method
        if mm_backend is not None:
            _opt_calc_kwargs["mm_backend"] = mm_backend
        if use_cmap is not None:
            _opt_calc_kwargs["use_cmap"] = use_cmap
        calc = _mlmm_calc(**_opt_calc_kwargs)
        g_opt.set_calculator(calc)
        _ = float(g_opt.energy)

        return g_opt, final_geom_path
    finally:
        prepared_input.cleanup()


def _dft_succeeded(result: Dict[str, Any]) -> bool:
    """Return True only if DFT converged and produced a valid energy."""
    return bool(result) and not result.get("_dft_failed", True)


def _dft_energy_ha(result: Dict[str, Any]) -> Optional[float]:
    """Extract DFT energy in hartree, or None if DFT failed."""
    if not _dft_succeeded(result):
        return None
    try:
        return float((result.get("energy") or {}).get("hartree"))
    except (TypeError, ValueError):
        return None


def _finite_float(value: Any) -> Optional[float]:
    try:
        fval = float(value)
    except (TypeError, ValueError):
        return None
    if not np.isfinite(fval):
        return None
    return fval


def _format_state_values(values: Dict[str, Optional[float]], *, precision: int) -> str:
    parts: List[str] = []
    for label in ("R", "TS", "P"):
        val = values.get(label)
        if val is None:
            parts.append(f"{label}=n/a")
        else:
            parts.append(f"{label}={val:.{precision}f}")
    return " ".join(parts)


def _scale_energy_values(
    values_ha: Dict[str, Optional[float]],
    scale: float,
) -> Dict[str, Optional[float]]:
    return {label: (val * scale if val is not None else None) for label, val in values_ha.items()}


def _relative_energy_values_kcal(values_ha: Dict[str, Optional[float]]) -> Optional[Dict[str, Optional[float]]]:
    ref = values_ha.get("R")
    if ref is None:
        return None
    return {
        label: ((val - ref) * AU2KCALPERMOL if val is not None else None)
        for label, val in values_ha.items()
    }


def _echo_energy_triplet(
    tag: str,
    seg_idx: int,
    label: str,
    values: Dict[str, Optional[float]],
    *,
    unit: str,
    precision: int,
) -> None:
    if not any(val is not None for val in values.values()):
        return
    _echo_detail(
        f"[{tag}] Segment {seg_idx:02d} {label} ({unit}): "
        f"{_format_state_values(values, precision=precision)}"
    )


def _thermo_correction_values(
    payloads: Dict[str, Dict[str, Any]],
    key: str,
) -> Dict[str, Optional[float]]:
    return {
        label: _finite_float((payloads.get(label) or {}).get(key))
        for label in ("R", "TS", "P")
    }


def _dft_total_mlmm_energy_ha(result: Dict[str, Any]) -> Optional[float]:
    if not _dft_succeeded(result):
        return None
    return _finite_float((result.get("mlmm_energy") or {}).get("E_total_ml_dft_mm_hartree"))


def _run_dft_for_state(pdb_path: Path,
                       q_int: int,
                       spin: int,
                       real_parm7: Path,
                       model_pdb: Path,
                       detect_layer: bool,
                       out_dir: Path,
                       args_yaml: Optional[Path],
                       func_basis: str = "wb97m-v/def2-tzvpd",
                       overrides: Optional[Dict[str, Any]] = None,
                       backend: Optional[str] = None,
                       embedcharge: bool = False,
                       embedcharge_cutoff: Optional[float] = None,
                       embedcharge_explicit: bool = False,
                       link_atom_method: Optional[str] = None,
                       mm_backend: Optional[str] = None,
                       use_cmap: Optional[bool] = None,
                       xyz_path: Optional[Path] = None) -> Dict[str, Any]:
    """
    Run dft CLI; return parsed result.yaml dict (may be empty).
    When *xyz_path* is given, use it for full-precision coordinates with
    *pdb_path* as topology reference (--ref-pdb).
    """
    ddir = out_dir
    ensure_dir(ddir)
    overrides = overrides or {}

    func_basis_use = overrides.get("func_basis", func_basis)

    # Prefer XYZ (full precision) with --ref-pdb for topology
    if xyz_path is not None and xyz_path.exists():
        args = ["-i", str(xyz_path), "--ref-pdb", str(pdb_path)]
    else:
        args = ["-i", str(pdb_path)]
    args.extend([
        "--parm", str(real_parm7),
        "--model-pdb", str(model_pdb),
        "-q", str(int(q_int)),
        "-m", str(int(spin)),
        "--func-basis", str(func_basis_use),
        "--out-dir", str(ddir),
    ])
    args.append("--detect-layer" if detect_layer else "--no-detect-layer")

    _append_cli_arg(args, "--max-cycle", overrides.get("max_cycle"))
    _append_cli_arg(args, "--conv-tol", overrides.get("conv_tol"))
    _append_cli_arg(args, "--grid-level", overrides.get("grid_level"))
    _append_cli_arg(args, "--engine", overrides.get("engine"))
    _append_toggle_arg(args, "--convert-files", overrides.get("convert_files"))

    from mlmm.workflows._all_helpers import append_backend_forwarding_args
    append_backend_forwarding_args(
        args,
        backend=backend,
        embedcharge=embedcharge,
        embedcharge_cutoff=embedcharge_cutoff,
        embedcharge_explicit=embedcharge_explicit,
        link_atom_method=link_atom_method,
        mm_backend=mm_backend,
        use_cmap=use_cmap,
        args_yaml=args_yaml,
    )
    # Run DFT as a real subprocess to avoid libcusolver conflict with torch.
    # The MLIP stack (UMA / ORB / MACE via torch) and gpu4pyscf both link
    # against libcusolver but pin different versions; running DFT in the same
    # Python process triggers a dynamic-loader clash. A fresh interpreter is
    # the only reliable isolation.
    # Free GPU memory before spawning the DFT subprocess so it can claim VRAM.
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    import subprocess as _sp
    cmd = [sys.executable, "-m", "mlmm", "dft"] + list(args)
    _echo(f"\n[dft] subprocess: {' '.join(cmd)}")
    proc = _sp.run(cmd, capture_output=True, text=True)
    if proc.stdout:
        _echo(proc.stdout.rstrip())
    if proc.returncode != 0:
        _echo(f"[dft] WARNING: dft exited with code {proc.returncode}", err=True)
        if proc.stderr:
            _echo(proc.stderr.rstrip(), err=True)
    y = out_dir / "result.yaml"
    if y.exists():
        try:
            data = yaml.safe_load(y.read_text(encoding="utf-8")) or {}
        except Exception as exc:
            logger.debug("Failed to parse DFT result YAML %s: %s", y, exc)
            data = {}
    else:
        data = {}
    converged = (data.get("energy") or {}).get("converged", False)
    data["_dft_converged"] = bool(converged)
    data["_dft_failed"] = not bool(converged) or proc.returncode != 0
    return data



_ALL_PRIMARY_HELP_OPTIONS = frozenset(
    {
        "-i",
        "--input",
        "-c",
        "--center",
        "-l",
        "--ligand-charge",
        "-q",
        "--charge",
        "--out-dir",
        "--tsopt",
        "--thermo",
        "--dft",
        "--dft-func-basis",
        "--config",
        "--dry-run",
        "--embedcharge",
        "-s",
        "--scan-lists",
        "-b",
        "--backend",
        "-o",
        "--help-advanced",
    }
)


def _show_advanced_help(
    ctx: click.Context, _param: click.Parameter, value: bool
) -> None:
    """Print full option help (including hidden advanced options) and exit."""
    if not value or ctx.resilient_parsing:
        return

    hidden = getattr(ctx.command, "_advanced_hidden_options", ())
    restored: list[click.Option] = []
    for opt in hidden:
        if opt.hidden:
            opt.hidden = False
            restored.append(opt)
    try:
        click.echo(ctx.command.get_help(ctx))
    finally:
        for opt in restored:
            opt.hidden = True
    ctx.exit()


def _configure_all_help_visibility(command: click.Command) -> None:
    """Hide advanced options from default --help while keeping them functional."""
    hidden_options: list[click.Option] = []
    for param in command.params:
        if not isinstance(param, click.Option):
            continue
        names = set(param.opts + param.secondary_opts)
        if names & _ALL_PRIMARY_HELP_OPTIONS:
            continue
        if param.hidden:
            continue
        param.hidden = True
        hidden_options.append(param)
    setattr(command, "_advanced_hidden_options", tuple(hidden_options))


@click.command(
    help="Run pocket extraction → (optional single-structure staged scan) → MEP search in one shot.\n"
         "If exactly one input is provided: (a) with --scan-lists, stage results feed into path-opt (or path_search with --refine-path); "
         "(b) with --tsopt True and no --scan-lists, run TSOPT-only mode.",
    context_settings={
        "help_option_names": ["-h", "--help"],
        "ignore_unknown_options": True,
        "allow_extra_args": True,
    },
)
@click.option(
    "--help-advanced",
    is_flag=True,
    is_eager=True,
    expose_value=False,
    callback=_show_advanced_help,
    help="Show all options (including advanced settings) and exit.",
)
# ===== Inputs =====
@click.option(
    "-i", "--input", "input_paths",
    type=click.Path(path_type=Path, exists=True, dir_okay=False),
    multiple=True, required=True,
    help=("Two or more **full** PDBs in reaction order (reactant [intermediates ...] product), "
          "or a single **full** PDB (with --scan-lists or with --tsopt True). "
          "You may pass a single '-i' followed by multiple space-separated files (e.g., '-i A.pdb B.pdb C.pdb').")
)
@click.option(
    "-c", "--center", "center_spec",
    type=str, required=False, default=None,
    help=("Substrate specification for the extractor: "
          "a PDB path, a residue-ID list like '123,124' or 'A:123,B:456' "
          "(insertion codes OK: '123A' / 'A:123A'), "
          "or a residue-name list like 'GPP,MMT'. "
          "When omitted, extraction is skipped and full structures are used directly.")
)
@click.option(
    "-o", "--out-dir", "out_dir",
    type=click.Path(path_type=Path, file_okay=False),
    default=Path(OUT_DIR_ALL), show_default=True,
    help="Top-level output directory for the pipeline."
)
# ===== Extractor knobs (subset of extract.parse_args) =====
@click.option("-r", "--radius", type=float, default=2.6, show_default=True,
              help="Inclusion cutoff (Å) around substrate atoms.")
@click.option("--radius-het2het", type=float, default=0.0, show_default=True,
              help="Independent hetero–hetero cutoff (Å) for non‑C/H pairs.")
@click.option("--include-h2o", "include_h2o", type=click.BOOL, default=True, show_default=True,
              help="Include waters (HOH/WAT/H2O/DOD/TIP/TIP3/SOL) in the pocket.")
@click.option("--exclude-backbone", "exclude_backbone", type=click.BOOL, default=False, show_default=True,
              help="Remove backbone atoms on non‑substrate amino acids (with PRO/HYP safeguards).")
@click.option("--add-linkh", "add_linkh", type=click.BOOL, default=False, show_default=True,
              help="Add link hydrogens for severed bonds (carbon-only) in pockets.")
@click.option("--selected-resn", type=str, default="", show_default=True,
              help="Force-include residues (comma/space separated; chain/insertion codes allowed).")
@click.option("--modified-residue", type=str, default="", show_default=True,
              help=("Comma-separated residue names (with optional charge) to treat as amino acids "
                    "for backbone truncation and charge assignment. "
                    "Examples: 'HD1,HD2,HD3' or 'HD1:0,SEP:-2'."))
@click.option("-l", "--ligand-charge", type=str, default=None,
              help=("Either a total charge (number) to distribute across unknown residues "
                    "or a mapping like 'GPP:-3,MMT:-1'."))
@click.option(
    "-q",
    "--charge",
    "charge_override",
    type=int,
    default=None,
    help="Force total system charge. Highest priority over derived charges.",
)
@click.option(
    "--parm",
    "parm7_override",
    type=click.Path(path_type=Path, exists=True, dir_okay=False),
    default=None,
    help="Pre-built AMBER parm7 topology file. When provided, mm_parm generation is skipped.",
)
@click.option(
    "--model-pdb",
    "model_pdb_override",
    type=click.Path(path_type=Path, exists=True, dir_okay=False),
    default=None,
    help="Pre-built ML-region PDB (with B-factor layer info). When provided, ml_region generation is skipped.",
)
@click.option("--auto-mm-ff-set", "mm_ff_set",
              type=click.Choice(["ff19SB", "ff14SB"], case_sensitive=False),
              default="ff19SB", show_default=True,
              help="Force-field set forwarded to mm_parm (ff19SB uses OPC3; ff14SB uses TIP3P).")
@click.option("--auto-mm-add-ter/--auto-mm-no-add-ter", "mm_add_ter",
              default=True, show_default=True,
              help="Control mm_parm TER insertion around ligand/water/ion blocks and disconnected peptide blocks.")
@click.option("--auto-mm-keep-temp", "mm_keep_temp", is_flag=True, default=False, show_default=True,
              help="Keep the mm_parm temporary working directory (for debugging).")
@click.option(
    "--auto-mm-ligand-mult",
    "mm_ligand_mult",
    type=str,
    default=None,
    help=("Spin multiplicity mapping forwarded to mm_parm (e.g., 'GPP:2,SAM:1'). "
          "If omitted, mm_parm defaults to 1 for all ligands.")
)
# ===== Path search knobs (subset of path_search.cli) =====
@click.option("-m", "--multiplicity", "spin", type=int, default=1, show_default=True, help="Multiplicity (2S+1).")
@click.option("--max-nodes", type=int, default=_path_opt.GS_KW["max_nodes"], show_default=True,
              help="Max internal nodes for **segment** GSM (String has max_nodes+2 images including endpoints).")
@click.option("--max-cycles", type=int, default=300, show_default=True, help="Maximum GSM optimization cycles.")
@click.option("--climb", type=click.BOOL, default=True, show_default=True,
              help="Enable transition-state climbing after growth for the **first** segment in each pair.")
@click.option(
    "--opt-mode",
    type=click.Choice(["grad", "hess"], case_sensitive=False),
    default="grad",
    show_default=True,
    help=(
        "Optimizer mode forwarded to scan/path-search and used for single optimizations: "
        "grad (=LBFGS/Dimer) or hess (=RFO/RSIRFO)."
    ),
)
@click.option(
    "--opt-mode-post",
    type=click.Choice(["grad", "hess"], case_sensitive=False),
    default="hess",
    show_default=True,
    help=(
        "Optimizer mode for TSOPT and post-IRC endpoint optimizations. "
        "Takes precedence over --opt-mode for these stages."
    ),
)
@click.option("--dump", type=click.BOOL, default=False, show_default=True,
              help="Dump GSM / single-structure trajectories during the run, forwarding the same flag to scan/tsopt/freq.")
@click.option(
    "--refine-path/--no-refine-path",
    "refine_path",
    default=False,
    show_default=True,
    help=(
        "If False (default), run a single-pass path-opt GSM between each adjacent pair and concatenate the "
        "segments (no path_search); if True, run recursive path_search on the full ordered series for "
        "automatic multistep discovery."
    ),
)
@click.option(
    "--thresh",
    type=click.Choice(THRESH_CHOICES, case_sensitive=False),
    default=None,
    show_default=False,
    help=(
        "Convergence preset (gau_loose|gau|gau_tight|gau_vtight|baker|never). "
        "Defaults to 'gau_loose' for path-opt, 'gau' for scan."
    ),
)
@click.option(
    "--thresh-post",
    type=click.Choice(THRESH_CHOICES, case_sensitive=False),
    default="baker",
    show_default=True,
    help=(
        "Convergence preset for post-IRC endpoint optimizations "
        "(gau_loose|gau|gau_tight|gau_vtight|baker|never)."
    ),
)
@click.option("--config", "config_yaml", type=click.Path(path_type=Path, exists=True, dir_okay=False),
              default=None, help="Base YAML configuration file applied before explicit CLI options.")
@click.option("--show-config/--no-show-config", "show_config", default=False, show_default=True,
              help="Print resolved configuration and continue execution.")
@click.option("--dry-run/--no-dry-run", "dry_run", default=False, show_default=True,
              help="Validate options and print the execution plan without running any stage.")
@click.option("--preopt", "pre_opt", type=click.BOOL, default=True, show_default=True,
              help="If True, run initial single-structure optimizations of the pocket inputs.")
@click.option("--hessian-calc-mode",
              type=click.Choice(["Analytical", "FiniteDifference"], case_sensitive=False),
              default=None,
              help="Common MLIP Hessian calculation mode forwarded to tsopt and freq. Default: 'FiniteDifference'. Use 'Analytical' when VRAM is sufficient.")
@click.option(
    "--detect-layer/--no-detect-layer",
    "detect_layer",
    default=True,
    show_default=True,
    help="Detect ML/MM layers from input PDB B-factors (ML=0, MovableMM=10, FrozenMM=20) in downstream tools. "
         "If disabled, downstream tools require --model-pdb or --model-indices.",
)
# ===== Post-processing toggles =====
@click.option("--tsopt", "do_tsopt", type=click.BOOL, default=False, show_default=True,
              help="TS optimization + EulerPC IRC per reactive segment (or TSOPT-only mode for single-structure), and build energy diagrams.")
@click.option("--thermo", "do_thermo", type=click.BOOL, default=False, show_default=True,
              help="Run freq on (R,TS,P) per reactive segment (or TSOPT-only mode) and build Gibbs free-energy diagram (MLIP).")
@click.option("--dft", "do_dft", type=click.BOOL, default=False, show_default=True,
              help="Run DFT single-point on (R,TS,P) and build DFT energy diagram. With --thermo True, also generate a DFT//MLIP Gibbs diagram.")
@click.option("--tsopt-max-cycles", type=int, default=None,
              help="Override tsopt --max-cycles value.")
@click.option(
    "--flatten/--no-flatten",
    "flatten",
    default=False,
    show_default=True,
    help="Enable the extra-imaginary-mode flattening loop in tsopt (grad: dimer loop, hess: post-RSIRFO); --no-flatten forces flatten_max_iter=0.",
)
@click.option(
    "--skip-final-freq/--no-skip-final-freq",
    "skip_final_freq",
    default=False,
    show_default=True,
    help="Skip post-convergence frequency analysis in tsopt. Useful for large unfrozen systems.",
)
@click.option("--tsopt-out-dir", type=click.Path(path_type=Path, file_okay=False), default=None,
              help="Override tsopt output subdirectory (relative paths are resolved against the default).")
@click.option("--freq-out-dir", type=click.Path(path_type=Path, file_okay=False), default=None,
              help="Override freq output base directory (relative paths resolved against the default).")
@click.option("--freq-max-write", type=int, default=None,
              help="Override freq --max-write value.")
@click.option("--freq-amplitude-ang", type=float, default=None,
              help="Override freq --amplitude-ang (Å).")
@click.option("--freq-n-frames", type=int, default=None,
              help="Override freq --n-frames value.")
@click.option("--freq-sort", type=click.Choice(["value", "abs"], case_sensitive=False), default=None,
              help="Override freq mode sorting.")
@click.option("--freq-temperature", type=float, default=None,
              help="Override freq thermochemistry temperature (K).")
@click.option("--freq-pressure", type=float, default=None,
              help="Override freq thermochemistry pressure (atm).")
@click.option("--dft-out-dir", type=click.Path(path_type=Path, file_okay=False), default=None,
              help="Override dft output base directory (relative paths resolved against the default).")
@click.option("--dft-func-basis", type=str, default=None,
              help="Override dft --func-basis value.")
@click.option("--dft-max-cycle", type=int, default=None,
              help="Override dft --max-cycle value.")
@click.option("--dft-conv-tol", type=float, default=None,
              help="Override dft --conv-tol value.")
@click.option("--dft-grid-level", type=int, default=None,
              help="Override dft --grid-level value.")
@click.option("--dft-engine", type=click.Choice(["gpu", "cpu"]), default=None,
              help="Override dft --engine value.")
# ===== Staged scan specification for single-structure route =====
@click.option(
    "-s", "--scan-lists",
    "scan_lists_raw",
    type=str, multiple=True, required=False,
    help='Scan targets: inline Python literal or a YAML/JSON spec file path. '
         'Multiple inline literals define sequential stages, e.g. '
         '"[(12,45,1.35)]" "[(10,55,2.20),(23,34,1.80)]". '
         'Indices refer to the original full PDB (1-based) or PDB atom selectors like "TYR,285,CA"; '
         'they are auto-mapped to the pocket after extraction.',
)
@click.option("--scan-out-dir", type=click.Path(path_type=Path, file_okay=False), default=None,
              help="Override the scan output directory (default: <out-dir>/scan/). Relative paths are resolved against the default parent.")
@click.option("--scan-one-based", type=click.BOOL, default=None,
              help="Override scan indexing interpretation (True = 1-based, False = 0-based).")
@click.option("--scan-max-step-size", type=float, default=None,
              help="Override scan --max-step-size (Å).")
@click.option("--scan-bias-k", type=float, default=None,
              help="Override scan harmonic bias strength k (eV/Å^2).")
@click.option("--scan-relax-max-cycles", type=int, default=None,
              help="Override scan relaxation max cycles per step.")
@click.option("--scan-preopt", "scan_preopt_override", type=click.BOOL, default=None,
              help="Override scan --preopt flag.")
@click.option("--scan-endopt", "scan_endopt_override", type=click.BOOL, default=None,
              help="Override scan --endopt flag.")
@click.option("--convert-files/--no-convert-files", "convert_files", default=True, show_default=True,
              help="Convert XYZ/TRJ outputs to PDB format using reference topology; forwarded to all subcommands.")
@click.option(
    "--ref-pdb",
    "ref_pdb_cli",
    type=click.Path(path_type=Path, exists=True, dir_okay=False),
    default=None,
    help=(
        "Reference PDB for topology/B-factor layer information when -i provides XYZ inputs. "
        "Used for define-layer, mm_parm, ml_region, and forwarded to downstream tools "
        "(tsopt, irc, freq, path_search) as --ref-pdb."
    ),
)
@click.option(
    "-b", "--backend",
    type=click.Choice(["uma", "orb", "mace", "aimnet2"], case_sensitive=False),
    default=None,
    show_default=False,
    help="ML backend for the ONIOM high-level region (default: uma).",
)
@click.option(
    "--embedcharge/--no-embedcharge",
    "embedcharge",
    default=False,
    show_default=True,
    help="Enable xTB point-charge embedding correction for MM→ML environmental effects (experimental).",
)
@click.option(
    "--embedcharge-cutoff",
    "embedcharge_cutoff",
    type=float,
    default=None,
    show_default=False,
    help="Distance cutoff (Å) from ML region for MM point charges in xTB embedding. "
         "Default: 12.0 Å. Only used when --embedcharge is enabled.",
)
@click.option(
    "--link-atom-method",
    "link_atom_method",
    type=click.Choice(["scaled", "fixed"], case_sensitive=False),
    default=None,
    show_default=False,
    help="Link-atom position mode: scaled (g-factor, default) or fixed (legacy 1.09/1.01 Å).",
)
@click.option(
    "--mm-backend",
    "mm_backend",
    type=click.Choice(["hessian_ff", "openmm"], case_sensitive=False),
    default=None,
    show_default=False,
    help="MM backend: hessian_ff (analytical Hessian, default) or openmm (finite-difference Hessian, slower).",
)
@click.option(
    "--cmap/--no-cmap",
    "use_cmap",
    default=None,
    show_default=False,
    help="Enable CMAP (backbone cross-map) terms in model parm7. Default: disabled (Gaussian ONIOM-compatible).",
)
@add_coord_type_option(choices=("cart", "dlc"))
@add_precision_option()
@add_backend_model_option()
@add_calc_file_option()
@add_deterministic_option()
@add_allow_charge_mult_mismatch_option()
@click.pass_context
def cli(
    ctx: click.Context,
    input_paths: Sequence[Path],
    center_spec: Optional[str],
    out_dir: Path,
    radius: float,
    radius_het2het: float,
    include_h2o: bool,
    exclude_backbone: bool,
    add_linkh: bool,
    selected_resn: str,
    modified_residue: str,
    ligand_charge: Optional[str],
    charge_override: Optional[int],
    parm7_override: Optional[Path],
    model_pdb_override: Optional[Path],
    mm_ff_set: str,
    mm_add_ter: bool,
    mm_keep_temp: bool,
    mm_ligand_mult: Optional[str],
    spin: int,
    max_nodes: int,
    max_cycles: int,
    climb: bool,
    opt_mode: str,
    opt_mode_post: Optional[str],
    dump: bool,
    refine_path: bool,
    thresh: Optional[str],
    thresh_post: str,
    config_yaml: Optional[Path],
    show_config: bool,
    dry_run: bool,
    pre_opt: bool,
    hessian_calc_mode: Optional[str],
    detect_layer: bool,
    do_tsopt: bool,
    do_thermo: bool,
    do_dft: bool,
    scan_lists_raw: Sequence[str],
    scan_out_dir: Optional[Path],
    scan_one_based: Optional[bool],
    scan_max_step_size: Optional[float],
    scan_bias_k: Optional[float],
    scan_relax_max_cycles: Optional[int],
    scan_preopt_override: Optional[bool],
    scan_endopt_override: Optional[bool],
    convert_files: bool,
    ref_pdb_cli: Optional[Path],
    backend: Optional[str],
    embedcharge: bool,
    embedcharge_cutoff: Optional[float],
    link_atom_method: Optional[str],
    mm_backend: Optional[str],
    use_cmap: Optional[bool],
    tsopt_max_cycles: Optional[int],
    flatten: bool,
    skip_final_freq: bool,
    tsopt_out_dir: Optional[Path],
    freq_out_dir: Optional[Path],
    freq_max_write: Optional[int],
    freq_amplitude_ang: Optional[float],
    freq_n_frames: Optional[int],
    freq_sort: Optional[str],
    freq_temperature: Optional[float],
    freq_pressure: Optional[float],
    dft_out_dir: Optional[Path],
    dft_func_basis: Optional[str],
    dft_max_cycle: Optional[int],
    dft_conv_tol: Optional[float],
    dft_grid_level: Optional[int],
    dft_engine: Optional[str],
    cli_coord_type: Optional[str],
    precision: Optional[str],
    backend_model: Optional[str],
    calc_file: Optional[str],
    calc_factory: str,
) -> None:
    """
    The **all** command composes `extract` → (optional `scan` on pocket) → MEP search (single-pass `path-opt` by default,
    or recursive `path_search` with ``--refine-path``) and hides ref-template bookkeeping.
    It also accepts the sloppy `-i A B C` style like `path_search` does. With single input:
      - with --scan-lists: run staged scan on the pocket and use stage results as inputs for path-opt (or path_search),
      - with --tsopt True and no --scan-lists: run TSOPT-only mode (no MEP search).
    """
    # Turn on pipeline-scoped default-verbosity suppression for this `all` run
    # (reset per-invocation in DefaultGroup.parse_args). Standalone leaf/report
    # commands are unaffected and keep full output at default verbosity.
    from mlmm.core.utils import set_pipeline_mode
    set_pipeline_mode(True)
    _echo_state.reset()

    time_start = time.perf_counter()
    command_str = "mlmm all " + " ".join(sys.argv[1:])

    _is_param_explicit = make_is_param_explicit(ctx)
    dump_override_requested = _is_param_explicit("dump")
    opt_mode_set = _is_param_explicit("opt_mode")
    opt_mode_post_set = _is_param_explicit("opt_mode_post")
    # Gate --no-embedcharge forwarding: when CLI default False is not user-supplied,
    # do not emit --no-embedcharge to downstream subprocesses (otherwise a `calc.embedcharge: true`
    # in --config YAML is silently overridden by the CLI default).
    embedcharge_explicit = _is_param_explicit("embedcharge")

    config_yaml, override_yaml, _ = resolve_yaml_sources(config_yaml, None, None)
    args_yaml, merged_yaml_cfg = _build_effective_args_yaml(
        config_yaml=config_yaml,
        override_yaml=None,
        tmp_prefix="mlmm_all_merged_",
    )
    _injected_coord = (
        str(cli_coord_type).lower()
        if _is_param_explicit("cli_coord_type") and cli_coord_type is not None
        else None
    )
    if _injected_coord is not None or precision is not None or backend_model is not None or calc_file is not None:
        args_yaml = _inject_coord_type_into_args_yaml(
            args_yaml, _injected_coord, precision=precision, backend_model=backend_model,
            calc_file=(str(Path(calc_file).resolve()) if calc_file else None), calc_factory=calc_factory,
        )

    mm_ff_set = "ff14SB" if str(mm_ff_set).lower().startswith("ff14") else "ff19SB"

    # --- Robustly accept a single "-i" followed by multiple paths (like path_search.cli) ---
    argv_all = sys.argv[1:]
    i_vals = collect_single_option_values(argv_all, ("-i", "--input"), label="-i/--input")
    if i_vals:
        i_parsed = validate_existing_files(
            i_vals,
            option_name="-i/--input",
            hint="When using '-i', list only existing file paths (multiple paths may follow a single '-i').",
        )
        input_paths = tuple(i_parsed)

    scan_vals = collect_single_option_values(argv_all, ("-s", "--scan-lists"), "--scan-lists")
    if scan_vals:
        scan_lists_raw = tuple(scan_vals)

    is_single = (len(input_paths) == 1)
    has_scan = bool(scan_lists_raw)
    single_tsopt_mode = (is_single and (not has_scan) and do_tsopt)

    if (len(input_paths) < 2) and (not (is_single and (has_scan or do_tsopt))):
        raise click.BadParameter(
            "Provide at least two PDBs with -i/--input in reaction order, "
            "or use a single PDB with --scan-lists, or a single PDB with --tsopt True."
        )

    if single_tsopt_mode:
        all_mode = "tsopt-only"
    elif has_scan:
        all_mode = "scan-to-path-search" if refine_path else "scan-to-path-opt"
    else:
        all_mode = "path-search" if refine_path else "path-opt"
    all_mode_label = "ts-only" if single_tsopt_mode else ("scan-lists" if has_scan else "mep")
    if verbose_level() >= 2:
        _echo(
            f"[mode] all ({all_mode_label}) inputs={len(input_paths)} "
            f"extract={'yes' if center_spec is not None and str(center_spec).strip() else 'no'} "
            f"scan={'yes' if has_scan else 'no'} tsopt={'yes' if do_tsopt else 'no'} "
            f"thermo={'yes' if do_thermo else 'no'} dft={'yes' if do_dft else 'no'} "
            f"dry_run={'yes' if dry_run else 'no'} internal={all_mode}",
            narrative=True,
        )

    _mode_alias = {
        "grad": "grad",
        "hess": "hess",
        "light": "grad",
        "heavy": "hess",
    }
    opt_mode_norm = _mode_alias.get(str(opt_mode).strip().lower(), "grad")
    path_search_opt_mode = opt_mode_norm
    opt_mode_post_norm = (
        None
        if opt_mode_post is None
        else _mode_alias.get(str(opt_mode_post).strip().lower(), "hess")
    )
    endpoint_opt_mode_default = (
        opt_mode_post_norm if (opt_mode_post_set and opt_mode_post_norm is not None)
        else (opt_mode_norm if opt_mode_set else "hess")
    )
    if opt_mode_post_norm in {"grad", "hess"}:
        tsopt_opt_mode_default = opt_mode_post_norm
    elif opt_mode_set:
        tsopt_opt_mode_default = opt_mode_norm
    else:
        tsopt_opt_mode_default = "hess"
    from mlmm.workflows._all_helpers import (
        build_tsopt_overrides as _build_tsopt_overrides,
        build_freq_overrides as _build_freq_overrides,
        build_dft_overrides as _build_dft_overrides,
    )
    tsopt_overrides = _build_tsopt_overrides(
        tsopt_max_cycles=tsopt_max_cycles,
        dump=dump,
        dump_override_requested=dump_override_requested,
        tsopt_out_dir=tsopt_out_dir,
        hessian_calc_mode=hessian_calc_mode,
        opt_mode_post_norm=opt_mode_post_norm,
        opt_mode_set=opt_mode_set,
        tsopt_opt_mode_default=tsopt_opt_mode_default,
        convert_files=convert_files,
        thresh_post=thresh_post,
        flatten_explicit=_is_param_explicit("flatten"),
        flatten=flatten,
        skip_final_freq=skip_final_freq,
    )
    freq_overrides = _build_freq_overrides(
        freq_max_write=freq_max_write,
        freq_amplitude_ang=freq_amplitude_ang,
        freq_n_frames=freq_n_frames,
        freq_sort=freq_sort,
        freq_temperature=freq_temperature,
        freq_pressure=freq_pressure,
        dump_override_requested=dump_override_requested,
        dump=dump,
        hessian_calc_mode=hessian_calc_mode,
        convert_files=convert_files,
    )
    dft_overrides = _build_dft_overrides(
        dft_max_cycle=dft_max_cycle,
        dft_conv_tol=dft_conv_tol,
        dft_grid_level=dft_grid_level,
        dft_engine=dft_engine,
        convert_files=convert_files,
    )

    dft_func_basis_use = dft_func_basis or "wb97m-v/def2-tzvpd"
    dft_method_fallback = dft_func_basis_use

    if show_config or (dry_run and verbose_level() >= 3):
        config_payload: Dict[str, Any] = {
            "yaml": {
                "config": str(config_yaml) if config_yaml else None,
                "override_yaml": str(override_yaml) if override_yaml else None,
                "effective_args_yaml": str(args_yaml) if args_yaml else None,
            },
            "all": {
                "inputs": [str(p) for p in input_paths],
                "center": center_spec,
                "charge_override": charge_override,
                "skip_extract": bool(center_spec is None or str(center_spec).strip() == ""),
                "out_dir": str(out_dir),
                "spin": int(spin),
                "max_nodes": int(max_nodes),
                "max_cycles": int(max_cycles),
                "climb": bool(climb),
                "opt_mode": str(opt_mode),
                "opt_mode_post": (None if opt_mode_post is None else str(opt_mode_post)),
                "path_search_opt_mode": str(path_search_opt_mode),
                "endpoint_opt_mode": str(endpoint_opt_mode_default),
                "dump": bool(dump),
                "refine_path": bool(refine_path),
                "thresh": thresh,
                "thresh_post": thresh_post,
                "flatten": bool(flatten),
                "pre_opt": bool(pre_opt),
                "detect_layer": bool(detect_layer),
                "tsopt": bool(do_tsopt),
                "thermo": bool(do_thermo),
                "dft": bool(do_dft),
            },
            "overrides": {
                "tsopt": tsopt_overrides,
                "freq": freq_overrides,
                "dft": dft_overrides,
            },
        }
        if merged_yaml_cfg:
            config_payload["effective_yaml"] = merged_yaml_cfg
        _echo_section("====== [all] Effective configuration ======")
        # `--show-config` is an explicit output request; dry-run's automatic
        # config dump is level-3 debug context so -v 1/2 stay compact.
        click.echo(
            yaml.safe_dump(config_payload, sort_keys=False, allow_unicode=True).rstrip(),
            narrative=show_config,
        )

    if dry_run:
        _echo("[all] Dry-run mode: no extraction/search/post-processing was executed.", narrative=True)
        _echo(
            "[all] Planned stages: extract -> mm_parm -> optional scan -> path_opt/path_search -> optional tsopt/freq/dft.",
            narrative=True,
        )
        _emit_final_summary(out_dir, time_start)
        return

    out_dir = out_dir.resolve()
    work_dir = out_dir / WORK_DIRNAME  # pipeline-wide scratch (safe to rm -rf)
    pockets_dir = work_dir / "pockets"
    # MEP-engine raw output is scratch under _work/; only its moved products reach root.
    path_dir = work_dir / ("path_search" if refine_path else "path_opt")
    scan_dir = _resolve_override_dir(work_dir / "scan", scan_out_dir)  # for single-structure scan mode
    # One monotonic Stage numbering for the whole pipeline: Stage 1 extraction
    # (+ lettered preparation sub-stages 1b/1c/1d), 2 MEP search, 3 merge,
    # 4 post-processing. Banners read "Stage N/{stage_total}". Stage 4 only
    # runs when at least one of tsopt/thermo/dft is requested, so the
    # denominator drops to 3 on a default run to avoid an "N/4" that never
    # reaches 4.
    stage_total = 4 if (do_tsopt or do_thermo or do_dft) else 3
    ensure_dir(out_dir)
    if not single_tsopt_mode:
        ensure_dir(path_dir)  # path_search might be skipped only in tsopt-only mode

    # Preflight: add_elem_info only for inputs lacking element fields
    # → Create fixed copies under a temporary folder inside out_dir (used ONLY for extraction)
    elem_tmp_dir = work_dir / "add_elem_info"
    inputs_for_extract: List[Path] = []
    elem_fix_echo=False
    for p in input_paths:
        if _pdb_needs_elem_fix(p):
            if elem_fix_echo==False:
                _echo_section("====== [all] Preflight — add_elem_info (only when element fields are missing) ======")
                elem_fix_echo=True
            ensure_dir(elem_tmp_dir)
            out_p = (elem_tmp_dir / p.name).resolve()
            try:
                _assign_elem_info(str(p), str(out_p), overwrite=False)
                _echo(f"[all] add_elem_info: fixed elements → {out_p}")
                inputs_for_extract.append(out_p)
            except SystemExit as e:
                code = getattr(e, "code", 1)
                _echo(f"[all] WARNING: add_elem_info exited with code {code} for {p}; using original.", err=True)
                inputs_for_extract.append(p.resolve())
            except Exception as e:
                _echo(f"[all] WARNING: add_elem_info failed for {p}: {e} — using original file.", err=True)
                inputs_for_extract.append(p.resolve())
        else:
            inputs_for_extract.append(p.resolve())

    extract_inputs = tuple(inputs_for_extract)
    skip_extract = center_spec is None or str(center_spec).strip() == ""

    # OOM hazard guard: skip_extract + --no-detect-layer + no --model-pdb collapses the
    # ML region to the entire input PDB, causing downstream ML/MM ONIOM to scale ML over
    # all atoms (OOM on enzyme-sized systems). Hard-fail with an actionable message
    # rather than silently running the doomed configuration.
    if skip_extract and (not detect_layer) and model_pdb_override is None:
        raise click.ClickException(
            "[all] Skipping extraction (no -c/--center) with --no-detect-layer requires "
            "--model-pdb. Otherwise the ML region collapses to the entire input PDB and "
            "downstream ML/MM ONIOM will treat every atom as ML (OOM hazard on enzyme-sized "
            "systems). Provide --model-pdb <ml_region.pdb>, or enable --detect-layer to use "
            "B-factor layer information from the input PDB."
        )

    # When inputs are XYZ and --ref-pdb is provided, use it for topology-requiring steps
    ref_pdb_for_topology: Optional[Path] = None
    if ref_pdb_cli is not None:
        ref_pdb_for_topology = ref_pdb_cli.resolve()
        _echo(f"[all] --ref-pdb provided: {ref_pdb_for_topology}")

    resolved_charge: Optional[int] = None
    pocket_outputs: List[Path] = []

    if skip_extract:
        _echo_section(
            f"====== [all] Stage 1/{stage_total} — Extraction skipped (no -c/--center); using full structures as pockets ======"
        )
        pocket_outputs = [p.resolve() for p in extract_inputs]
        _echo("[all] Pocket inputs (full structures):")
        for op in pocket_outputs:
            _echo(f"  - {op}")
        # Charge derivation when extraction is skipped:
        #  - --model-pdb provided → derive over that ML pocket PDB.
        #  - detect-layer with a layered input that actually has MM atoms (ML ⊊
        #    system) → derive the ML-region (B≈0) charge WITH cap correction.
        #    Deriving over the full input would mis-apply the whole-system charge
        #    to the ML/QM region (the ts-only / flatten charge bug: e.g. CM total
        #    -15 vs ML +0; R90A total -16 vs ML -2). validate_charge_spin only
        #    catches an electron-parity break, so same-parity cases were silently
        #    wrong. Reuse extract's compute_charge_summary with backbone-cut caps.
        #  - otherwise (the whole input IS the model, p2r-style) → full input PDB.
        resolved_charge = None
        if model_pdb_override is not None:
            resolved_charge = _derive_charge_from_ligand_charge_when_extract_skipped(
                model_pdb_override, ligand_charge
            )
        elif detect_layer:
            _layer_counts = _summarize_existing_bfactor_layers(extract_inputs[0])
            if _layer_counts.get("movable", 0) > 0 or _layer_counts.get("frozen", 0) > 0:
                resolved_charge = _derive_ml_charge_from_layered_pdb(
                    extract_inputs[0], ligand_charge
                )
        if resolved_charge is None:
            resolved_charge = _derive_charge_from_ligand_charge_when_extract_skipped(
                extract_inputs[0], ligand_charge
            )
    else:
        _echo_section(
            f"====== [all] Stage 1/{stage_total} — Active-site pocket extraction ======"
        )
        ensure_dir(pockets_dir)
        for p in extract_inputs:
            pocket_outputs.append((pockets_dir / f"pocket_{p.stem}.pdb").resolve())

        try:
            ex_res = extract_api(
                complex_pdb=[str(p) for p in extract_inputs],
                center=center_spec,
                output=[str(p) for p in pocket_outputs],
                radius=float(radius),
                radius_het2het=float(radius_het2het),
                include_h2o=bool(include_h2o),
                exclude_backbone=bool(exclude_backbone),
                add_linkh=bool(add_linkh),
                selected_resn=selected_resn or "",
                modified_residue=modified_residue or "",
                ligand_charge=ligand_charge,
                verbose=True,  # extractor INFO now gated by the unified global -v level
            )
        except Exception as e:
            raise click.ClickException(f"[all] Extractor failed: {e}")

        _echo("[all] Pocket files:")
        for op in pocket_outputs:
            _echo(f"  - {op}")

        try:
            cs = ex_res.get("charge_summary", {})
            q_total = float(cs.get("total_charge", 0.0))
            q_prot = float(cs.get("protein_charge", 0.0))
            q_lig = float(cs.get("ligand_total_charge", 0.0))
            q_ion = float(cs.get("ion_total_charge", 0.0))
            _echo("")
            _echo("[all] Charge summary from extractor (model #1):")
            _echo(
                f"  Protein: {q_prot:+g},  Ligand: {q_lig:+g},  Ions: {q_ion:+g},  Total: {q_total:+g}"
            )
            resolved_charge = _round_charge_with_note(q_total)
        except Exception as e:
            raise click.ClickException(f"[all] Could not obtain total charge from extractor: {e}")

    if charge_override is not None:
        q_int = int(charge_override)
        override_msg = f"[all] WARNING: -q/--charge override supplied; forcing TOTAL system charge to {q_int:+d}"
        if resolved_charge is not None:
            override_msg += f" (would otherwise use {int(resolved_charge):+d} from workflow)"
        _echo(override_msg)
    else:
        if resolved_charge is None:
            raise click.ClickException(
                "[all] Total charge could not be resolved. Provide -q/--charge, "
                "or provide --ligand-charge when extraction is skipped."
            )
        q_int = int(resolved_charge)

    # Stage 1b: ML-region definition (copy first pocket) and mm_parm on the first full input
    _echo_section("====== [all] Stage 1b — ML/MM preparation — ML region + parm7 ======")
    first_pocket = pocket_outputs[0]
    first_full_input = extract_inputs[0]
    # When --ref-pdb is provided, use it for PDB-requiring topology operations
    pocket_for_ml_region = ref_pdb_for_topology if ref_pdb_for_topology is not None else first_pocket
    pdb_for_mm_parm = ref_pdb_for_topology if ref_pdb_for_topology is not None else first_full_input

    # ML region definition: use --model-pdb if provided, otherwise generate from pocket.
    # When extraction was skipped + detect-layer, the "pocket" is the whole input, so
    # write only the B≈0 ML atoms — otherwise ml_region.pdb is the full system and a
    # downstream stage that can't read B-factors collapses ML to the entire input
    # (sum_Z=45875 electron-count error at freq/dft).
    if model_pdb_override is not None:
        ml_region_pdb = model_pdb_override.resolve()
        _echo_detail(f"[all] ML region definition (--model-pdb override) → {ml_region_pdb}")
    else:
        ml_region_pdb = None
        if skip_extract and detect_layer:
            ml_region_pdb = _write_bfactor_ml_subset(pocket_for_ml_region, out_dir / "ml_region.pdb")
            if ml_region_pdb is not None:
                _echo_detail(f"[all] ML region definition (B≈0 subset from layered input) → {ml_region_pdb}")
        if ml_region_pdb is None:
            ml_region_pdb = _write_ml_region_definition(pocket_for_ml_region, out_dir / "ml_region.pdb")  # reusable deliverable (--model-pdb input for follow-up runs)
            _echo_detail(f"[all] ML region definition → {ml_region_pdb}")

    # mm_parm: use --parm if provided, otherwise run tleap
    if parm7_override is not None:
        real_parm7_path = parm7_override.resolve()
        _echo_detail(f"[all] parm7 (--parm override) → {real_parm7_path}")
    else:
        _echo_detail(f"[all] mm_parm source PDB → {pdb_for_mm_parm}")
        mm_dir = out_dir / "mm_parm"  # reusable deliverable (parm7/rst7 = --parm input for follow-up runs)
        ensure_commands_available(
            ("tleap", "antechamber", "parmchk2"),
            context="mm_parm (AmberTools)",
        )
        real_parm7_path, real_rst7_path = _build_mm_parm7(
            pdb=pdb_for_mm_parm,
            ligand_charge_expr=ligand_charge,
            ligand_mult_expr=mm_ligand_mult,
            out_dir=mm_dir,
            ff_set=mm_ff_set,
            add_ter=mm_add_ter,
            keep_temp=mm_keep_temp,
        )
        _echo_detail(f"[all] mm_parm outputs → parm7: {real_parm7_path.name}, rst7: {real_rst7_path.name}")

    # define-layer: assign 3-layer B-factors to each full-system PDB
    _echo_section("====== [all] Stage 1c — define-layer — assign 3-layer B-factors to full-system PDBs ======")
    layered_dir = out_dir / "layered"  # deliverable (B-factor-layered PDBs for inspection / reuse)
    ensure_dir(layered_dir)
    layered_inputs: List[Path] = []
    # If extraction was skipped AND --detect-layer is True AND --model-pdb was
    # not provided, the user is responsible for B-factor layer encoding in the
    # input PDB (typical workflow: pre-run `mlmm define-layer --radius-freeze
    # X` and feed the resulting layered.pdb to `mlmm all`). Recomputing layers
    # here would silently override radius-freeze with the default and (when
    # extract was skipped) collapse the ML region to the full system, since
    # ml_region_pdb is then the full PDB itself. Honor the input B-factors.
    honor_input_bfactors = bool(skip_extract and detect_layer and model_pdb_override is None)
    if honor_input_bfactors:
        _echo_detail(
            "[all] Extraction skipped and --detect-layer is on; honoring input PDB "
            "B-factor layer encoding (ML=0/MovableMM=10/FrozenMM=20)."
        )
        for idx, full_pdb in enumerate(extract_inputs):
            pdb_for_layer = full_pdb
            if ref_pdb_for_topology is not None and full_pdb.suffix.lower() != ".pdb":
                pdb_for_layer = ref_pdb_for_topology
            counts = _summarize_existing_bfactor_layers(pdb_for_layer)
            _echo_detail(
                f"[all] define-layer [{idx}]: {full_pdb.name} (input B-factor layers honored)  "
                f"(ML={counts['ml']}, MovableMM={counts['movable']}, FrozenMM={counts['frozen']})"
            )
            layered_inputs.append(pdb_for_layer)
    else:
        for idx, full_pdb in enumerate(extract_inputs):
            # When --ref-pdb is given and input is not PDB, use ref_pdb for define-layer
            pdb_for_layer = full_pdb
            if ref_pdb_for_topology is not None and full_pdb.suffix.lower() != ".pdb":
                pdb_for_layer = ref_pdb_for_topology
            out_layered = layered_dir / f"{pdb_for_layer.stem}_layered.pdb"
            try:
                layer_info = _define_layers(
                    input_pdb=pdb_for_layer,
                    output_pdb=out_layered,
                    model_pdb=ml_region_pdb,
                )
                _echo_detail(f"[all] define-layer [{idx}]: {full_pdb.name} → {out_layered.name}  "
                             f"(ML={len(layer_info.get('ml_indices', []))}, "
                             f"MovableMM={len(layer_info.get('movable_mm_indices', []))}, "
                             f"FrozenMM={len(layer_info.get('frozen_indices', []))})")
                layered_inputs.append(out_layered)
            except Exception as e:
                _echo(f"[all] WARNING: define-layer failed for {full_pdb.name}: {e}", err=True)
                _echo(f"[all] Falling back to original PDB (no B-factor layers).", err=True)
                layered_inputs.append(full_pdb)

    # Other path: single-structure + --tsopt True (and NO scan-lists) → TSOPT-only mode
    if single_tsopt_mode:
        _echo_section("====== [all] TSOPT-only single-structure mode ======")
        irc_trj_for_all: List[Tuple[Path, bool]] = []
        tsroot = out_dir / SEGMENTS_DIRNAME / "seg_01"
        ensure_dir(tsroot)

        # Use the layered full-system PDB as TS initial guess
        layered_pdb = layered_inputs[0]
        # When --ref-pdb is given and input is XYZ, copy the XYZ next to the layered PDB
        # so that _run_tsopt_on_hei can use XYZ (full precision) + layered PDB (topology)
        if ref_pdb_for_topology is not None and extract_inputs[0].suffix.lower() != ".pdb":
            xyz_companion = layered_pdb.with_suffix(".xyz")
            if not xyz_companion.exists():
                shutil.copy2(extract_inputs[0], xyz_companion)
                _echo(f"[all] Copied XYZ input → {xyz_companion} (full precision for tsopt)")
        # TS optimization
        ts_pdb, g_ts = _run_tsopt_on_hei(
            layered_pdb,
            q_int,
            spin,
            real_parm7_path,
            ml_region_pdb,
            detect_layer,
            args_yaml,
            tsroot,
            tsopt_opt_mode_default,
            overrides=tsopt_overrides,
            backend=backend,
            embedcharge=embedcharge,
            embedcharge_cutoff=embedcharge_cutoff,
            embedcharge_explicit=embedcharge_explicit,
            link_atom_method=link_atom_method,
            mm_backend=mm_backend,
            use_cmap=use_cmap,
            ref_pdb=layered_pdb,
        )

        # EulerPC IRC & map endpoints (no segment endpoints exist → fallback mapping)
        irc_pocket_ref = ref_pdb_for_topology if ref_pdb_for_topology is not None else first_pocket
        irc_res = _irc_and_match(seg_idx=1,
                                 seg_dir=tsroot,
                                 ref_pdb_for_seg=ts_pdb,
                                 seg_pocket_pdb=irc_pocket_ref,
                                 g_ts=g_ts,
                                 q_int=q_int,
                                 spin=spin,
                                 real_parm7=real_parm7_path,
                                 model_pdb=ml_region_pdb,
                                 detect_layer=detect_layer,
                                 backend=backend,
                                 embedcharge=embedcharge,
                                 embedcharge_cutoff=embedcharge_cutoff,
                                 embedcharge_explicit=embedcharge_explicit,
                                 link_atom_method=link_atom_method,
                                 mm_backend=mm_backend,
                                 use_cmap=use_cmap,
                                 args_yaml=args_yaml)
        gL = irc_res["left_min_geom"]
        gR = irc_res["right_min_geom"]
        gT = irc_res["ts_geom"]
        irc_plot_path = irc_res.get("irc_plot")
        irc_trj_path = irc_res.get("irc_trj")
        if irc_trj_path:
            try:
                irc_trj_for_all.append((Path(irc_trj_path), bool(irc_res.get("reverse_irc", False))))
            except Exception:
                logger.debug("Failed to append IRC trajectory path", exc_info=True)

        # Ensure UMA energies
        eL = float(gL.energy)
        eT = float(gT.energy)
        eR = float(gR.energy)

        # In this mode ONLY: assign Reactant/Product so that higher-energy end is the Reactant
        if eL >= eR:
            g_react, e_react = gL, eL
            g_prod,  e_prod  = gR, eR
        else:
            g_react, e_react = gR, eR
            g_prod,  e_prod  = gL, eL

        # Save XYZ (full precision) + PDB (companion) and run endpoint-opt
        struct_dir = tsroot / "structures"
        ensure_dir(struct_dir)
        pocket_ref = ref_pdb_for_topology if ref_pdb_for_topology is not None else first_pocket
        xR_irc, pR_irc = _save_single_geom_for_tools(g_react, pocket_ref, struct_dir, "reactant_irc")
        xT, pT         = _save_single_geom_for_tools(gT,       pocket_ref, struct_dir, "ts")
        xP_irc, pP_irc = _save_single_geom_for_tools(g_prod,   pocket_ref, struct_dir, "product_irc")

        endpoint_opt_dir = tsroot / "endpoint_opt"
        ensure_dir(endpoint_opt_dir)

        # Map IRC left/right Hessians → R/P endpoint (left=forward, right=backward)
        from mlmm.io.hessian_cache import load as _hess_load, store as _hess_store, clear as _clear_hess_cache
        _react_hk = "irc_left" if eL >= eR else "irc_right"
        _prod_hk  = "irc_right" if eL >= eR else "irc_left"

        _c = _hess_load(_react_hk)
        if _c:
            _hess_store("irc_endpoint", _c["hessian"], active_dofs=_c.get("active_dofs"), meta=_c.get("meta"))
        try:
            g_react, _ = _run_opt_for_state(
                pR_irc, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                endpoint_opt_dir / "R", args_yaml, endpoint_opt_mode_default,
                convert_files=convert_files,
                backend=backend,
                embedcharge=embedcharge,
                embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method,
                mm_backend=mm_backend,
                use_cmap=use_cmap,
                thresh=thresh_post,
                xyz_path=xR_irc,
            )
        except Exception as e:
            _echo(
                f"[post] WARNING: Reactant endpoint optimization failed in TSOPT-only mode: {e}",
                err=True,
            )

        _c = _hess_load(_prod_hk)
        if _c:
            _hess_store("irc_endpoint", _c["hessian"], active_dofs=_c.get("active_dofs"), meta=_c.get("meta"))
        try:
            g_prod, _ = _run_opt_for_state(
                pP_irc, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                endpoint_opt_dir / "P", args_yaml, endpoint_opt_mode_default,
                convert_files=convert_files,
                backend=backend,
                embedcharge=embedcharge,
                embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method,
                mm_backend=mm_backend,
                use_cmap=use_cmap,
                thresh=thresh_post,
                xyz_path=xP_irc,
            )
        except Exception as e:
            _echo(
                f"[post] WARNING: Product endpoint optimization failed in TSOPT-only mode: {e}",
                err=True,
            )
        shutil.rmtree(endpoint_opt_dir, ignore_errors=True)
        _echo_detail("[endpoint-opt] Clean endpoint-opt working dir.")

        xR, pR = _save_single_geom_for_tools(g_react, pocket_ref, struct_dir, "reactant")
        xP, pP = _save_single_geom_for_tools(g_prod,   pocket_ref, struct_dir, "product")
        e_react = float(g_react.energy)
        e_prod = float(g_prod.energy)

        # UMA energy diagram (R, TS, P)
        uma_prefix = tsroot / "energy_diagram_UMA"
        uma_diag = _write_segment_energy_diagram(
            uma_prefix,
            labels=["R", "TS", "P"],
            energies_eh=[e_react, eT, e_prod],
            title_note="(UMA, TSOPT/IRC)",
        )
        g_uma_diag = None
        dft_diag = None
        g_dft_diag = None

        # ── Release GPU memory before freq/thermo/DFT ──
        for _g in (gL, gR, gT, g_react, g_prod):
            if _g is not None and hasattr(_g, "calculator"):
                _g.calculator = None
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Thermochemistry (UMA) Gibbs
        thermo_payloads: Dict[str, Dict[str, Any]] = {}
        GR = GT = GP = None
        eR_dft = eT_dft = eP_dft = None
        GR_dftUMA = GT_dftUMA = GP_dftUMA = None
        freq_root = _resolve_override_dir(tsroot / "freq", freq_out_dir)
        dft_root = _resolve_override_dir(tsroot / "dft", dft_out_dir)

        if do_thermo:
            _echo_detail("[thermo] Single TSOPT: freq on TS/R/P")
            tT = _run_freq_for_state(pT, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                                     freq_root / "TS", args_yaml, overrides=freq_overrides,
                                     backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                                     embedcharge_explicit=embedcharge_explicit,
                                     link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xT)
            _clear_hess_cache()  # TS Hessian consumed; R/P need exact computation
            tR = _run_freq_for_state(pR, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                                     freq_root / "R", args_yaml, overrides=freq_overrides,
                                     backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                                     embedcharge_explicit=embedcharge_explicit,
                                     link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xR)
            tP = _run_freq_for_state(pP, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                                     freq_root / "P", args_yaml, overrides=freq_overrides,
                                     backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                                     embedcharge_explicit=embedcharge_explicit,
                                     link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xP)
            thermo_payloads = {"R": tR, "TS": tT, "P": tP}
            try:
                GR = float(tR.get("sum_EE_and_thermal_free_energy_ha", e_react))
                GT = float(tT.get("sum_EE_and_thermal_free_energy_ha", eT))
                GP = float(tP.get("sum_EE_and_thermal_free_energy_ha", e_prod))
                g_uma_diag = _write_segment_energy_diagram(
                    tsroot / "energy_diagram_G_UMA",
                    labels=["R", "TS", "P"],
                    energies_eh=[GR, GT, GP],
                    title_note="(Gibbs, UMA)",
                    ylabel="ΔG (kcal/mol)",
                )
            except Exception as e:
                _echo(f"[thermo] WARNING: failed to build Gibbs diagram: {e}", err=True)

        # DFT & DFT//UMA
        if do_dft:
            # DO NOT INLINE: (single-TS path): freq subprocess parsing may
            # have re-bound heavy refs onto cli()-frame Geometry locals.
            # Two layers (null calculator + rebind local to None) prevent
            # closure/hook capture from resurrecting the model before the
            # DFT subprocess fork. `del locals()[name]` is a CPython no-op
            # (locals() returns a *copy* of the frame namespace), so the
            # rebind to None below is the only mechanism that actually
            # decrements the heavy refs before gc.collect + empty_cache.
            # NOTE: thermo_payloads is deliberately NOT nulled here — the
            # DFT//UMA Gibbs-diagram block below and the
            # segment_log num_imag write both read from it.
            for _g in (gL, gR, gT, g_react, g_prod):
                if _g is not None and hasattr(_g, "calculator"):
                    _g.calculator = None
            gL = gR = gT = g_react = g_prod = None
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            _echo_detail("[dft] Single TSOPT: DFT on R/TS/P")
            dR = _run_dft_for_state(pR, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                                     dft_root / "R", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
                                     backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                                     embedcharge_explicit=embedcharge_explicit,
                                     link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xR)
            dT = _run_dft_for_state(pT, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                                     dft_root / "TS", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
                                     backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                                     embedcharge_explicit=embedcharge_explicit,
                                     link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xT)
            dP = _run_dft_for_state(pP, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                                     dft_root / "P", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
                                     backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                                     embedcharge_explicit=embedcharge_explicit,
                                     link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xP)
            eR_dft = _dft_energy_ha(dR)
            eT_dft = _dft_energy_ha(dT)
            eP_dft = _dft_energy_ha(dP)
            _dft_all_ok = all(e is not None for e in (eR_dft, eT_dft, eP_dft))
            if not _dft_all_ok:
                _failed_states = [s for s, e in zip(["R", "TS", "P"], [eR_dft, eT_dft, eP_dft]) if e is None]
                _echo(f"[dft] WARNING: DFT failed for state(s): {', '.join(_failed_states)}. Skipping DFT diagrams.", err=True)
            if _dft_all_ok:
                try:
                    dft_diag = _write_segment_energy_diagram(
                        tsroot / "energy_diagram_DFT",
                        labels=["R", "TS", "P"],
                        energies_eh=[eR_dft, eT_dft, eP_dft],
                        title_note=f"({dft_method_fallback})",
                    )
                except Exception as e:
                    _echo(f"[dft] WARNING: failed to build DFT diagram: {e}", err=True)

            if do_thermo and _dft_all_ok:
                try:
                    dG_R = float(thermo_payloads.get("R", {}).get("thermal_correction_free_energy_ha", 0.0))
                    dG_T = float(thermo_payloads.get("TS", {}).get("thermal_correction_free_energy_ha", 0.0))
                    dG_P = float(thermo_payloads.get("P", {}).get("thermal_correction_free_energy_ha", 0.0))
                    GR_dftUMA = eR_dft + dG_R
                    GT_dftUMA = eT_dft + dG_T
                    GP_dftUMA = eP_dft + dG_P
                    g_dft_diag = _write_segment_energy_diagram(
                        tsroot / "energy_diagram_G_DFT_plus_UMA",
                        labels=["R", "TS", "P"],
                        energies_eh=[GR_dftUMA, GT_dftUMA, GP_dftUMA],
                        title_note="(Gibbs, DFT//UMA)",
                        ylabel="ΔG (kcal/mol)",
                    )
                except Exception as e:
                    _echo(f"[dft//uma] WARNING: failed to build DFT//UMA Gibbs diagram: {e}", err=True)

        # Summary.yaml / summary.log for TSOPT-only mode
        bond_cfg = dict(_path_search.BOND_KW)
        bond_summary = ""
        try:
            changed, bond_summary = _path_search._has_bond_change(g_react, g_prod, bond_cfg)
            if not changed:
                bond_summary = "(no covalent changes detected)"
        except Exception:
            bond_summary = "(no covalent changes detected)"

        barrier = (eT - e_react) * AU2KCALPERMOL
        delta = (e_prod - e_react) * AU2KCALPERMOL

        from mlmm.workflows._all_helpers import promote_diag_for_root
        energy_diagrams: List[Dict[str, Any]] = []
        for diag, stem in (
            (uma_diag, "energy_diagram_UMA"),
            (g_uma_diag, "energy_diagram_G_UMA"),
            (dft_diag, "energy_diagram_DFT"),
            (g_dft_diag, "energy_diagram_G_DFT_plus_UMA"),
        ):
            promoted = promote_diag_for_root(diag, stem, out_dir)
            if promoted is not None:
                energy_diagrams.append(promoted)

        summary = {
            "out_dir": str(tsroot),
            "n_images": 0,
            "n_segments": 1,
            "segments": [
                {
                    "index": 1,
                    "tag": "seg_01",
                    "kind": "tsopt",
                    "barrier_kcal": float(barrier),
                    "delta_kcal": float(delta),
                    "bond_changes": bond_summary,
                }
            ],
            "energy_diagrams": list(energy_diagrams),
        }
        _enrich_summary(
            summary,
            version="",
            pipeline_mode="tsopt-only",
            out_dir=out_dir,
            mlip_backend=backend or "unknown",
            charge=q_int,
            spin=spin,
            command=command_str,
            config={
                "refine_path": bool(refine_path),
                "tsopt": do_tsopt,
                "thermo": do_thermo,
                "dft": do_dft,
                "opt_mode": tsopt_opt_mode_default,
            },
        )
        try:
            with open(tsroot / "summary.json", "w") as f:
                json.dump(summary, f, indent=2, ensure_ascii=False)
            shutil.copy2(tsroot / "summary.json", out_dir / "summary.json")
        except Exception as e:
            _echo(f"[write] WARNING: failed to write summary.json: {e}", err=True)

        # Copy R/TS/P structures to out_dir/seg_01/
        try:
            _state_structs = {"R": pR, "TS": pT, "P": pP}
            _seg_out = _copy_structures_to_seg_dir(
                _state_structs, out_dir, 1, ".pdb",
            )
            _echo(f"[all] Wrote R/TS/P for segment 01 → {_seg_out}", narrative=True)
        except Exception as e:
            _echo(f"[all] WARNING: Failed to copy R/TS/P structures: {e}", err=True)

        segment_log: Dict[str, Any] = {
            "index": 1,
            "tag": "seg_01",
            "kind": "tsopt",
            "bond_changes": bond_summary,
            "post_dir": str(tsroot),
            "mep_barrier_kcal": barrier,
            "mep_delta_kcal": delta,
        }
        if irc_plot_path:
            segment_log["irc_plot"] = str(irc_plot_path)
        if irc_trj_path:
            segment_log["irc_traj"] = str(irc_trj_path)
        if do_thermo:
            n_imag = None
            try:
                n_imag = int(thermo_payloads.get("TS", {}).get("num_imag_freq"))
            except Exception:
                n_imag = None
            if n_imag is not None:
                segment_log["ts_imag"] = {"n_imag": n_imag}

        from mlmm.workflows._all_helpers import build_energy_level_dict
        _structs = {"R": pR, "TS": pT, "P": pP}
        segment_log["uma"] = build_energy_level_dict(
            labels=["R", "TS", "P"],
            energies_au=[e_react, eT, e_prod],
            ref_energy=e_react,
            au_to_kcal=AU2KCALPERMOL,
            diagram_path=str((tsroot / "energy_diagram_UMA").with_suffix(".png")),
            structures=_structs,
        )
        if GR is not None and GT is not None and GP is not None:
            segment_log["gibbs_uma"] = build_energy_level_dict(
                labels=["R", "TS", "P"],
                energies_au=[GR, GT, GP],
                ref_energy=GR,
                au_to_kcal=AU2KCALPERMOL,
                diagram_path=str((tsroot / "energy_diagram_G_UMA").with_suffix(".png")),
                structures=_structs,
            )
        if eR_dft is not None and eT_dft is not None and eP_dft is not None:
            segment_log["dft"] = build_energy_level_dict(
                labels=["R", "TS", "P"],
                energies_au=[eR_dft, eT_dft, eP_dft],
                ref_energy=eR_dft,
                au_to_kcal=AU2KCALPERMOL,
                diagram_path=str((tsroot / "energy_diagram_DFT").with_suffix(".png")),
                structures=_structs,
            )
        if GR_dftUMA is not None and GT_dftUMA is not None and GP_dftUMA is not None:
            segment_log["gibbs_dft_uma"] = build_energy_level_dict(
                labels=["R", "TS", "P"],
                energies_au=[GR_dftUMA, GT_dftUMA, GP_dftUMA],
                ref_energy=GR_dftUMA,
                au_to_kcal=AU2KCALPERMOL,
                diagram_path=str((tsroot / "energy_diagram_G_DFT_plus_UMA").with_suffix(".png")),
                structures=_structs,
            )

        summary_payload = {
            "root_out_dir": str(out_dir),
            "path_dir": str(tsroot),
            "path_module_dir": "tsopt_single",
            "pipeline_mode": "tsopt-only",
            "refine_path": bool(refine_path),
            "thresh": thresh,
            "thresh_post": thresh_post,
            "flatten": bool(flatten),
            "tsopt": do_tsopt,
            "thermo": do_thermo,
            "dft": do_dft,
            "opt_mode": tsopt_opt_mode_default,
            "mep_mode": "tsopt-only",
            "uma_model": None,
            "command": command_str,
            "charge": q_int,
            "spin": spin,
            "mep": {"n_images": 0, "n_segments": 1},
            "segments": summary.get("segments", []),
            "energy_diagrams": list(energy_diagrams),
            "post_segments": [segment_log],
            "key_files": {},
        }
        # Refresh summary.json with post_segments and key_output_files
        try:
            summary["post_segments"] = _json_safe([segment_log])
            # Rebuild key_output_files now that seg_01/ exists
            try:
                _kf: Dict[str, Any] = {}
                for _n, _d in [("summary.log", "Human-readable results summary"),
                               ("irc_plot_all.png", "Aggregated IRC plot")]:
                    if (out_dir / _n).exists():
                        _kf[_n] = _d
                _seg_parent = out_dir / SEGMENTS_DIRNAME
                for _child in sorted(_seg_parent.iterdir()) if _seg_parent.exists() else []:
                    if _child.is_dir() and _child.name.startswith("seg_"):
                        _kf[_child.name] = {"files": sorted(f.name for f in _child.iterdir() if f.is_file())}
                if _kf:
                    summary["key_output_files"] = _kf
            except Exception:
                pass
            with open(tsroot / "summary.json", "w") as f:
                json.dump(summary, f, indent=2, ensure_ascii=False)
            shutil.copy2(tsroot / "summary.json", out_dir / "summary.json")
        except Exception:
            pass

        try:
            write_summary_log(tsroot / "summary.log", summary_payload)
            shutil.copy2(tsroot / "summary.log", out_dir / "summary.log")
        except Exception as e:
            _echo(f"[write] WARNING: failed to write summary.log: {e}", err=True)

        try:
            for stem in (
                "energy_diagram_UMA",
                "energy_diagram_G_UMA",
                "energy_diagram_DFT",
                "energy_diagram_G_DFT_plus_UMA",
            ):
                src = tsroot / f"{stem}.png"
                if src.exists():
                    shutil.copy2(src, out_dir / f"{stem}_all.png")
        except Exception as e:
            _echo(f"[all] WARNING: failed to copy *_all diagrams: {e}", err=True)

        try:
            if irc_plot_path:
                irc_plot_src = Path(irc_plot_path)
                if irc_plot_src.exists():
                    shutil.copy2(irc_plot_src, out_dir / "irc_plot_all.png")
        except Exception as e:
            _echo(f"[all] WARNING: failed to copy irc_plot_all.png: {e}", err=True)

        _echo_section("====== [all] TSOPT-only pipeline finished successfully ======")
        _emit_final_summary(out_dir, time_start)
        return

    # Stage 1d: Optional scan (single-structure only) to build ordered pocket inputs
    pockets_for_path: List[Path]
    if is_single and has_scan:
        _echo_section("====== [all] Stage 1d — Staged scan on layered full-system PDB (single-structure mode) ======")
        ensure_dir(scan_dir)
        layered_pdb = Path(layered_inputs[0]).resolve()
        full_input_pdb = Path(input_paths[0]).resolve()
        # Use the layered full-system PDB for scan (no pocket index remapping needed)
        full_atom_meta = load_pdb_atom_metadata(full_input_pdb)
        # Honour --scan-one-based CLI toggle (None defaults to 1-based for backward compat).
        scan_one_based_use = True if scan_one_based is None else bool(scan_one_based)
        converted_scan_stages = _parse_scan_lists_literals(
            scan_lists_raw, atom_meta=full_atom_meta, one_based=scan_one_based_use,
        )
        scan_stage_literals: List[str] = []
        for stage in converted_scan_stages:
            scan_stage_literals.append(_format_scan_stage(stage))
        _echo("[all] Remapped --scan-lists indices from the full PDB to the pocket ordering.")
        scan_preopt_use = pre_opt if scan_preopt_override is None else bool(scan_preopt_override)
        scan_endopt_use = False if scan_endopt_override is None else bool(scan_endopt_override)
        scan_opt_mode_use = path_search_opt_mode

        scan_args: List[str] = [
            "-i", str(layered_pdb),
            "--parm", str(real_parm7_path),
            "-q", str(int(q_int)),
            "-m", str(int(spin)),
            "--out-dir", str(scan_dir),
            "--preopt" if scan_preopt_use else "--no-preopt",
            "--endopt" if scan_endopt_use else "--no-endopt",
            "--opt-mode", str(scan_opt_mode_use),
        ]
        scan_args.append("--detect-layer" if detect_layer else "--no-detect-layer")

        if dump_override_requested:
            scan_args.append("--dump" if dump else "--no-dump")

        # Forward the scan-indexing convention selected by --scan-one-based
        scan_args.append("--one-based" if scan_one_based_use else "--zero-based")

        _append_cli_arg(scan_args, "--max-step-size", scan_max_step_size)
        _append_cli_arg(scan_args, "--bias-k", scan_bias_k)
        _append_cli_arg(scan_args, "--relax-max-cycles", scan_relax_max_cycles)
        scan_args.append("--convert-files" if convert_files else "--no-convert-files")
        if thresh is not None:
            scan_args.extend(["--thresh", str(thresh)])
        if args_yaml is not None:
            scan_args.extend(["--config", str(args_yaml)])
        # Forward all converted --scan-lists (aligned to the pocket atom order)
        if scan_stage_literals:
            scan_args.append("--scan-lists")
            scan_args.extend(scan_stage_literals)

        from mlmm.workflows._all_helpers import append_backend_forwarding_args
        append_backend_forwarding_args(
            scan_args,
            backend=backend,
            embedcharge=embedcharge,
            embedcharge_cutoff=embedcharge_cutoff,
            embedcharge_explicit=embedcharge_explicit,
            link_atom_method=link_atom_method,
            mm_backend=mm_backend,
            use_cmap=use_cmap,
        )

        _echo_detail(
            f"[all] dispatch scan: input={layered_pdb.name}, "
            f"stages={len(scan_stage_literals)}, preopt={'yes' if scan_preopt_use else 'no'}, "
            f"endopt={'yes' if scan_endopt_use else 'no'}, out={scan_dir}"
        )
        _echo("[all] mlmm scan " + " ".join(scan_args))

        _run_cli_main("scan", _scan_cli.cli, scan_args, on_nonzero="raise", on_exception="raise", prefix="all")

        # Collect stage results — prefer XYZ (full precision), keep PDB as ref for topology
        stage_results: List[Path] = []
        stage_refs: List[Path] = []
        for st in sorted(scan_dir.glob("stage_*")):
            if not st.is_dir():
                continue
            xyz = st / "result.xyz"
            pdb = st / "result.pdb"
            if xyz.exists():
                stage_results.append(xyz.resolve())
                stage_refs.append(pdb.resolve() if pdb.exists() else layered_pdb)
            elif pdb.exists():
                stage_results.append(pdb.resolve())
                stage_refs.append(pdb.resolve())
        if not stage_results:
            raise click.ClickException("[all] No stage result files found under scan/.")
        _echo_detail("[all] Collected scan stage files:")
        for p in stage_results:
            _echo_detail(f"  - {p}")

        # Input series to path_search: [preopt result (if available), scan stage results ...]
        # When scan ran with --preopt, its optimized reactant geometry lives in
        # scan/preopt/result.xyz (full precision) or result.pdb.  Using this
        # avoids a redundant ~2000-cycle re-optimization inside path_search.
        preopt_xyz = scan_dir / "preopt" / "result.xyz"
        preopt_pdb = scan_dir / "preopt" / "result.pdb"
        if preopt_xyz.exists():
            init0_geom = preopt_xyz.resolve()
            init0_ref = layered_pdb          # layered PDB has authoritative B-factor layers
        elif preopt_pdb.exists():
            init0_geom = preopt_pdb.resolve()
            init0_ref = layered_pdb
        else:
            # No preopt output — fall back to original layered PDB
            init0_geom = layered_pdb
            init0_ref = layered_pdb
        pockets_for_path = [init0_geom] + stage_results
        refs_for_path = [init0_ref] + stage_refs
        _echo_detail(f"[all] Using scan initial endpoint: {init0_geom}")
    else:
        # Multi-structure standard route: use layered full-system PDBs
        pockets_for_path = list(layered_inputs)

    # --- Global pre-alignment for coordinate continuity across segments ---
    if not refine_path and len(pockets_for_path) >= 2:
        try:
            _echo("[all] Pre-aligning all input structures to first frame...")
            _align_dir = path_dir / "pre_align"
            ensure_dir(_align_dir)
            _bfs = read_bfactors_from_pdb(pockets_for_path[0])
            _fa = [i for i, bf in enumerate(_bfs) if bf >= 15.0]
            if _fa:
                _geoms = [geom_loader(str(p), coord_type="cart") for p in pockets_for_path]
                for _g in _geoms:
                    _g.freeze_atoms = np.array(_fa, dtype=int)
                _calc_kw: Dict[str, Any] = dict(
                    model_charge=q_int, model_mult=int(spin),
                    input_pdb=str(pockets_for_path[0]),
                    real_parm7=str(real_parm7_path),
                    use_bfactor_layers=True,
                    embedcharge=embedcharge,
                )
                if backend is not None:
                    _calc_kw["backend"] = backend
                if link_atom_method is not None:
                    _calc_kw["link_atom_method"] = link_atom_method
                if mm_backend is not None:
                    _calc_kw["mm_backend"] = mm_backend
                if use_cmap is not None:
                    _calc_kw["use_cmap"] = use_cmap
                _align_calc = _mlmm_calc(**_calc_kw)
                align_and_refine_sequence_inplace(
                    _geoms, shared_calc=_align_calc,
                    out_dir=_align_dir / "refine", verbose=True,
                )
                del _align_calc
                _new_pockets: List[Path] = []
                for _i, (_g, _orig) in enumerate(zip(_geoms, pockets_for_path)):
                    _xyz = _align_dir / f"{_i:03d}.xyz"
                    _xyz.write_text(_g.as_xyz() + "\n")
                    _pdb = _align_dir / f"{_i:03d}.pdb"
                    convert_xyz_to_pdb(_xyz, _orig, _pdb)
                    _new_pockets.append(_pdb)
                pockets_for_path = _new_pockets
                _echo("[all] Pre-alignment completed.")
        except Exception as e:
            _echo(
                f"[all] WARNING: Pre-alignment failed: {e}. "
                "Continuing with original files.",
                err=True,
            )

    # Stage 2: Path search on full-system layered PDBs
    if refine_path:
        _echo_section(f"====== [all] Stage 2/{stage_total} — MEP search on full-system layered PDBs (recursive GSM) ======")

        # Build path_search CLI args using *repeated* options (robust for Click)
        ps_args: List[str] = []

        # Inputs: single -i followed by all layered full-system PDBs
        ps_args.append("-i")
        for p in pockets_for_path:
            ps_args.append(str(p))

        # Charge & spin
        ps_args.extend(["-q", str(q_int)])
        ps_args.extend(["-m", str(int(spin))])
        ps_args.extend(["--parm", str(real_parm7_path)])
        # Layered PDBs have B-factors → --detect-layer is True by default in
        # the option declaration. Honor the user's explicit `--no-detect-layer`
        # by forwarding the chosen toggle instead of hardcoding `--detect-layer`.
        ps_args.append("--detect-layer" if detect_layer else "--no-detect-layer")

        # Nodes, cycles, climb, optimizer, dump, out-dir, preopt, args-yaml
        ps_args.extend(["--max-nodes", str(int(max_nodes))])
        ps_args.extend(["--max-cycles", str(int(max_cycles))])
        ps_args.append("--climb" if climb else "--no-climb")
        ps_args.extend(["--opt-mode", str(path_search_opt_mode)])
        ps_args.append("--dump" if dump else "--no-dump")
        ps_args.extend(["--out-dir", str(path_dir)])
        ps_args.append("--preopt" if pre_opt else "--no-preopt")
        ps_args.append("--convert-files" if convert_files else "--no-convert-files")
        if thresh is not None:
            ps_args.extend(["--thresh", str(thresh)])
        if args_yaml is not None:
            ps_args.extend(["--config", str(args_yaml)])

        # Provide --ref-pdb for topology/B-factor info (one per input)
        # MUST use layered PDBs (with B-factor layer info) so that downstream
        # PDB conversion preserves ML/MovableMM/FrozenMM layer encoding.
        ps_args.append("--ref-pdb")
        if is_single and has_scan:
            # single+scan: use refs_for_path which maps to each pocket (XYZ→PDB ref)
            for ref in refs_for_path:
                ps_args.append(str(ref))
        else:
            for lp in layered_inputs:
                ps_args.append(str(lp))

        from mlmm.workflows._all_helpers import append_backend_forwarding_args
        append_backend_forwarding_args(
            ps_args,
            backend=backend,
            embedcharge=embedcharge,
            embedcharge_cutoff=embedcharge_cutoff,
            embedcharge_explicit=embedcharge_explicit,
            link_atom_method=link_atom_method,
            mm_backend=mm_backend,
            use_cmap=use_cmap,
        )

        _echo_detail(
            f"[all] dispatch path-search: inputs={len(pockets_for_path)}, "
            f"mode=recursive-gsm, preopt={'yes' if pre_opt else 'no'}, "
            f"detect_layer={'yes' if detect_layer else 'no'}, out={path_dir}"
        )
        _echo("[all] mlmm path-search " + " ".join(ps_args))

        _run_cli_main("path_search", _path_search.cli, ps_args, on_nonzero="raise", on_exception="raise", prefix="all")
    else:
        # --no-refine-path: run path-opt GSM between each adjacent pair and concatenate
        _echo_section(f"====== [all] Stage 2/{stage_total} — MEP path-opt on full-system layered PDBs (single-pass GSM per pair) ======")

        if len(pockets_for_path) < 2:
            raise click.ClickException("[all] Need at least two structures for path-opt MEP concatenation.")

        ensure_dir(path_dir)
        combined_blocks: List[str] = []
        path_opt_segments: List[Dict[str, Any]] = []

        for pair_idx in range(len(pockets_for_path) - 1):
            p_left = pockets_for_path[pair_idx]
            p_right = pockets_for_path[pair_idx + 1]
            seg_tag = f"seg_{pair_idx:02d}"
            seg_out = path_dir / f"{seg_tag}_mep"
            ensure_dir(seg_out)

            po_args: List[str] = [
                "-i", str(p_left), str(p_right),
                "-q", str(q_int),
                "-m", str(int(spin)),
                "--parm", str(real_parm7_path),
                "--max-nodes", str(int(max_nodes)),
                "--max-cycles", str(int(max_cycles)),
            ]
            # When the single+scan route handed over XYZ pockets, forward the
            # matching layered-template ref PDBs so path-opt can overlay the XYZ
            # coordinates onto full ML/MM topology (path-search receives these too).
            if is_single and has_scan:
                po_args.extend([
                    "--ref-pdb", str(refs_for_path[pair_idx]),
                    "--ref-pdb", str(refs_for_path[pair_idx + 1]),
                ])
            # Forward the chosen --detect-layer/--no-detect-layer toggle
            # (default True). Hardcoded "--detect-layer" silently overrode
            # user's `--no-detect-layer` request.
            po_args.append("--detect-layer" if detect_layer else "--no-detect-layer")
            po_args.append("--climb" if climb else "--no-climb")
            po_args.append("--dump" if dump else "--no-dump")
            po_args.extend(["--out-dir", str(seg_out)])
            po_args.append("--preopt" if pre_opt else "--no-preopt")
            po_args.append("--convert-files" if convert_files else "--no-convert-files")
            if thresh is not None:
                po_args.extend(["--thresh", str(thresh)])
            from mlmm.workflows._all_helpers import append_backend_forwarding_args
            append_backend_forwarding_args(
                po_args,
                backend=backend,
                embedcharge=embedcharge,
                embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method,
                mm_backend=mm_backend,
                use_cmap=use_cmap,
                args_yaml=args_yaml,
            )

            _echo_detail(
                f"[all] dispatch path-opt pair {pair_idx + 1}/{len(pockets_for_path) - 1}: "
                f"preopt={'yes' if pre_opt else 'no'}, climb={'yes' if climb else 'no'}, out={seg_out}"
            )
            _echo(f"[all] mlmm path-opt " + " ".join(po_args))
            _run_cli_main("path_opt", _path_opt.cli, po_args, on_nonzero="raise", on_exception="raise", prefix="all")

            # --- Post-processing per segment ---
            seg_trj = seg_out / "final_geometries_trj.xyz"
            if not seg_trj.exists():
                raise click.ClickException(
                    f"[all] path-opt segment {pair_idx} did not produce final_geometries_trj.xyz"
                )

            # Copy per-segment trajectory to path_dir
            try:
                seg_mep_trj = path_dir / f"mep_seg_{pair_idx:02d}_trj.xyz"
                shutil.copy2(seg_trj, seg_mep_trj)
                if pockets_for_path[0].suffix.lower() == ".pdb":
                    _path_search._maybe_convert_to_pdb(
                        seg_mep_trj,
                        ref_pdb_path=pockets_for_path[0],
                        out_path=path_dir / f"mep_seg_{pair_idx:02d}.pdb",
                    )
            except Exception as e:
                _echo(
                    f"[all] WARNING: failed to emit per-segment trajectory copies for segment {pair_idx:02d}: {e}",
                    err=True,
                )

            # Mirror HEI artifacts
            hei_src = seg_out / "hei.xyz"
            if hei_src.exists():
                try:
                    shutil.copy2(hei_src, path_dir / f"hei_seg_{pair_idx:02d}.xyz")
                    hei_pdb_src = seg_out / "hei.pdb"
                    if hei_pdb_src.exists():
                        shutil.copy2(hei_pdb_src, path_dir / f"hei_seg_{pair_idx:02d}.pdb")
                except Exception as e:
                    _echo(
                        f"[all] WARNING: failed to prepare HEI artifacts for segment {pair_idx:02d}: {e}",
                        err=True,
                    )

            # Parse trajectory blocks for concatenation and energy extraction
            raw_blocks = read_xyz_as_blocks(seg_trj, strict=True)
            blocks = ["\n".join(b) + "\n" for b in raw_blocks]
            if not blocks:
                raise click.ClickException(
                    f"[all] No frames read from path-opt segment {pair_idx} trajectory: {seg_trj}"
                )
            # Skip duplicate first frame for subsequent segments
            if pair_idx > 0:
                blocks = blocks[1:]
            combined_blocks.extend(blocks)

            # Extract energies from trajectory comment lines
            energies_seg: List[float] = []
            for blk in raw_blocks:
                E = np.nan
                if len(blk) >= 2:
                    try:
                        E = float(blk[1].split()[0])
                    except Exception:
                        E = np.nan
                energies_seg.append(E)

            # Parse first/last frame coordinates for bond-change detection
            first_last = None
            try:
                first_last = xyz_blocks_first_last(raw_blocks, path=seg_trj)
            except Exception as e:
                _echo(
                    f"[all] WARNING: failed to parse first/last frames for segment {pair_idx:02d}: {e}",
                    err=True,
                )

            path_opt_segments.append(
                {
                    "tag": seg_tag,
                    "energies": energies_seg,
                    "traj": seg_trj,
                    "inputs": (p_left, p_right),
                    "first_last": first_last,
                }
            )

        # --- Concatenated MEP trajectory ---
        final_trj = path_dir / "mep_trj.xyz"
        try:
            final_trj.write_text("".join(combined_blocks), encoding="utf-8")
            _echo(f"[all] Wrote concatenated MEP trajectory: {final_trj}", narrative=True)
        except Exception as e:
            raise click.ClickException(f"[all] Failed to write concatenated MEP: {e}")

        # Energy plot for concatenated trajectory
        try:
            run_trj2fig(final_trj, [path_dir / "mep_plot.png"], unit="kcal", reference="init", reverse_x=False)
            close_matplotlib_figures()
            _echo_detail(f"[plot] Saved energy plot → '{path_dir / 'mep_plot.png'}'")
        except Exception as e:
            _echo(f"[plot] WARNING: Failed to plot concatenated MEP: {e}", err=True)

        # PDB conversion of concatenated trajectory
        try:
            if pockets_for_path[0].suffix.lower() == ".pdb":
                mep_pdb_path = path_dir / "mep.pdb"
                _path_search._maybe_convert_to_pdb(
                    final_trj, ref_pdb_path=pockets_for_path[0], out_path=mep_pdb_path
                )
                if mep_pdb_path.exists():
                    shutil.copy2(mep_pdb_path, out_dir / mep_pdb_path.name)
                    _echo_detail(f"[all] Copied concatenated MEP PDB → {out_dir / mep_pdb_path.name}")
        except Exception as e:
            _echo(
                f"[all] WARNING: Failed to convert/copy concatenated MEP to PDB: {e}",
                err=True,
            )

        # --- Energy diagram ---
        energy_diagrams_po: List[Dict[str, Any]] = []
        try:
            labels = _build_global_segment_labels(len(path_opt_segments))
            energies_chain: List[float] = []
            for si, seg_info in enumerate(path_opt_segments):
                Es = [float(x) for x in seg_info.get("energies", [])]
                if not Es:
                    continue
                if si == 0:
                    energies_chain.append(Es[0])
                energies_chain.append(float(np.nanmax(Es)))
                energies_chain.append(Es[-1])
            if labels and energies_chain and len(labels) == len(energies_chain):
                title_note = "(GSM; all segments)" if len(path_opt_segments) > 1 else "(GSM)"
                diag_payload = _write_segment_energy_diagram(
                    path_dir / "energy_diagram_MEP",
                    labels=labels,
                    energies_eh=energies_chain,
                    title_note=title_note,
                )
                if diag_payload:
                    energy_diagrams_po.append(diag_payload)
        except Exception as e:
            _echo(f"[diagram] WARNING: Failed to build GSM diagram for path-opt branch: {e}", err=True)

        # --- Bond change detection and summary.json ---
        segments_summary: List[Dict[str, Any]] = []
        bond_cfg = dict(_path_search.BOND_KW)
        for seg_idx, info in enumerate(path_opt_segments):
            Es = [float(x) for x in info.get("energies", []) if np.isfinite(x)]
            if not Es:
                continue
            barrier = (max(Es) - Es[0]) * AU2KCALPERMOL
            delta = (Es[-1] - Es[0]) * AU2KCALPERMOL
            bond_summary = ""
            try:
                first_last = info.get("first_last")
                if first_last:
                    elems, c_first, c_last = first_last
                else:
                    elems, c_first, c_last = read_xyz_first_last(Path(info["traj"]))
                gL = _geom_from_angstrom(elems, c_first, [])
                gR = _geom_from_angstrom(elems, c_last, [])
                changed, bond_summary = _path_search._has_bond_change(gL, gR, bond_cfg)
                if not changed:
                    bond_summary = "(no covalent changes detected)"
            except Exception as e:
                _echo(
                    f"[all] WARNING: Failed to detect bond changes for segment {seg_idx:02d}: {e}",
                    err=True,
                )
                bond_summary = "(no covalent changes detected)"

            segments_summary.append(
                {
                    "index": seg_idx,
                    "tag": info.get("tag", f"seg_{seg_idx:02d}"),
                    "kind": "seg",
                    "barrier_kcal": float(barrier),
                    "delta_kcal": float(delta),
                    "bond_changes": bond_summary,
                }
            )

        po_summary: Dict[str, Any] = {
            "out_dir": str(path_dir),
            "n_images": len(read_xyz_as_blocks(final_trj)),
            "n_segments": len(segments_summary),
            "segments": segments_summary,
        }
        if energy_diagrams_po:
            po_summary["energy_diagrams"] = list(energy_diagrams_po)
        _enrich_summary(
            po_summary,
            version="",
            pipeline_mode="path-search" if refine_path else "path-opt",
            out_dir=out_dir,
            mlip_backend=backend or "unknown",
            charge=q_int,
            spin=spin,
            command=command_str,
            config={
                "refine_path": bool(refine_path),
                "tsopt": do_tsopt,
                "thermo": do_thermo,
                "dft": do_dft,
                "opt_mode": tsopt_opt_mode_default,
                "mep_mode": "gsm",
            },
        )
        try:
            with open(path_dir / "summary.json", "w") as f:
                json.dump(po_summary, f, indent=2, ensure_ascii=False)
            _echo_detail(f"[write] Wrote '{path_dir / 'summary.json'}'.")
        except Exception as e:
            _echo(f"[write] WARNING: Failed to write summary.json for path-opt branch: {e}", err=True)

        # Copy key outputs to out_dir root
        try:
            for name in ("mep_plot.png", "energy_diagram_MEP.png", "summary.json"):
                src = path_dir / name
                if src.exists():
                    shutil.copy2(src, out_dir / name)
            for ext in ("_trj.xyz", ".xyz"):
                src = path_dir / f"mep{ext}"
                if src.exists():
                    shutil.copy2(src, out_dir / src.name)
        except Exception as e:
            _echo(f"[all] WARNING: Failed to relocate path-opt summary files: {e}", err=True)

    # Stage 3: Merge (performed by path_search when --ref-pdb was supplied)
    _echo_section(f"====== [all] Stage 3/{stage_total} — Core MEP outputs ======")
    _echo_detail(f"[all] Final products can be found under: {out_dir}")
    _echo_detail("  - mep_trj.xyz              (concatenated MEP trajectory)")
    _echo_detail("  - mep.pdb                  (PDB conversion, if input was .pdb)")
    _echo_detail("  - summary.json             (segment barriers, ΔE, bond changes)")
    _echo_detail("  - mep_plot.png / energy_diagram_MEP.png / summary.log")
    _echo_detail(f"[all] Raw per-segment MEP-engine files stay under: {path_dir}")
    _echo_detail("  - mep_seg_XX_trj.xyz       (per-segment trajectories)")
    _echo_detail("  - hei_seg_XX.xyz/.pdb      (HEI per segment)")
    _echo_section("====== [all] Core MEP pipeline finished successfully ======")

    summary_json_path = path_dir / "summary.json"
    summary_loaded = {}
    if summary_json_path.exists():
        try:
            summary_loaded = json.loads(summary_json_path.read_text(encoding="utf-8")) or {}
        except Exception:
            summary_loaded = {}
    summary: Dict[str, Any] = summary_loaded if isinstance(summary_loaded, dict) else {}
    segments = _read_summary(summary_json_path)
    energy_diagrams: List[Dict[str, Any]] = []
    existing_diagrams = summary.get("energy_diagrams", [])
    if isinstance(existing_diagrams, list):
        energy_diagrams.extend(existing_diagrams)

    def _copy_path_outputs_to_root() -> None:
        # Thin wrapper preserving the closure-captured helper signature.
        # Body extracted to mlmm.workflows._all_helpers so it is unit-
        # testable and the cli() body shrinks one slot.
        from mlmm.workflows._all_helpers import copy_path_outputs_to_root
        copy_path_outputs_to_root(
            path_dir,
            out_dir,
            warn_fn=lambda msg: _echo(msg, err=True),
        )

    def _write_pipeline_summary_log(post_segment_logs: Sequence[Dict[str, Any]]) -> None:
        # Payload assembly extracted to mlmm.workflows._all_helpers; the
        # I/O wrapper here keeps the original closure capture + error
        # routing semantics so callers do not change.
        from mlmm.workflows._all_helpers import build_pipeline_summary_payload
        try:
            summary_payload = build_pipeline_summary_payload(
                out_dir=out_dir,
                path_dir=path_dir,
                summary=summary,
                refine_path=refine_path,
                thresh=thresh,
                thresh_post=thresh_post,
                flatten=flatten,
                do_tsopt=do_tsopt,
                do_thermo=do_thermo,
                do_dft=do_dft,
                opt_mode_norm=opt_mode_norm,
                opt_mode_post=opt_mode_post,
                command_str=command_str,
                q_int=q_int,
                spin=spin,
                post_segment_logs=post_segment_logs,
            )
            write_summary_log(path_dir / "summary.log", summary_payload)
            _copy_path_outputs_to_root()
        except (OSError, KeyError, ValueError, TypeError) as e:
            _echo(f"[write] WARNING: Failed to write summary.log: {e}", err=True)

    # Optional Stage 4: TSOPT / THERMO / DFT (per reactive segment)
    if not (do_tsopt or do_thermo or do_dft):
        _write_pipeline_summary_log([])
        # Elapsed time
        _emit_final_summary(out_dir, time_start)
        return

    _echo_section(f"====== [all] Stage 4/{stage_total} — Post-processing per reactive segment ======")

    # Use segment summary from path_search / path-opt
    if not segments:
        _echo("[post] No segments found in summary; nothing to do.", narrative=True)
        _write_pipeline_summary_log([])
        _emit_final_summary(out_dir, time_start)
        return

    # Iterate only bond-change segments (kind='seg' and bond_changes not empty and not '(no covalent...)')
    reactive = [s for s in segments if (s.get("kind", "seg") == "seg" and str(s.get("bond_changes", "")).strip() and str(s.get("bond_changes", "")).strip() != "(no covalent changes detected)")]
    if not reactive:
        _echo("[post] No bond-change segments. Skipping TS/thermo/DFT.", narrative=True)
        _write_pipeline_summary_log([])
        _emit_final_summary(out_dir, time_start)
        return

    post_segment_logs: List[Dict[str, Any]] = []
    tsopt_seg_energies: List[Tuple[float, float, float]] = []
    g_uma_seg_energies: List[Tuple[float, float, float]] = []
    dft_seg_energies: List[Tuple[float, float, float]] = []
    g_dftuma_seg_energies: List[Tuple[float, float, float]] = []
    irc_trj_for_all: List[Tuple[Path, bool]] = []

    # For each reactive segment
    for s in reactive:
        seg_idx = int(s.get("index", 0) or 0)
        seg_tag = s.get("tag", f"seg_{seg_idx:02d}")
        _echo_section(f"--- [post] seg_{seg_idx:02d} ({seg_tag}) ---")

        seg_root = path_dir  # MEP-engine scratch root (hei_seg_/mep_seg_ live here, under _work/)
        seg_dir = out_dir / SEGMENTS_DIRNAME / f"seg_{seg_idx:02d}"  # per-segment deliverables
        ensure_dir(seg_dir)

        # HEI pocket file prepared by path_search (only for bond-change segments)
        hei_pocket_pdb = seg_root / f"hei_seg_{seg_idx:02d}.pdb"
        if not hei_pocket_pdb.exists():
            _echo(f"[post] WARNING: HEI pocket PDB not found for segment {seg_idx:02d}; skipping TSOPT.", err=True)
            continue

        # 4.1 TS optimization (optional; still needed to drive IRC & diagrams)
        if do_tsopt:
            ts_pdb, g_ts = _run_tsopt_on_hei(
                hei_pocket_pdb,
                q_int,
                spin,
                real_parm7_path,
                ml_region_pdb,
                detect_layer,
                args_yaml,
                seg_dir,
                tsopt_opt_mode_default,
                overrides=tsopt_overrides,
                backend=backend,
                embedcharge=embedcharge,
                embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method,
                mm_backend=mm_backend,
                use_cmap=use_cmap,
                ref_pdb=layered_inputs[0] if layered_inputs else None,
            )
        else:
            # If TSOPT off: use the GSM HEI (pocket) as TS geometry
            ts_pdb = hei_pocket_pdb
            g_ts = geom_loader(ts_pdb, coord_type="cart")
            _hei_calc_kwargs = dict(
                model_charge=int(q_int),
                model_mult=int(spin),
                input_pdb=str(ts_pdb),
                real_parm7=str(real_parm7_path),
                model_pdb=str(ml_region_pdb),
                use_bfactor_layers=detect_layer,
                backend=backend,
                embedcharge=embedcharge,
            )
            if link_atom_method is not None:
                _hei_calc_kwargs["link_atom_method"] = link_atom_method
            if mm_backend is not None:
                _hei_calc_kwargs["mm_backend"] = mm_backend
            if use_cmap is not None:
                _hei_calc_kwargs["use_cmap"] = use_cmap
            calc = _mlmm_calc(**_hei_calc_kwargs)
            g_ts.set_calculator(calc); _ = float(g_ts.energy)

        # 4.2 EulerPC IRC & mapping to (left,right)
        irc_plot_path = None
        irc_trj_path = None
        irc_res = _irc_and_match(seg_idx=seg_idx,
                                 seg_dir=seg_dir,
                                 mep_dir=path_dir,
                                 ref_pdb_for_seg=ts_pdb,
                                 seg_pocket_pdb=hei_pocket_pdb,
                                 g_ts=g_ts,
                                 q_int=q_int,
                                 spin=spin,
                                 real_parm7=real_parm7_path,
                                 model_pdb=ml_region_pdb,
                                 detect_layer=detect_layer,
                                 backend=backend,
                                 embedcharge=embedcharge,
                                 embedcharge_cutoff=embedcharge_cutoff,
                                 embedcharge_explicit=embedcharge_explicit,
                                 link_atom_method=link_atom_method,
                                 mm_backend=mm_backend,
                                 use_cmap=use_cmap,
                                 args_yaml=args_yaml)
        irc_plot_path = irc_res.get("irc_plot")
        irc_trj_path = irc_res.get("irc_trj")
        if irc_trj_path:
            try:
                irc_trj_for_all.append((Path(irc_trj_path), bool(irc_res.get("reverse_irc", False))))
            except Exception:
                logger.debug("Failed to append IRC trajectory path", exc_info=True)

        gL = irc_res["left_min_geom"]
        gR = irc_res["right_min_geom"]
        gT = irc_res["ts_geom"]
        # Save IRC endpoints (XYZ primary), run endpoint-opt, then save optimized structures
        struct_dir = seg_dir / "structures"
        ensure_dir(struct_dir)
        xL_irc, pL_irc = _save_single_geom_for_tools(gL, hei_pocket_pdb, struct_dir, "reactant_irc")
        xT, pT         = _save_single_geom_for_tools(gT, hei_pocket_pdb, struct_dir, "ts")
        xR_irc, pR_irc = _save_single_geom_for_tools(gR, hei_pocket_pdb, struct_dir, "product_irc")

        endpoint_opt_dir = seg_dir / "endpoint_opt"
        ensure_dir(endpoint_opt_dir)

        # Map IRC left/right Hessians → R/P endpoint
        # When reverse_irc is True, _irc_and_match swapped left/right to match GSM endpoints,
        # so "irc_left" (=forward) now corresponds to gR and "irc_right" (=backward) to gL.
        from mlmm.io.hessian_cache import load as _hess_load, store as _hess_store, clear as _clear_hess_cache
        _reversed = bool(irc_res.get("reverse_irc", False))
        _left_hk  = "irc_right" if _reversed else "irc_left"
        _right_hk = "irc_left"  if _reversed else "irc_right"

        _c = _hess_load(_left_hk)
        if _c:
            _hess_store("irc_endpoint", _c["hessian"], active_dofs=_c.get("active_dofs"), meta=_c.get("meta"))
        try:
            gL, _ = _run_opt_for_state(
                pL_irc, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                endpoint_opt_dir / "R", args_yaml, endpoint_opt_mode_default,
                convert_files=convert_files,
                backend=backend,
                embedcharge=embedcharge,
                embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method,
                mm_backend=mm_backend,
                use_cmap=use_cmap,
                thresh=thresh_post,
                xyz_path=xL_irc,
            )
        except Exception as e:
            _echo(
                f"[post] WARNING: Reactant endpoint optimization failed for segment {seg_idx:02d}: {e}",
                err=True,
            )

        _c = _hess_load(_right_hk)
        if _c:
            _hess_store("irc_endpoint", _c["hessian"], active_dofs=_c.get("active_dofs"), meta=_c.get("meta"))
        try:
            gR, _ = _run_opt_for_state(
                pR_irc, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                endpoint_opt_dir / "P", args_yaml, endpoint_opt_mode_default,
                convert_files=convert_files,
                backend=backend,
                embedcharge=embedcharge,
                embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method,
                mm_backend=mm_backend,
                use_cmap=use_cmap,
                thresh=thresh_post,
                xyz_path=xR_irc,
            )
        except Exception as e:
            _echo(
                f"[post] WARNING: Product endpoint optimization failed for segment {seg_idx:02d}: {e}",
                err=True,
            )
        shutil.rmtree(endpoint_opt_dir, ignore_errors=True)
        _echo_detail("[endpoint-opt] Clean endpoint-opt working dir.")

        xL, pL = _save_single_geom_for_tools(gL, hei_pocket_pdb, struct_dir, "reactant")
        xR, pR = _save_single_geom_for_tools(gR, hei_pocket_pdb, struct_dir, "product")

        # Copy R/TS/P structures to out_dir/seg_XX/
        try:
            _state_structs = {"R": pL, "TS": pT, "P": pR}
            _seg_out = _copy_structures_to_seg_dir(
                _state_structs, out_dir, seg_idx, ".pdb",
            )
            _echo(f"[all] Wrote R/TS/P for segment {seg_idx:02d} → {_seg_out}", narrative=True)
        except Exception as e:
            _echo(f"[all] WARNING: Failed to copy R/TS/P structures for segment {seg_idx:02d}: {e}", err=True)

        # 4.3 Segment-level energy diagram from UMA (R,TS,P)
        eR = float(gL.energy)
        eT = float(gT.energy)
        eP = float(gR.energy)
        tsopt_seg_energies.append((eR, eT, eP))
        uma_prefix = seg_dir / "energy_diagram_UMA"
        _write_segment_energy_diagram(
            uma_prefix,
            labels=["R", f"TS{seg_idx}", "P"],
            energies_eh=[eR, eT, eP],
            title_note="(UMA, TSOPT/IRC)",
        )

        # ── Release GPU memory before freq/thermo/DFT ──
        for _g in (gL, gR, gT):
            if _g is not None and hasattr(_g, "calculator"):
                _g.calculator = None
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # 4.4 Thermochemistry (UMA freq) and Gibbs diagram
        thermo_payloads: Dict[str, Dict[str, Any]] = {}
        GR = GT = GP = None
        freq_seg_root = _resolve_override_dir(seg_dir / "freq", freq_out_dir)
        dft_seg_root = _resolve_override_dir(seg_dir / "dft", dft_out_dir)

        if do_thermo:
            _echo_detail(f"[thermo] Segment {seg_idx:02d}: freq on TS/R/P")
            tT = _run_freq_for_state(
                pT, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                freq_seg_root / "TS", args_yaml, overrides=freq_overrides,
                backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xT,
            )
            _clear_hess_cache()  # TS Hessian consumed; R/P need exact computation
            tR = _run_freq_for_state(
                pL, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                freq_seg_root / "R", args_yaml, overrides=freq_overrides,
                backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xL,
            )
            tP = _run_freq_for_state(
                pR, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                freq_seg_root / "P", args_yaml, overrides=freq_overrides,
                backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xR,
            )
            thermo_payloads = {"R": tR, "TS": tT, "P": tP}
            _echo_energy_triplet(
                "thermo",
                seg_idx,
                "ZPE correction",
                _scale_energy_values(
                    _thermo_correction_values(thermo_payloads, "zpe_correction_ha"),
                    AU2KCALPERMOL,
                ),
                unit="kcal/mol",
                precision=2,
            )
            _echo_energy_triplet(
                "thermo",
                seg_idx,
                "thermal energy correction",
                _scale_energy_values(
                    _thermo_correction_values(thermo_payloads, "thermal_correction_energy_ha"),
                    AU2KCALPERMOL,
                ),
                unit="kcal/mol",
                precision=2,
            )
            _echo_energy_triplet(
                "thermo",
                seg_idx,
                "thermal free-energy correction",
                _scale_energy_values(
                    _thermo_correction_values(thermo_payloads, "thermal_correction_free_energy_ha"),
                    AU2KCALPERMOL,
                ),
                unit="kcal/mol",
                precision=2,
            )
            try:
                GR = float(tR.get("sum_EE_and_thermal_free_energy_ha", eR))
                GT = float(tT.get("sum_EE_and_thermal_free_energy_ha", eT))
                GP = float(tP.get("sum_EE_and_thermal_free_energy_ha", eP))
                g_uma_seg_energies.append((GR, GT, GP))
                _g_rel = _relative_energy_values_kcal({"R": GR, "TS": GT, "P": GP})
                if _g_rel is not None:
                    _echo_energy_triplet(
                        "thermo",
                        seg_idx,
                        "G_UMA relative",
                        _g_rel,
                        unit="kcal/mol",
                        precision=2,
                    )
                _write_segment_energy_diagram(
                    seg_dir / "energy_diagram_G_UMA",
                    labels=["R", f"TS{seg_idx}", "P"],
                    energies_eh=[GR, GT, GP],
                    title_note="(Gibbs, UMA)",
                    ylabel="ΔG (kcal/mol)",
                )
            except Exception as e:
                _echo(f"[thermo] WARNING: failed to build Gibbs diagram: {e}", err=True)

        # 4.5 DFT single-point and (optionally) DFT//UMA Gibbs
        eR_dft = eT_dft = eP_dft = None
        GR_dftUMA = GT_dftUMA = GP_dftUMA = None
        if do_dft:
            _echo_detail(f"[dft] Segment {seg_idx:02d}: DFT on R/TS/P")
            dR = _run_dft_for_state(
                pL, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                dft_seg_root / "R", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
                backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xL,
            )
            dT = _run_dft_for_state(
                pT, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                dft_seg_root / "TS", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
                backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xT,
            )
            dP = _run_dft_for_state(
                pR, q_int, spin, real_parm7_path, ml_region_pdb, detect_layer,
                dft_seg_root / "P", args_yaml, func_basis=dft_func_basis_use, overrides=dft_overrides,
                backend=backend, embedcharge=embedcharge, embedcharge_cutoff=embedcharge_cutoff,
                embedcharge_explicit=embedcharge_explicit,
                link_atom_method=link_atom_method, mm_backend=mm_backend, use_cmap=use_cmap, xyz_path=xR,
            )
            try:
                eR_dft = float(dR.get("energy", {}).get("hartree", np.nan) if dR else np.nan)
                eT_dft = float(dT.get("energy", {}).get("hartree", np.nan) if dT else np.nan)
                eP_dft = float(dP.get("energy", {}).get("hartree", np.nan) if dP else np.nan)
                if all(map(np.isfinite, [eR_dft, eT_dft, eP_dft])):
                    _dft_values = {"R": eR_dft, "TS": eT_dft, "P": eP_dft}
                    _echo_energy_triplet(
                        "dft",
                        seg_idx,
                        "E_DFT",
                        _dft_values,
                        unit="Hartree",
                        precision=8,
                    )
                    _dft_rel = _relative_energy_values_kcal(_dft_values)
                    if _dft_rel is not None:
                        _echo_energy_triplet(
                            "dft",
                            seg_idx,
                            "E_DFT relative",
                            _dft_rel,
                            unit="kcal/mol",
                            precision=2,
                        )
                    _dft_mlmm_values = {
                        "R": _dft_total_mlmm_energy_ha(dR),
                        "TS": _dft_total_mlmm_energy_ha(dT),
                        "P": _dft_total_mlmm_energy_ha(dP),
                    }
                    _echo_energy_triplet(
                        "dft",
                        seg_idx,
                        "E_total ML(dft)/MM",
                        _dft_mlmm_values,
                        unit="Hartree",
                        precision=8,
                    )
                    _dft_mlmm_rel = _relative_energy_values_kcal(_dft_mlmm_values)
                    if _dft_mlmm_rel is not None:
                        _echo_energy_triplet(
                            "dft",
                            seg_idx,
                            "E_total ML(dft)/MM relative",
                            _dft_mlmm_rel,
                            unit="kcal/mol",
                            precision=2,
                        )
                    dft_seg_energies.append((eR_dft, eT_dft, eP_dft))
                    _write_segment_energy_diagram(
                        seg_dir / "energy_diagram_DFT",
                        labels=["R", f"TS{seg_idx}", "P"],
                        energies_eh=[eR_dft, eT_dft, eP_dft],
                        title_note=f"({dft_method_fallback})",
                    )
                else:
                    _echo("[dft] WARNING: some DFT energies missing; diagram skipped.", err=True)
            except Exception as e:
                _echo(f"[dft] WARNING: failed to build DFT diagram: {e}", err=True)

            # DFT//UMA thermal Gibbs (E_DFT + ΔG_therm(UMA))
            if do_thermo:
                try:
                    dG_R = float(thermo_payloads.get("R", {}).get("thermal_correction_free_energy_ha", 0.0))
                    dG_T = float(thermo_payloads.get("TS", {}).get("thermal_correction_free_energy_ha", 0.0))
                    dG_P = float(thermo_payloads.get("P", {}).get("thermal_correction_free_energy_ha", 0.0))
                    eR_dft = float(dR.get("energy", {}).get("hartree", eR) if dR else eR)
                    eT_dft = float(dT.get("energy", {}).get("hartree", eT) if dT else eT)
                    eP_dft = float(dP.get("energy", {}).get("hartree", eP) if dP else eP)
                    GR_dftUMA = eR_dft + dG_R
                    GT_dftUMA = eT_dft + dG_T
                    GP_dftUMA = eP_dft + dG_P
                    g_dftuma_seg_energies.append((GR_dftUMA, GT_dftUMA, GP_dftUMA))
                    _g_dftuma_rel = _relative_energy_values_kcal(
                        {"R": GR_dftUMA, "TS": GT_dftUMA, "P": GP_dftUMA}
                    )
                    if _g_dftuma_rel is not None:
                        _echo_energy_triplet(
                            "dft//uma",
                            seg_idx,
                            "G_DFT+thermo relative",
                            _g_dftuma_rel,
                            unit="kcal/mol",
                            precision=2,
                        )
                    _write_segment_energy_diagram(
                        seg_dir / "energy_diagram_G_DFT_plus_UMA",
                        labels=["R", f"TS{seg_idx}", "P"],
                        energies_eh=[GR_dftUMA, GT_dftUMA, GP_dftUMA],
                        title_note="(Gibbs, DFT//UMA)",
                        ylabel="ΔG (kcal/mol)",
                    )
                except Exception as e:
                    _echo(f"[dft//uma] WARNING: failed to build DFT//UMA Gibbs diagram: {e}", err=True)

        segment_log: Dict[str, Any] = {
            "index": seg_idx,
            "tag": seg_tag,
            "kind": s.get("kind", "seg"),
            "bond_changes": s.get("bond_changes", ""),
            "mep_barrier_kcal": s.get("barrier_kcal"),
            "mep_delta_kcal": s.get("delta_kcal"),
            "post_dir": str(seg_dir),
        }
        if irc_plot_path:
            segment_log["irc_plot"] = str(irc_plot_path)
        if irc_trj_path:
            segment_log["irc_traj"] = str(irc_trj_path)
        if do_thermo:
            n_imag = None
            try:
                n_imag = int(thermo_payloads.get("TS", {}).get("num_imag_freq"))
            except Exception:
                n_imag = None
            if n_imag is not None:
                segment_log["ts_imag"] = {"n_imag": n_imag}
        from mlmm.workflows._all_helpers import build_energy_level_dict
        _structs_seg = {"R": pL, "TS": pT, "P": pR}
        segment_log["uma"] = build_energy_level_dict(
            labels=["R", "TS", "P"],
            energies_au=[eR, eT, eP],
            ref_energy=eR,
            au_to_kcal=AU2KCALPERMOL,
            diagram_path=str((seg_dir / "energy_diagram_UMA").with_suffix(".png")),
            structures=_structs_seg,
        )
        if GR is not None and GT is not None and GP is not None:
            segment_log["gibbs_uma"] = build_energy_level_dict(
                labels=["R", "TS", "P"],
                energies_au=[GR, GT, GP],
                ref_energy=GR,
                au_to_kcal=AU2KCALPERMOL,
                diagram_path=str((seg_dir / "energy_diagram_G_UMA").with_suffix(".png")),
                structures=_structs_seg,
            )
        if eR_dft is not None and eT_dft is not None and eP_dft is not None and all(
            map(np.isfinite, [eR_dft, eT_dft, eP_dft])
        ):
            segment_log["dft"] = build_energy_level_dict(
                labels=["R", "TS", "P"],
                energies_au=[eR_dft, eT_dft, eP_dft],
                ref_energy=eR_dft,
                au_to_kcal=AU2KCALPERMOL,
                diagram_path=str((seg_dir / "energy_diagram_DFT").with_suffix(".png")),
                structures=_structs_seg,
            )
        if GR_dftUMA is not None and GT_dftUMA is not None and GP_dftUMA is not None:
            segment_log["gibbs_dft_uma"] = build_energy_level_dict(
                labels=["R", "TS", "P"],
                energies_au=[GR_dftUMA, GT_dftUMA, GP_dftUMA],
                ref_energy=GR_dftUMA,
                au_to_kcal=AU2KCALPERMOL,
                diagram_path=str((seg_dir / "energy_diagram_G_DFT_plus_UMA").with_suffix(".png")),
                structures=_structs_seg,
            )

        post_segment_logs.append(segment_log)

    _all_diagram_specs = [
        (True, tsopt_seg_energies, "energy_diagram_UMA_all",
         "(UMA, TSOPT + IRC; all segments)", None),
        (do_thermo, g_uma_seg_energies, "energy_diagram_G_UMA_all",
         "(UMA + Thermal Correction; all segments)", "ΔG (kcal/mol)"),
        (do_dft, dft_seg_energies, "energy_diagram_DFT_all",
         f"({dft_method_fallback}; all segments)", None),
        (do_dft and do_thermo, g_dftuma_seg_energies, "energy_diagram_G_DFT_plus_UMA_all",
         f"({dft_method_fallback} // UMA + Thermal Correction; all segments)", "ΔG (kcal/mol)"),
    ]
    for cond, seg_energies, fname_stem, title_note, ylabel in _all_diagram_specs:
        if not cond or not seg_energies:
            continue
        all_energies = [e for triple in seg_energies for e in triple]
        all_labels = _build_global_segment_labels(len(seg_energies))
        if not (all_labels and len(all_labels) == len(all_energies)):
            continue
        extra_kwargs = {"ylabel": ylabel} if ylabel is not None else {}
        diag_payload = _write_segment_energy_diagram(
            out_dir / fname_stem,
            labels=all_labels,
            energies_eh=all_energies,
            title_note=title_note,
            write_html=False,
            **extra_kwargs,
        )
        if diag_payload:
            energy_diagrams.append(diag_payload)

    if irc_trj_for_all:
        _merge_irc_trajectories_to_single_plot(
            irc_trj_for_all, out_dir / "irc_plot_all.png"
        )

    # Refresh summary.json with final energy diagram metadata
    try:
        summary["energy_diagrams"] = list(energy_diagrams)
        _enrich_summary(
            summary,
            version="",
            pipeline_mode="path-search" if refine_path else "path-opt",
            out_dir=out_dir,
            mlip_backend=backend or "unknown",
            charge=q_int,
            spin=spin,
            command=command_str,
            post_segments=post_segment_logs,
            config={
                "refine_path": bool(refine_path),
                "tsopt": do_tsopt,
                "thermo": do_thermo,
                "dft": do_dft,
                "opt_mode": tsopt_opt_mode_default,
                "mep_mode": "gsm",
            },
        )
        with open(path_dir / "summary.json", "w") as f:
            json.dump(summary, f, indent=2, ensure_ascii=False)
        try:
            shutil.copy2(path_dir / "summary.json", out_dir / "summary.json")
        except Exception as e:
            _echo(f"[all] WARNING: Failed to mirror summary.json to {out_dir}: {e}", err=True)
    except Exception as e:
        _echo(f"[write] WARNING: Failed to refresh summary.json with energy diagram metadata: {e}", err=True)

    _write_pipeline_summary_log(post_segment_logs)
    _emit_final_summary(out_dir, time_start)


_configure_all_help_visibility(cli)


if __name__ == "__main__":
    cli()
