diff --git a/CodeEntropy/levels/execution/chunks.py b/CodeEntropy/levels/execution/chunks.py new file mode 100644 index 00000000..a583fc42 --- /dev/null +++ b/CodeEntropy/levels/execution/chunks.py @@ -0,0 +1,28 @@ +"""Frame chunking helpers for map-reduce execution.""" + +from __future__ import annotations + + +def chunk_frame_indices( + frame_indices: list[int], + chunk_size: int, +) -> list[tuple[int, ...]]: + """Split frame indices into deterministic fixed-size chunks. + + Args: + frame_indices: Ordered selected frame indices to split. + chunk_size: Maximum number of frame indices per chunk. + + Returns: + A list of frame-index tuples preserving input order. + + Raises: + ValueError: If ``chunk_size`` is less than one. + """ + if chunk_size < 1: + raise ValueError("chunk_size must be >= 1") + + return [ + tuple(frame_indices[i : i + chunk_size]) + for i in range(0, len(frame_indices), chunk_size) + ] diff --git a/CodeEntropy/levels/execution/policy.py b/CodeEntropy/levels/execution/policy.py new file mode 100644 index 00000000..c29e3976 --- /dev/null +++ b/CodeEntropy/levels/execution/policy.py @@ -0,0 +1,84 @@ +"""Internal policy for hierarchy-level frame map-reduce execution. + +Users provide compute resources. CodeEntropy keeps scheduling choices such as +chunk size and in-flight task limits internal so the public configuration remains +stable and simple. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class ExecutionPolicy: + """Internal policy for scalable, deterministic frame execution.""" + + target_frame_chunks_per_worker: int = 2 + min_frame_chunk_size: int = 1 + max_frame_chunk_size: int = 32 + max_frame_in_flight_multiplier: int = 1 + + def requested_worker_count(self, shared_data: dict[str, Any]) -> int: + """Return the requested worker-process count. + + Args: + shared_data: Shared workflow data that may contain ``args`` with local Dask + or HPC worker settings. + + Returns: + The requested worker count, clamped to at least one. + """ + args = shared_data.get("args") + + dask_workers = getattr(args, "dask_workers", None) + if dask_workers is not None: + return max(1, int(dask_workers)) + + if bool(getattr(args, "hpc", False)): + hpc_nodes = max(1, int(getattr(args, "hpc_nodes", 1) or 1)) + hpc_processes = max(1, int(getattr(args, "hpc_processes", 1) or 1)) + return hpc_nodes * hpc_processes + + return 1 + + def frame_chunk_size(self, shared_data: dict[str, Any], *, n_frames: int) -> int: + """Choose a deterministic frame chunk size. + + Args: + shared_data: Shared workflow data used to infer requested worker count. + n_frames: Number of selected frames for the current run. + + Returns: + The frame chunk size clamped between the policy minimum and maximum. + """ + n_frames = max(1, int(n_frames)) + workers = self.requested_worker_count(shared_data) + target_chunks = max(1, workers * int(self.target_frame_chunks_per_worker)) + chunk_size = math.ceil(n_frames / target_chunks) + + return max( + int(self.min_frame_chunk_size), + min(int(self.max_frame_chunk_size), int(chunk_size)), + ) + + def max_frame_in_flight_tasks( + self, + shared_data: dict[str, Any], + *, + n_chunks: int, + ) -> int: + """Choose the maximum number of active frame-chunk tasks. + + Args: + shared_data: Shared workflow data used to infer requested worker count. + n_chunks: Number of frame chunks available for submission. + + Returns: + The number of frame-chunk tasks allowed to be active at once. + """ + workers = self.requested_worker_count(shared_data) + max_in_flight = workers * int(self.max_frame_in_flight_multiplier) + return max(1, min(int(n_chunks), int(max_in_flight))) diff --git a/CodeEntropy/levels/execution/reducers.py b/CodeEntropy/levels/execution/reducers.py new file mode 100644 index 00000000..c5428ff9 --- /dev/null +++ b/CodeEntropy/levels/execution/reducers.py @@ -0,0 +1,386 @@ +"""Parent-side reducers for frame/chunk map-reduce outputs.""" + +from __future__ import annotations + +from typing import Any + +from CodeEntropy.levels.execution.tasks import CovarianceChunkPartial + + +def stable_keys(mapping: dict[Any, Any]) -> list[Any]: + """Return mapping keys in deterministic order. + + Args: + mapping: Mapping whose keys should be ordered independently of process hash + randomisation. + + Returns: + A list of keys sorted by key type name and representation. + """ + return sorted(mapping.keys(), key=lambda key: (type(key).__name__, repr(key))) + + +def merge_means(old_mean: Any, old_n: int, new_mean: Any, new_n: int) -> Any: + """Merge two running means using their sample counts. + + Args: + old_mean: Existing mean value, or ``None`` if no samples have been seen. + old_n: Number of samples represented by ``old_mean``. + new_mean: New mean value to merge. + new_n: Number of samples represented by ``new_mean``. + + Returns: + The merged mean. If ``new_n`` is zero or negative, ``old_mean`` is returned. + """ + if new_n <= 0: + return old_mean + if old_mean is None or old_n <= 0: + return new_mean.copy() if hasattr(new_mean, "copy") else new_mean + total_n = old_n + new_n + return old_mean + (new_mean - old_mean) * (float(new_n) / float(total_n)) + + +def incremental_mean(old: Any, new: Any, n: int) -> Any: + """Update a running mean with one new sample. + + Args: + old: Existing running mean, or ``None`` for the first sample. + new: New sample to incorporate. + n: One-based sample count after adding ``new``. + + Returns: + The updated running mean. + """ + if old is None: + return new.copy() if hasattr(new, "copy") else new + return old + (new - old) / float(n) + + +class NeighborReducer: + """Initialise, merge, and finalise neighbour-count reductions.""" + + @staticmethod + def initialise(shared_data: dict[str, Any]) -> None: + """Initialise parent-side neighbour accumulators. + + Args: + shared_data: Shared workflow data containing ``groups``. The method writes + ``neighbor_totals`` and ``neighbor_samples``. + """ + shared_data["neighbor_totals"] = { + group_id: 0 for group_id in shared_data["groups"].keys() + } + shared_data["neighbor_samples"] = { + group_id: 0 for group_id in shared_data["groups"].keys() + } + + @staticmethod + def reduce_frame_output( + shared_data: dict[str, Any], + frame_neighbors: dict[int, tuple[int, int]] | None, + ) -> None: + """Merge one frame's neighbour-count payload. + + Args: + shared_data: Shared workflow data containing neighbour total/sample + accumulators. + frame_neighbors: Optional mapping of group id to ``(count, sample_count)``. + """ + if frame_neighbors is None: + return + + totals = shared_data["neighbor_totals"] + samples = shared_data["neighbor_samples"] + for group_id in stable_keys(frame_neighbors): + count, sample_count = frame_neighbors[group_id] + totals[group_id] = totals.get(group_id, 0) + int(count) + samples[group_id] = samples.get(group_id, 0) + int(sample_count) + + @staticmethod + def merge_chunk_partial( + shared_data: dict[str, Any], + neighbor_totals: dict[int, int], + neighbor_samples: dict[int, int], + ) -> None: + """Merge chunk-level neighbour totals and samples. + + Args: + shared_data: Shared workflow data containing neighbour accumulators. + neighbor_totals: Mapping of group id to additive neighbour totals. + neighbor_samples: Mapping of group id to additive sample counts. + """ + totals = shared_data.get("neighbor_totals") + samples = shared_data.get("neighbor_samples") + if totals is None or samples is None: + return + + for group_id in stable_keys(neighbor_totals): + count = neighbor_totals[group_id] + totals[group_id] = totals.get(group_id, 0) + int(count) + for group_id in stable_keys(neighbor_samples): + sample_count = neighbor_samples[group_id] + samples[group_id] = samples.get(group_id, 0) + int(sample_count) + + @staticmethod + def finalise(shared_data: dict[str, Any]) -> None: + """Compute average neighbour counts from reduced totals. + + Args: + shared_data: Shared workflow data containing ``groups``, + ``neighbor_totals``, and ``neighbor_samples``. The method writes + ``neighbors``. + """ + neighbors = {} + for group_id in stable_keys(shared_data["groups"]): + sample_count = shared_data["neighbor_samples"].get(group_id, 0) + if sample_count <= 0: + neighbors[group_id] = 0.0 + else: + neighbors[group_id] = ( + shared_data["neighbor_totals"].get(group_id, 0) / sample_count + ) + shared_data["neighbors"] = neighbors + + +class CovarianceReducer: + """Merge frame and chunk covariance outputs into canonical accumulators.""" + + def reduce_frame_output( + self, + shared_data: dict[str, Any], + frame_out: dict[str, Any], + ) -> None: + """Reduce one frame covariance payload into parent accumulators. + + Args: + shared_data: Shared workflow data containing covariance accumulators. + frame_out: Frame covariance payload with force, torque, and optional + force-torque sections. + """ + self._reduce_force_and_torque(shared_data, frame_out) + self._reduce_forcetorque(shared_data, frame_out) + + def merge_chunk_partial( + self, + shared_data: dict[str, Any], + partial: CovarianceChunkPartial, + ) -> None: + """Merge a worker covariance partial into parent accumulators. + + Args: + shared_data: Shared workflow data containing covariance accumulators. + partial: Compact covariance partial returned by a worker frame chunk. + """ + self._merge_force_and_torque_partial(shared_data, partial) + self._merge_forcetorque_partial(shared_data, partial) + + def reduce_frame_map_output( + self, + shared_data: dict[str, Any], + frame_out: dict[str, Any], + ) -> None: + """Reduce a complete serial MAP output. + + Args: + shared_data: Shared workflow data containing covariance and neighbour + accumulators. + frame_out: MAP output containing optional ``covariance`` and ``neighbors`` + entries. + """ + covariance = frame_out.get("covariance") + if covariance is not None: + self.reduce_frame_output(shared_data, covariance) + + neighbors = frame_out.get("neighbors") + if neighbors is not None: + NeighborReducer.reduce_frame_output(shared_data, neighbors) + + def _merge_force_and_torque_partial( + self, + shared_data: dict[str, Any], + partial: CovarianceChunkPartial, + ) -> None: + """Merge chunk force and torque means into parent accumulators. + + Args: + shared_data: Shared workflow data containing force/torque accumulators, + frame counts, and ``group_id_to_index``. + partial: Worker covariance partial with force, torque, and count mappings. + """ + f_cov = shared_data["force_covariances"] + t_cov = shared_data["torque_covariances"] + counts = shared_data["frame_counts"] + gid2i = shared_data["group_id_to_index"] + + for key in stable_keys(partial.frame_counts["ua"]): + partial_n = partial.frame_counts["ua"][key] + old_n = int(counts["ua"].get(key, 0)) + if key in partial.force["ua"]: + f_cov["ua"][key] = merge_means( + f_cov["ua"].get(key), old_n, partial.force["ua"][key], partial_n + ) + if key in partial.torque["ua"]: + t_cov["ua"][key] = merge_means( + t_cov["ua"].get(key), old_n, partial.torque["ua"][key], partial_n + ) + counts["ua"][key] = old_n + partial_n + + for gid in stable_keys(partial.frame_counts["res"]): + partial_n = partial.frame_counts["res"][gid] + gi = gid2i[gid] + old_n = int(counts["res"][gi]) + if gid in partial.force["res"]: + f_cov["res"][gi] = merge_means( + f_cov["res"][gi], old_n, partial.force["res"][gid], partial_n + ) + if gid in partial.torque["res"]: + t_cov["res"][gi] = merge_means( + t_cov["res"][gi], old_n, partial.torque["res"][gid], partial_n + ) + counts["res"][gi] = old_n + partial_n + + for gid in stable_keys(partial.frame_counts["poly"]): + partial_n = partial.frame_counts["poly"][gid] + gi = gid2i[gid] + old_n = int(counts["poly"][gi]) + if gid in partial.force["poly"]: + f_cov["poly"][gi] = merge_means( + f_cov["poly"][gi], old_n, partial.force["poly"][gid], partial_n + ) + if gid in partial.torque["poly"]: + t_cov["poly"][gi] = merge_means( + t_cov["poly"][gi], old_n, partial.torque["poly"][gid], partial_n + ) + counts["poly"][gi] = old_n + partial_n + + @staticmethod + def _merge_forcetorque_partial( + shared_data: dict[str, Any], + partial: CovarianceChunkPartial, + ) -> None: + """Merge chunk force-torque block means into parent accumulators. + + Args: + shared_data: Shared workflow data containing force-torque accumulators, + force-torque counts, and ``group_id_to_index``. + partial: Worker covariance partial with force-torque matrices and counts. + """ + ft_cov = shared_data["forcetorque_covariances"] + ft_counts = shared_data["forcetorque_counts"] + gid2i = shared_data["group_id_to_index"] + + for gid in stable_keys(partial.forcetorque_counts["res"]): + partial_n = partial.forcetorque_counts["res"][gid] + gi = gid2i[gid] + old_n = int(ft_counts["res"][gi]) + ft_cov["res"][gi] = merge_means( + ft_cov["res"][gi], old_n, partial.forcetorque["res"][gid], partial_n + ) + ft_counts["res"][gi] = old_n + partial_n + + for gid in stable_keys(partial.forcetorque_counts["poly"]): + partial_n = partial.forcetorque_counts["poly"][gid] + gi = gid2i[gid] + old_n = int(ft_counts["poly"][gi]) + ft_cov["poly"][gi] = merge_means( + ft_cov["poly"][gi], old_n, partial.forcetorque["poly"][gid], partial_n + ) + ft_counts["poly"][gi] = old_n + partial_n + + def _reduce_force_and_torque( + self, + shared_data: dict[str, Any], + frame_out: dict[str, Any], + ) -> None: + """Reduce frame force and torque matrices into running means. + + Args: + shared_data: Shared workflow data containing force/torque accumulators, + frame counts, and ``group_id_to_index``. + frame_out: Frame covariance payload with ``force`` and ``torque`` sections. + """ + f_cov = shared_data["force_covariances"] + t_cov = shared_data["torque_covariances"] + counts = shared_data["frame_counts"] + gid2i = shared_data["group_id_to_index"] + + f_frame = frame_out["force"] + t_frame = frame_out["torque"] + + for key in stable_keys(f_frame["ua"]): + F = f_frame["ua"][key] + counts["ua"][key] = counts["ua"].get(key, 0) + 1 + n = counts["ua"][key] + f_cov["ua"][key] = incremental_mean(f_cov["ua"].get(key), F, n) + + for key in stable_keys(t_frame["ua"]): + T = t_frame["ua"][key] + if key not in counts["ua"]: + counts["ua"][key] = counts["ua"].get(key, 0) + 1 + n = counts["ua"][key] + t_cov["ua"][key] = incremental_mean(t_cov["ua"].get(key), T, n) + + for gid in stable_keys(f_frame["res"]): + F = f_frame["res"][gid] + gi = gid2i[gid] + counts["res"][gi] += 1 + n = counts["res"][gi] + f_cov["res"][gi] = incremental_mean(f_cov["res"][gi], F, n) + + for gid in stable_keys(t_frame["res"]): + T = t_frame["res"][gid] + gi = gid2i[gid] + if counts["res"][gi] == 0: + counts["res"][gi] += 1 + n = counts["res"][gi] + t_cov["res"][gi] = incremental_mean(t_cov["res"][gi], T, n) + + for gid in stable_keys(f_frame["poly"]): + F = f_frame["poly"][gid] + gi = gid2i[gid] + counts["poly"][gi] += 1 + n = counts["poly"][gi] + f_cov["poly"][gi] = incremental_mean(f_cov["poly"][gi], F, n) + + for gid in stable_keys(t_frame["poly"]): + T = t_frame["poly"][gid] + gi = gid2i[gid] + if counts["poly"][gi] == 0: + counts["poly"][gi] += 1 + n = counts["poly"][gi] + t_cov["poly"][gi] = incremental_mean(t_cov["poly"][gi], T, n) + + def _reduce_forcetorque( + self, + shared_data: dict[str, Any], + frame_out: dict[str, Any], + ) -> None: + """Reduce frame force-torque matrices into running means. + + Args: + shared_data: Shared workflow data containing force-torque accumulators, + force-torque counts, and ``group_id_to_index``. + frame_out: Frame covariance payload that may contain a ``forcetorque`` + section. + """ + if "forcetorque" not in frame_out: + return + + ft_cov = shared_data["forcetorque_covariances"] + ft_counts = shared_data["forcetorque_counts"] + gid2i = shared_data["group_id_to_index"] + ft_frame = frame_out["forcetorque"] + + for gid in stable_keys(ft_frame.get("res", {})): + M = ft_frame["res"][gid] + gi = gid2i[gid] + ft_counts["res"][gi] += 1 + n = ft_counts["res"][gi] + ft_cov["res"][gi] = incremental_mean(ft_cov["res"][gi], M, n) + + for gid in stable_keys(ft_frame.get("poly", {})): + M = ft_frame["poly"][gid] + gi = gid2i[gid] + ft_counts["poly"][gi] += 1 + n = ft_counts["poly"][gi] + ft_cov["poly"][gi] = incremental_mean(ft_cov["poly"][gi], M, n) diff --git a/CodeEntropy/levels/execution/scheduler.py b/CodeEntropy/levels/execution/scheduler.py new file mode 100644 index 00000000..5f3d8455 --- /dev/null +++ b/CodeEntropy/levels/execution/scheduler.py @@ -0,0 +1,306 @@ +"""Serial and Dask schedulers for frame-chunk map-reduce execution.""" + +from __future__ import annotations + +from typing import Any + +from rich.progress import TaskID + +from CodeEntropy.levels.execution.chunks import chunk_frame_indices +from CodeEntropy.levels.execution.policy import ExecutionPolicy +from CodeEntropy.levels.execution.reducers import CovarianceReducer, NeighborReducer +from CodeEntropy.levels.execution.tasks import ( + FrameChunkResult, + FrameChunkTask, + execute_frame_chunk_worker, + execute_frame_map_output, + make_frame_worker_shared_data, +) +from CodeEntropy.levels.frame_dag import FrameGraph +from CodeEntropy.levels.neighbors import Neighbors +from CodeEntropy.results.reporter import _RichProgressSink + + +class FrameScheduler: + """Execute frame-local MAP work serially or through Dask. + + Dask futures may complete in any order, but completed chunks are reduced by + chunk index. This keeps parent-side floating-point reductions deterministic. + """ + + def __init__( + self, + *, + frame_dag: FrameGraph, + policy: ExecutionPolicy, + universe_operations: Any | None = None, + ) -> None: + """Initialise the frame scheduler. + + Args: + frame_dag: Built or buildable frame-local DAG used for serial execution. + policy: Internal execution policy for chunking and in-flight limits. + universe_operations: Optional universe-operation adapter forwarded to worker + frame graphs. + """ + self._frame_dag = frame_dag + self._policy = policy + self._universe_operations = universe_operations + self._covariance_reducer = CovarianceReducer() + + def execute( + self, + shared_data: dict[str, Any], + *, + frame_indices: list[int], + progress: _RichProgressSink | None = None, + ) -> None: + """Execute frame-local MAP work and reduce outputs. + + Args: + shared_data: Shared workflow data containing serial or Dask execution + inputs. + frame_indices: Ordered selected frame indices to execute. + progress: Optional progress sink used for reporting frame-stage progress. + + Raises: + RuntimeError: If Dask execution is requested but unavailable or incomplete. + """ + task: TaskID | None = None + if progress is not None: + task = progress.add_task( + "[green]Frame processing", + total=len(frame_indices), + title="Initializing frame stage", + ) + + client = shared_data.get("dask_client") + parallel_frames = bool(shared_data.get("parallel_frames", client is not None)) + + if parallel_frames and client is not None and len(frame_indices) > 1: + self._run_dask( + shared_data, + frame_indices=frame_indices, + client=client, + progress=progress, + task=task, + ) + return + + self._run_serial( + shared_data, + frame_indices=frame_indices, + progress=progress, + task=task, + ) + + def _run_serial( + self, + shared_data: dict[str, Any], + *, + frame_indices: list[int], + progress: _RichProgressSink | None, + task: TaskID | None, + ) -> None: + """Execute frame MAP work serially. + + Args: + shared_data: Shared workflow data mutated by parent-side reducers. + frame_indices: Ordered frame indices to process. + progress: Optional progress sink. + task: Optional progress task identifier. + """ + neighbor_helper = Neighbors() + + for frame_index in frame_indices: + if progress is not None and task is not None: + progress.update(task, title=f"Frame {frame_index}") + + frame_out = execute_frame_map_output( + shared_data=shared_data, + frame_index=frame_index, + frame_dag=self._frame_dag, + neighbor_helper=neighbor_helper, + ) + self._covariance_reducer.reduce_frame_map_output(shared_data, frame_out) + + if progress is not None and task is not None: + progress.advance(task) + + def _run_dask( + self, + shared_data: dict[str, Any], + *, + frame_indices: list[int], + client: Any, + progress: _RichProgressSink | None, + task: TaskID | None, + ) -> None: + """Execute frame MAP work using bounded Dask futures. + + Args: + shared_data: Shared workflow data mutated by parent-side reducers. + frame_indices: Ordered frame indices to process. + client: Dask distributed client-like object. + progress: Optional progress sink. + task: Optional progress task identifier. + + Raises: + RuntimeError: If ``dask.distributed`` is unavailable or the number of + reduced frames does not match the selected frame count. + Exception: Propagates worker or Dask client errors after cancelling active + futures. + """ + try: + from distributed import wait + except ImportError as exc: + raise RuntimeError( + "Parallel frame execution requires dask.distributed to be installed." + ) from exc + + frame_tasks = self._make_frame_chunk_tasks(shared_data, frame_indices) + max_in_flight = self._policy.max_frame_in_flight_tasks( + shared_data, + n_chunks=len(frame_tasks), + ) + worker_shared = make_frame_worker_shared_data(shared_data) + worker_shared_future = client.scatter( + [worker_shared], + broadcast=True, + hash=False, + )[0] + + frame_task_iter = iter(frame_tasks) + active_futures: set[Any] = set() + submitted = 0 + completed = 0 + next_reduce_index = 0 + pending_results: dict[int, FrameChunkResult] = {} + + def submit_next() -> bool: + """Submit the next frame-chunk task if one is available. + + Returns: + ``True`` if a task was submitted, otherwise ``False`` when all tasks + have already been submitted. + """ + nonlocal submitted + try: + frame_task = next(frame_task_iter) + except StopIteration: + return False + + future = client.submit( + execute_frame_chunk_worker, + frame_task, + worker_shared_future, + self._universe_operations, + pure=False, + ) + active_futures.add(future) + submitted += 1 + return True + + def reduce_ready_results() -> None: + """Reduce completed frame chunks in deterministic chunk-index order. + + Mutates enclosing scheduler state by consuming pending results, advancing + the next expected reduce index, and updating the completed-frame count. + """ + nonlocal completed, next_reduce_index + while next_reduce_index in pending_results: + chunk_result = pending_results.pop(next_reduce_index) + + self._covariance_reducer.merge_chunk_partial( + shared_data, + chunk_result.covariance_partial, + ) + NeighborReducer.merge_chunk_partial( + shared_data, + chunk_result.neighbor_totals, + chunk_result.neighbor_samples, + ) + + completed += len(chunk_result.frame_indices) + next_reduce_index += 1 + + if progress is not None and task is not None: + progress.advance(task, len(chunk_result.frame_indices)) + + try: + for _ in range(min(max_in_flight, len(frame_tasks))): + submit_next() + + if progress is not None and task is not None: + progress.update( + task, + title=f"Submitted {submitted} of {len(frame_tasks)} frame chunks", + ) + + while active_futures: + if progress is not None and task is not None and completed == 0: + progress.update(task, title="Waiting for first frame chunk") + + done, not_done = wait( + active_futures, + return_when="FIRST_COMPLETED", + ) + active_futures = set(not_done) + + for future in done: + chunk_result = future.result() + pending_results[chunk_result.chunk_index] = chunk_result + future.release() + + if submit_next() and progress is not None and task is not None: + progress.update( + task, + title=( + f"Submitted {submitted} of {len(frame_tasks)} " + "frame chunks" + ), + ) + + reduce_ready_results() + + reduce_ready_results() + + if completed != len(frame_indices): + raise RuntimeError( + f"Parallel frame execution completed {completed} frames, " + f"but expected {len(frame_indices)}." + ) + + except Exception: + client.cancel(list(active_futures)) + raise + finally: + worker_shared_future.release() + + def _make_frame_chunk_tasks( + self, + shared_data: dict[str, Any], + frame_indices: list[int], + ) -> list[FrameChunkTask]: + """Build frame-chunk task descriptors. + + Args: + shared_data: Shared workflow data used by the execution policy. + frame_indices: Ordered selected frame indices to split into chunks. + + Returns: + A list of ``FrameChunkTask`` descriptors with deterministic chunk indices. + """ + chunk_size = self._policy.frame_chunk_size( + shared_data, + n_frames=len(frame_indices), + ) + frame_chunks = chunk_frame_indices(frame_indices, chunk_size) + + return [ + FrameChunkTask( + chunk_index=chunk_index, + frame_indices=chunk, + ) + for chunk_index, chunk in enumerate(frame_chunks) + ] diff --git a/CodeEntropy/levels/execution/tasks.py b/CodeEntropy/levels/execution/tasks.py new file mode 100644 index 00000000..2f9420d3 --- /dev/null +++ b/CodeEntropy/levels/execution/tasks.py @@ -0,0 +1,310 @@ +"""Task and worker-side helpers for frame-chunk execution.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from CodeEntropy.levels.frame_dag import FrameGraph +from CodeEntropy.levels.neighbors import Neighbors + +FRAME_WORKER_EXCLUDED_SHARED_KEYS = { + "force_covariances", + "torque_covariances", + "forcetorque_covariances", + "frame_counts", + "forcetorque_counts", + "neighbor_totals", + "neighbor_samples", + "n_frames", + "entropy_manager", + "run_manager", + "reporter", + "dask_client", +} + + +@dataclass(frozen=True) +class FrameChunkTask: + """MAP-stage input for a chunk of selected trajectory frames.""" + + chunk_index: int + frame_indices: tuple[int, ...] + + +@dataclass +class CovarianceChunkPartial: + """Compact, mergeable covariance partial for one frame chunk.""" + + force: dict[str, dict[Any, Any]] = field( + default_factory=lambda: {"ua": {}, "res": {}, "poly": {}} + ) + torque: dict[str, dict[Any, Any]] = field( + default_factory=lambda: {"ua": {}, "res": {}, "poly": {}} + ) + frame_counts: dict[str, dict[Any, int]] = field( + default_factory=lambda: {"ua": {}, "res": {}, "poly": {}} + ) + forcetorque: dict[str, dict[Any, Any]] = field( + default_factory=lambda: {"res": {}, "poly": {}} + ) + forcetorque_counts: dict[str, dict[Any, int]] = field( + default_factory=lambda: {"res": {}, "poly": {}} + ) + + +@dataclass(frozen=True) +class FrameChunkResult: + """MAP-stage output for a completed frame chunk.""" + + chunk_index: int + covariance_partial: CovarianceChunkPartial + neighbor_totals: dict[int, int] + neighbor_samples: dict[int, int] + frame_indices: tuple[int, ...] + + +def make_frame_worker_shared_data(shared_data: dict[str, Any]) -> dict[str, Any]: + """Build the worker-visible subset of shared workflow data. + + Args: + shared_data: Full parent-side shared workflow data. + + Returns: + A shallow copy excluding parent-owned accumulators, reporting objects, and + Dask client state. + """ + return { + key: value + for key, value in shared_data.items() + if key not in FRAME_WORKER_EXCLUDED_SHARED_KEYS + } + + +def incremental_mean_value(old: Any, new: Any, n: int) -> Any: + """Update a worker-local running mean. + + Args: + old: Existing running mean, or ``None`` for the first sample. + new: New sample to incorporate. + n: One-based sample count after adding ``new``. + + Returns: + The updated running mean value. + """ + if old is None: + return new.copy() if hasattr(new, "copy") else new + return old + (new - old) / float(n) + + +def reduce_frame_covariance_into_partial( + partial: CovarianceChunkPartial, + frame_out: dict[str, Any], +) -> None: + """Reduce one frame covariance payload into a chunk partial. + + Args: + partial: Worker-local covariance partial mutated in place. + frame_out: Frame covariance payload with force, torque, and optional + force-torque sections. + + Raises: + KeyError: If required force or torque sections are missing. + """ + f_frame = frame_out["force"] + t_frame = frame_out["torque"] + + for key, force_matrix in f_frame["ua"].items(): + partial.frame_counts["ua"][key] = partial.frame_counts["ua"].get(key, 0) + 1 + n = partial.frame_counts["ua"][key] + partial.force["ua"][key] = incremental_mean_value( + partial.force["ua"].get(key), + force_matrix, + n, + ) + + for key, torque_matrix in t_frame["ua"].items(): + if key not in partial.frame_counts["ua"]: + partial.frame_counts["ua"][key] = partial.frame_counts["ua"].get(key, 0) + 1 + n = partial.frame_counts["ua"][key] + partial.torque["ua"][key] = incremental_mean_value( + partial.torque["ua"].get(key), + torque_matrix, + n, + ) + + for group_id, force_matrix in f_frame["res"].items(): + partial.frame_counts["res"][group_id] = ( + partial.frame_counts["res"].get(group_id, 0) + 1 + ) + n = partial.frame_counts["res"][group_id] + partial.force["res"][group_id] = incremental_mean_value( + partial.force["res"].get(group_id), + force_matrix, + n, + ) + + for group_id, torque_matrix in t_frame["res"].items(): + if group_id not in partial.frame_counts["res"]: + partial.frame_counts["res"][group_id] = ( + partial.frame_counts["res"].get(group_id, 0) + 1 + ) + n = partial.frame_counts["res"][group_id] + partial.torque["res"][group_id] = incremental_mean_value( + partial.torque["res"].get(group_id), + torque_matrix, + n, + ) + + for group_id, force_matrix in f_frame["poly"].items(): + partial.frame_counts["poly"][group_id] = ( + partial.frame_counts["poly"].get(group_id, 0) + 1 + ) + n = partial.frame_counts["poly"][group_id] + partial.force["poly"][group_id] = incremental_mean_value( + partial.force["poly"].get(group_id), + force_matrix, + n, + ) + + for group_id, torque_matrix in t_frame["poly"].items(): + if group_id not in partial.frame_counts["poly"]: + partial.frame_counts["poly"][group_id] = ( + partial.frame_counts["poly"].get(group_id, 0) + 1 + ) + n = partial.frame_counts["poly"][group_id] + partial.torque["poly"][group_id] = incremental_mean_value( + partial.torque["poly"].get(group_id), + torque_matrix, + n, + ) + + if "forcetorque" not in frame_out: + return + + ft_frame = frame_out["forcetorque"] + for group_id, matrix in ft_frame.get("res", {}).items(): + partial.forcetorque_counts["res"][group_id] = ( + partial.forcetorque_counts["res"].get(group_id, 0) + 1 + ) + n = partial.forcetorque_counts["res"][group_id] + partial.forcetorque["res"][group_id] = incremental_mean_value( + partial.forcetorque["res"].get(group_id), + matrix, + n, + ) + + for group_id, matrix in ft_frame.get("poly", {}).items(): + partial.forcetorque_counts["poly"][group_id] = ( + partial.forcetorque_counts["poly"].get(group_id, 0) + 1 + ) + n = partial.forcetorque_counts["poly"][group_id] + partial.forcetorque["poly"][group_id] = incremental_mean_value( + partial.forcetorque["poly"].get(group_id), + matrix, + n, + ) + + +def execute_frame_map_output( + *, + shared_data: dict[str, Any], + frame_index: int, + frame_dag: FrameGraph, + neighbor_helper: Neighbors | None = None, +) -> dict[str, Any]: + """Execute all MAP operations for one frame in serial mode. + + Args: + shared_data: Shared workflow data required by frame covariance and neighbour + calculations. + frame_index: Frame index to execute. + frame_dag: Frame-local DAG used to compute covariance outputs. + neighbor_helper: Optional neighbour helper. A default ``Neighbors`` instance + is created when omitted. + + Returns: + A MAP output containing ``covariance`` and ``neighbors`` entries. + + Raises: + KeyError: If required shared workflow keys are missing. + """ + frame_index = int(frame_index) + frame_out: dict[str, Any] = { + "covariance": frame_dag.execute_frame(shared_data, frame_index), + } + + if neighbor_helper is None: + neighbor_helper = Neighbors() + + universe = shared_data.get("reduced_universe", shared_data.get("universe")) + frame_out["neighbors"] = neighbor_helper.get_frame_neighbor_counts( + universe=universe, + levels=shared_data["levels"], + groups=shared_data["groups"], + frame_source=shared_data["frame_source"], + frame_index=frame_index, + search_type=shared_data["args"].search_type, + ) + + return frame_out + + +def execute_frame_chunk_worker( + task: FrameChunkTask, + worker_shared_data: dict[str, Any], + universe_operations: Any | None = None, +) -> FrameChunkResult: + """Execute one frame chunk and return compact mergeable partials. + + Args: + task: Frame chunk descriptor containing chunk index and selected frames. + worker_shared_data: Worker-visible shared workflow data. + universe_operations: Optional universe-operation adapter used to build the + worker-local frame graph. + + Returns: + A ``FrameChunkResult`` containing covariance partials, neighbour totals, + neighbour sample counts, and processed frame indices. + + Raises: + KeyError: If required worker shared-data keys are missing. + """ + frame_dag = FrameGraph(universe_operations=universe_operations).build() + neighbor_helper = Neighbors() + + covariance_partial = CovarianceChunkPartial() + neighbor_totals = {group_id: 0 for group_id in worker_shared_data["groups"].keys()} + neighbor_samples = {group_id: 0 for group_id in worker_shared_data["groups"].keys()} + + for frame_index in task.frame_indices: + frame_index = int(frame_index) + frame_covariance = frame_dag.execute_frame(worker_shared_data, frame_index) + reduce_frame_covariance_into_partial(covariance_partial, frame_covariance) + + universe = worker_shared_data.get( + "reduced_universe", + worker_shared_data.get("universe"), + ) + frame_neighbors = neighbor_helper.get_frame_neighbor_counts( + universe=universe, + levels=worker_shared_data["levels"], + groups=worker_shared_data["groups"], + frame_source=worker_shared_data["frame_source"], + frame_index=frame_index, + search_type=worker_shared_data["args"].search_type, + ) + + for group_id, (count, sample_count) in frame_neighbors.items(): + neighbor_totals[group_id] = neighbor_totals.get(group_id, 0) + int(count) + neighbor_samples[group_id] = neighbor_samples.get(group_id, 0) + int( + sample_count + ) + + return FrameChunkResult( + chunk_index=task.chunk_index, + covariance_partial=covariance_partial, + neighbor_totals=neighbor_totals, + neighbor_samples=neighbor_samples, + frame_indices=task.frame_indices, + ) diff --git a/CodeEntropy/levels/frame_dag.py b/CodeEntropy/levels/frame_dag.py index 05f6f522..831f12c7 100644 --- a/CodeEntropy/levels/frame_dag.py +++ b/CodeEntropy/levels/frame_dag.py @@ -1,90 +1,58 @@ """Frame-local DAG execution. This module defines the frame-scoped DAG used during the MAP stage of the -hierarchy workflow. Each frame is processed independently to produce -frame-local outputs (e.g., axes and covariance data), which are later reduced -outside this DAG. +hierarchy workflow. Each selected frame is processed independently to produce +frame-local observable outputs, which are reduced outside this DAG. + +FrameGraph owns trajectory positioning. It does not own scheduling, chunking, or +reduction. """ from __future__ import annotations -import logging -from dataclasses import dataclass from typing import Any import networkx as nx from CodeEntropy.levels.nodes.covariance import FrameCovarianceNode -logger = logging.getLogger(__name__) - - -@dataclass -class FrameContext: - """Container for per-frame execution context. - - Attributes: - shared: Shared workflow data (mutated across the full workflow). - frame_index: Absolute trajectory frame index being processed. - frame_covariance: Frame-local covariance output produced by FrameCovarianceNode. - data: Additional frame-local scratch space for nodes, if needed. - """ - - shared: dict[str, Any] - frame_index: int - frame_covariance: Any = None - data: dict[str, Any] | None = None - class FrameGraph: - """Execute a frame-local directed acyclic graph. - - The graph is run once per trajectory frame. Nodes may read shared inputs from - `ctx["shared"]` and must write only frame-local outputs into the frame context. - - Expected node outputs: - - "frame_covariance" - """ + """Execute the frame-local directed acyclic graph.""" def __init__(self, universe_operations: Any | None = None) -> None: - """Initialise a FrameGraph. + """Initialise a frame-local DAG. Args: - universe_operations: Optional adapter providing universe operations used - by frame-level nodes (e.g., selections / molecule containers). + universe_operations: Optional universe-operation adapter retained for frame + graph construction consistency. """ self._universe_operations = universe_operations self._graph = nx.DiGraph() self._nodes: dict[str, Any] = {} def build(self) -> FrameGraph: - """Build the default frame DAG topology. + """Build the default frame-local graph topology. Returns: - Self, to allow fluent chaining. + The current ``FrameGraph`` instance for fluent construction. """ self._add("frame_covariance", FrameCovarianceNode()) return self def execute_frame(self, shared_data: dict[str, Any], frame_index: int) -> Any: - """Execute the frame DAG for one selected analysis frame. - - FrameGraph owns trajectory positioning for frame-local execution. Higher-level - orchestration passes explicit frame indices but must not rely on hidden - MDAnalysis cursor state. + """Execute frame-local nodes for one selected frame. Args: - shared_data: Shared workflow data dictionary. Must contain - ``"frame_source"``. - frame_index: Frame index valid for the active analysis universe. During - this migration stage this is local to the frame-reduced universe. + shared_data: Shared workflow data containing ``frame_source``. + frame_index: Frame index in the active analysis frame-source space. Returns: - Frame-local covariance payload produced by ``FrameCovarianceNode``. + The frame covariance payload produced by the frame-local covariance node. Raises: - KeyError: If ``"frame_source"`` is missing from ``shared_data``. - IndexError: If ``frame_index`` is outside trajectory bounds. + KeyError: If ``frame_source`` is missing from ``shared_data``. + IndexError: If ``frame_index`` is outside the active trajectory bounds. """ frame_source = shared_data["frame_source"] frame_index = int(frame_index) @@ -98,19 +66,21 @@ def execute_frame(self, shared_data: dict[str, Any], frame_index: int) -> Any: f"for trajectory with {n_frames} frames." ) from exc - ctx = self._make_frame_ctx( - shared_data=shared_data, - frame_index=frame_index, - ) + ctx = self._make_frame_ctx(shared_data=shared_data, frame_index=frame_index) for node_name in nx.topological_sort(self._graph): - logger.debug("[FrameGraph] running %s @ frame=%s", node_name, frame_index) self._nodes[node_name].run(ctx) return ctx["frame_covariance"] def _add(self, name: str, node: Any, deps: list[str] | None = None) -> None: - """Register a node and its dependencies in the DAG.""" + """Register a frame-local node and dependency edges. + + Args: + name: Unique node name in the frame DAG. + node: Node object exposing a ``run`` method. + deps: Optional upstream node names that must run before ``name``. + """ self._nodes[name] = node self._graph.add_node(name) for dep in deps or []: @@ -118,21 +88,18 @@ def _add(self, name: str, node: Any, deps: list[str] | None = None) -> None: @staticmethod def _make_frame_ctx( - shared_data: dict[str, Any], frame_index: int + shared_data: dict[str, Any], + frame_index: int, ) -> dict[str, Any]: - """Create a frame context dictionary for node execution. - - Notes: - - The context includes a reference to `shared_data` via "shared". - - The context is intentionally frame-scoped and should not be used as - a replacement for shared workflow state. + """Build a frame-local execution context. Args: - shared_data: Shared workflow data dict. - frame_index: Absolute trajectory frame index. + shared_data: Shared workflow data referenced by frame-local nodes. + frame_index: Frame index currently being executed. Returns: - Frame context dict with required keys. + A frame context dictionary containing shared data, frame index, and a + placeholder for frame covariance output. """ return { "shared": shared_data, diff --git a/CodeEntropy/levels/level_dag.py b/CodeEntropy/levels/level_dag.py index 0c463ced..ecbf1d49 100644 --- a/CodeEntropy/levels/level_dag.py +++ b/CodeEntropy/levels/level_dag.py @@ -1,114 +1,55 @@ -"""Hierarchy-level DAG orchestration and reduction. +"""Hierarchy-level DAG orchestration. -This module defines the `LevelDAG`, which coordinates two stages of the hierarchy -workflow: - -1) Static stage (runs once): - - Detect molecules and available resolution levels. - - Build beads for each (molecule, level) definition. - - Initialise accumulators used during per-frame reduction. - - Compute conformational state descriptors required later by entropy nodes. - -2) Frame stage (runs for each trajectory frame): - - Execute the `FrameGraph` to produce frame-local covariance outputs. - - Reduce frame-local outputs into running (incremental) means. +LevelDAG owns hierarchy-level workflow order. Static setup nodes prepare +structural and conformational data, then frame-local covariance and neighbour +observables are executed through deterministic frame map-reduce. """ from __future__ import annotations -import logging from typing import Any import networkx as nx -from rich.progress import TaskID from CodeEntropy.levels.axes import AxesCalculator +from CodeEntropy.levels.execution.policy import ExecutionPolicy +from CodeEntropy.levels.execution.reducers import NeighborReducer +from CodeEntropy.levels.execution.scheduler import FrameScheduler from CodeEntropy.levels.frame_dag import FrameGraph +from CodeEntropy.levels.neighbors import Neighbors from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode from CodeEntropy.levels.nodes.beads import BuildBeadsNode from CodeEntropy.levels.nodes.conformations import ComputeConformationalStatesNode from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode -from CodeEntropy.levels.nodes.find_neighbors import ComputeNeighborsNode from CodeEntropy.results.reporter import _RichProgressSink -logger = logging.getLogger(__name__) - - -_FRAME_WORKER_EXCLUDED_SHARED_KEYS = { - "force_covariances", - "torque_covariances", - "forcetorque_covariances", - "frame_counts", - "forcetorque_counts", - "force_torque_stats", - "force_torque_counts", - "n_frames", - "entropy_manager", - "run_manager", - "reporter", - "dask_client", -} - - -def _execute_frame_worker( - shared_data: dict[str, Any], - frame_index: int, - universe_operations: Any | None = None, -) -> tuple[int, Any]: - """Execute one frame on a Dask worker. - - Args: - shared_data: Worker-local shared calculation inputs. - frame_index: Frame index to process. - universe_operations: Optional universe operations adapter. - - Returns: - Tuple of frame index and frame-local covariance output. - """ - frame_dag = FrameGraph(universe_operations=universe_operations).build() - return int(frame_index), frame_dag.execute_frame(shared_data, int(frame_index)) - class LevelDAG: - """Execute hierarchy detection, per-frame covariance calculation, and reduction. - - The LevelDAG is responsible for: - - Running a static DAG (once) to prepare shared inputs. - - Running a per-frame DAG (for each frame) to compute frame-local outputs. - - Reducing frame-local outputs into shared running means. - - The reduction performed here is an incremental mean across frames (and across - molecules within a group when frame nodes average within-frame first). - """ + """Execute static setup and deterministic frame map-reduce execution.""" def __init__(self, universe_operations: Any | None = None) -> None: - """Initialise a LevelDAG. + """Initialise the hierarchy-level DAG. Args: - universe_operations: Optional adapter providing universe operations. - Passed to the FrameGraph and the conformational-state node. + universe_operations: Optional universe-operation adapter passed to static + conformational-state setup and frame-local execution. """ self._universe_operations = universe_operations - self._static_graph = nx.DiGraph() self._static_nodes: dict[str, Any] = {} - self._frame_dag = FrameGraph(universe_operations=universe_operations) + self._policy = ExecutionPolicy() def build(self) -> LevelDAG: - """Build the static and frame DAG topology. - - This registers all static nodes and their dependencies, and builds the - internal FrameGraph used for per-frame execution. + """Build static and frame-level DAG topology. Returns: - Self, to allow fluent chaining. + The current ``LevelDAG`` instance for fluent construction. """ self._add_static("detect_molecules", DetectMoleculesNode()) self._add_static("detect_levels", DetectLevelsNode(), deps=["detect_molecules"]) self._add_static("build_beads", BuildBeadsNode(), deps=["detect_levels"]) - self._add_static( "init_covariance_accumulators", InitCovarianceAccumulatorsNode(), @@ -119,71 +60,80 @@ def build(self) -> LevelDAG: ComputeConformationalStatesNode(self._universe_operations), deps=["detect_levels"], ) - self._add_static( - "find_neighbors", ComputeNeighborsNode(), deps=["detect_levels"] - ) self._frame_dag.build() return self def execute( - self, shared_data: dict[str, Any], *, progress: _RichProgressSink | None = None + self, + shared_data: dict[str, Any], + *, + progress: _RichProgressSink | None = None, ) -> dict[str, Any]: - """Execute the full hierarchy workflow and mutate shared_data. - - This method ensures required shared components exist, runs the static stage - once, then iterates through trajectory frames to run the per-frame stage and - reduce outputs into running means. + """Execute the hierarchy workflow. Args: - shared_data: Shared workflow data dict. This mapping is mutated in-place - by both static and frame stages. - progress: Optional progress sink passed through to nodes and used for - per-frame progress reporting when supported. + shared_data: Shared workflow data mutated by static setup, frame execution, + and parent-side reductions. + progress: Optional progress sink passed to supported static nodes and frame + scheduling. Returns: - The same shared_data mapping passed in, after mutation. + The same ``shared_data`` mapping after workflow execution. + + Raises: + KeyError: If required shared workflow keys are missing. """ shared_data.setdefault("axes_manager", AxesCalculator()) + self._run_static_stage(shared_data, progress=progress) + self._initialise_neighbor_metadata(shared_data) + NeighborReducer.initialise(shared_data) self._run_frame_stage(shared_data, progress=progress) + NeighborReducer.finalise(shared_data) + return shared_data def _run_static_stage( - self, shared_data: dict[str, Any], *, progress: _RichProgressSink | None = None + self, + shared_data: dict[str, Any], + *, + progress: _RichProgressSink | None = None, ) -> None: - """Run all static nodes in dependency order. - - Nodes are executed in topological order of the static DAG. If a progress - object is provided, it is passed to node.run when the node accepts it. + """Run static setup nodes in dependency order. Args: - shared_data: Shared workflow data dict to be mutated by static nodes. - progress: Optional progress sink to pass to nodes that support it. + shared_data: Shared workflow data mutated by each static node. + progress: Optional progress sink passed to nodes that accept it. """ for node_name in nx.topological_sort(self._static_graph): node = self._static_nodes[node_name] + if progress is not None: try: node.run(shared_data, progress=progress) continue except TypeError: pass + node.run(shared_data) - def _add_static(self, name: str, node: Any, deps: list[str] | None = None) -> None: - """Register a static node and its dependencies in the static DAG. + def _add_static( + self, + name: str, + node: Any, + deps: list[str] | None = None, + ) -> None: + """Register a static node in the hierarchy DAG. Args: - name: Unique node name used in the static DAG. - node: Node object exposing a run(shared_data, **kwargs) method. - deps: Optional list of upstream node names that must run before this node. - - Returns: - None. Mutates the internal static graph and node registry. + name: Unique node name in the static DAG. + node: Node object exposing a ``run`` method. + deps: Optional upstream node names that must execute before ``name``. """ self._static_nodes[name] = node self._static_graph.add_node(name) + for dep in deps or []: self._static_graph.add_edge(dep, name) @@ -193,28 +143,15 @@ def _run_frame_stage( *, progress: _RichProgressSink | None = None, ) -> None: - """Execute the per-frame DAG stage and reduce frame outputs. - - This method iterates over explicit frame indices provided by - ``shared_data["frame_source"]``. During this migration stage, those indices - are local indices into the physically frame-reduced analysis universe. After - physical frame slicing is removed, they will be absolute source-trajectory - indices. - - FrameGraph owns trajectory positioning. LevelDAG only chooses which frame - indices to process and reduces each frame-local output into shared - accumulators. - - If ``shared_data["dask_client"]`` exists and parallel frame execution is - enabled, frame-local outputs are computed on Dask workers and reduced in - the parent process. + """Execute frame map-reduce work through the frame scheduler. Args: - shared_data: Shared data dictionary. Must contain ``frame_source``. - progress: Optional progress sink. + shared_data: Shared workflow data containing ``frame_source`` and + frame-stage inputs. The method writes ``n_frames``. + progress: Optional progress sink forwarded to the frame scheduler. - Returns: - None. Mutates ``shared_data`` in-place via reduction. + Raises: + KeyError: If ``frame_source`` is missing from ``shared_data``. """ frame_source = shared_data["frame_source"] frame_indices = [ @@ -222,240 +159,36 @@ def _run_frame_stage( ] shared_data["n_frames"] = len(frame_indices) - task: TaskID | None = None - - if progress is not None: - task = progress.add_task( - "[green]Frame processing", - total=len(frame_indices), - title="Initializing", - ) - - client = shared_data.get("dask_client") - parallel_frames = bool(shared_data.get("parallel_frames", client is not None)) - - if parallel_frames and client is not None and len(frame_indices) > 1: - self._run_frame_stage_dask( - shared_data, - frame_indices=frame_indices, - client=client, - progress=progress, - task=task, - ) - return - - for frame_index in frame_indices: - if progress is not None and task is not None: - progress.update(task, title=f"Frame {frame_index}") - - frame_out = self._frame_dag.execute_frame( - shared_data, - frame_index, - ) - - self._reduce_one_frame(shared_data, frame_out) - - if progress is not None and task is not None: - progress.advance(task) - - @staticmethod - def _make_frame_worker_shared_data(shared_data: dict[str, Any]) -> dict[str, Any]: - """Return the subset of shared data required by frame workers. - - Reduction accumulators and parent orchestration/reporting objects are - intentionally excluded because workers should only compute frame-local - outputs. - """ - return { - key: value - for key, value in shared_data.items() - if key not in _FRAME_WORKER_EXCLUDED_SHARED_KEYS - } - - def _run_frame_stage_dask( - self, - shared_data: dict[str, Any], - *, - frame_indices: list[int], - client: Any, - progress: _RichProgressSink | None = None, - task: TaskID | None = None, - ) -> None: - """Execute frame-local DAG tasks in parallel using Dask. - - Workers return frame-local covariance payloads. The parent process performs - all reductions into the shared accumulators. - - Important: - Do not scatter/broadcast worker_shared. It contains stateful objects - such as frame_source / universe trajectory state. Broadcasting can reuse - mutable state across tasks on the same worker and make frames interfere - with one another. - """ - try: - from distributed import as_completed - except ImportError as exc: - raise RuntimeError( - "Parallel frame execution requires dask.distributed to be installed." - ) from exc - - worker_shared = self._make_frame_worker_shared_data(shared_data) - - futures = [ - client.submit( - _execute_frame_worker, - worker_shared, - frame_index, - self._universe_operations, - pure=False, - ) - for frame_index in frame_indices - ] - - completed = 0 - - try: - for future in as_completed(futures): - frame_index, frame_out = future.result() - completed += 1 - - if progress is not None and task is not None: - progress.update(task, title=f"Frame {frame_index}") - - self._reduce_one_frame(shared_data, frame_out) - - if progress is not None and task is not None: - progress.advance(task) - - if completed != len(frame_indices): - raise RuntimeError( - f"Parallel frame execution completed {completed} frames, " - f"but expected {len(frame_indices)}." - ) - - except Exception: - client.cancel(futures) - raise + scheduler = FrameScheduler( + frame_dag=self._frame_dag, + policy=self._policy, + universe_operations=self._universe_operations, + ) + scheduler.execute( + shared_data, + frame_indices=frame_indices, + progress=progress, + ) @staticmethod - def _incremental_mean(old: Any, new: Any, n: int) -> Any: - """Compute an incremental mean. - - Args: - old: Previous running mean (or None for first sample). - new: New sample to incorporate. - n: 1-based sample count after adding `new`. - - Returns: - Updated running mean. - """ - if old is None: - return new.copy() if hasattr(new, "copy") else new - return old + (new - old) / float(n) - - def _reduce_one_frame( - self, shared_data: dict[str, Any], frame_out: dict[str, Any] - ) -> None: - """Reduce one frame's covariance outputs into shared running means. + def _initialise_neighbor_metadata(shared_data: dict[str, Any]) -> None: + """Compute frame-invariant neighbour metadata. Args: - shared_data: Shared workflow data dict containing accumulators. - frame_out: Frame-local covariance outputs produced by FrameGraph. - """ - self._reduce_force_and_torque(shared_data, frame_out) - self._reduce_forcetorque(shared_data, frame_out) - - def _reduce_force_and_torque( - self, shared_data: dict[str, Any], frame_out: dict[str, Any] - ) -> None: - """Reduce force/torque covariance outputs into shared accumulators. + shared_data: Shared workflow data containing ``groups`` and either + ``reduced_universe`` or ``universe``. The method writes + ``symmetry_number`` and ``linear``. - Args: - shared_data: Shared workflow data dict containing: - - "force_covariances", "torque_covariances": accumulator structures. - - "frame_counts": running sample counts for each accumulator slot. - - "group_id_to_index": mapping from group id to accumulator index. - frame_out: Frame-local outputs containing "force" and "torque" sections. - - Returns: - None. Mutates accumulator values and counts in shared_data in-place. + Raises: + KeyError: If ``groups`` is missing from ``shared_data``. """ - f_cov = shared_data["force_covariances"] - t_cov = shared_data["torque_covariances"] - counts = shared_data["frame_counts"] - gid2i = shared_data["group_id_to_index"] - - f_frame = frame_out["force"] - t_frame = frame_out["torque"] - - for key, F in f_frame["ua"].items(): - counts["ua"][key] = counts["ua"].get(key, 0) + 1 - n = counts["ua"][key] - f_cov["ua"][key] = self._incremental_mean(f_cov["ua"].get(key), F, n) - - for key, T in t_frame["ua"].items(): - if key not in counts["ua"]: - counts["ua"][key] = counts["ua"].get(key, 0) + 1 - n = counts["ua"][key] - t_cov["ua"][key] = self._incremental_mean(t_cov["ua"].get(key), T, n) - - for gid, F in f_frame["res"].items(): - gi = gid2i[gid] - counts["res"][gi] += 1 - n = counts["res"][gi] - f_cov["res"][gi] = self._incremental_mean(f_cov["res"][gi], F, n) - - for gid, T in t_frame["res"].items(): - gi = gid2i[gid] - if counts["res"][gi] == 0: - counts["res"][gi] += 1 - n = counts["res"][gi] - t_cov["res"][gi] = self._incremental_mean(t_cov["res"][gi], T, n) - - for gid, F in f_frame["poly"].items(): - gi = gid2i[gid] - counts["poly"][gi] += 1 - n = counts["poly"][gi] - f_cov["poly"][gi] = self._incremental_mean(f_cov["poly"][gi], F, n) - - for gid, T in t_frame["poly"].items(): - gi = gid2i[gid] - if counts["poly"][gi] == 0: - counts["poly"][gi] += 1 - n = counts["poly"][gi] - t_cov["poly"][gi] = self._incremental_mean(t_cov["poly"][gi], T, n) - - def _reduce_forcetorque( - self, shared_data: dict[str, Any], frame_out: dict[str, Any] - ) -> None: - """Reduce combined force-torque covariance outputs into shared accumulators. + helper = Neighbors() + universe = shared_data.get("reduced_universe", shared_data.get("universe")) - Args: - shared_data: Shared workflow data dict containing: - - "forcetorque_covariances": accumulator structures. - - "forcetorque_counts": running sample counts for each accumulator slot. - - "group_id_to_index": mapping from group id to accumulator index. - frame_out: Frame-local outputs that may include a "forcetorque" section. + symmetry_number, linear = helper.get_symmetry( + universe=universe, + groups=shared_data["groups"], + ) - Returns: - None. Mutates accumulator values and counts in shared_data in-place. - """ - if "forcetorque" not in frame_out: - return - - ft_cov = shared_data["forcetorque_covariances"] - ft_counts = shared_data["forcetorque_counts"] - gid2i = shared_data["group_id_to_index"] - ft_frame = frame_out["forcetorque"] - - for gid, M in ft_frame.get("res", {}).items(): - gi = gid2i[gid] - ft_counts["res"][gi] += 1 - n = ft_counts["res"][gi] - ft_cov["res"][gi] = self._incremental_mean(ft_cov["res"][gi], M, n) - - for gid, M in ft_frame.get("poly", {}).items(): - gi = gid2i[gid] - ft_counts["poly"][gi] += 1 - n = ft_counts["poly"][gi] - ft_cov["poly"][gi] = self._incremental_mean(ft_cov["poly"][gi], M, n) + shared_data["symmetry_number"] = symmetry_number + shared_data["linear"] = linear diff --git a/CodeEntropy/levels/neighbors.py b/CodeEntropy/levels/neighbors.py index e40b06aa..19d00bd2 100644 --- a/CodeEntropy/levels/neighbors.py +++ b/CodeEntropy/levels/neighbors.py @@ -1,250 +1,190 @@ -"""Neighbours info for orientational entropy. +"""Frame-local neighbour observables for orientational entropy. -This module finds the average number of neighbors, symmetry numbers, and -and linearity for each group. -These are used downstream to compute the orientational entropy. +The frame execution layer calls :class:`Neighbors` once per selected trajectory +frame. Each call returns mergeable neighbour-count totals for every molecule +group. Static symmetry and linearity metadata is computed separately because it +does not vary by frame. """ -import logging +from __future__ import annotations + +from typing import Any -import numpy as np from rdkit import Chem from CodeEntropy.levels.search import Search -logger = logging.getLogger(__name__) +NeighborCounts = dict[int, tuple[int, int]] class Neighbors: - """ - Class to find the neighbors and any related information needed for the - calculation of orientational entropy. - """ - - def __init__(self): - """ - Initializes the Neighbors class with placeholders for data, - including the system trajectory, groups, and levels. + """Compute neighbour-count and orientational metadata observables.""" + + def __init__(self, search: Search | None = None) -> None: + self._search = search or Search() + + def get_frame_neighbor_counts( + self, + *, + universe: Any, + levels: list[list[str]], + groups: dict[int, list[int]], + frame_source: Any, + frame_index: int, + search_type: str, + ) -> NeighborCounts: + """Return neighbour-count totals for one selected frame. + + The returned ``(total, sample_count)`` pairs are intentionally additive. + Parent-side reducers combine them across frames and divide at finalisation. """ - - self._universe = None - self._groups = None - self._levels = None - self._search = Search() - - def get_neighbors(self, universe, levels, groups, frame_source, search_type): - """Find average neighbour counts for each molecule group. - - Args: - universe: MDAnalysis universe object for the active analysis system. - levels: Level list for each molecule. - groups: Mapping of group id to molecule ids. - frame_source: FrameSource controlling selected trajectory access. - search_type: Neighbour search method, either ``"RAD"`` or ``"grid"``. - - Returns: - Dict mapping group id to average number of neighbours. - - Raises: - ValueError: If ``search_type`` is unknown. - """ - frame_indices = [ - int(frame_index) for frame_index in frame_source.iter_indices() - ] - n_frames = len(frame_indices) - - if n_frames <= 0: - return {group_id: 0.0 for group_id in groups.keys()} - - number_neighbors = {} - average_number_neighbors = {} - - for group_id in groups.keys(): - molecules = groups[group_id] - highest_level = levels[molecules[0]][-1] - - for mol_id in molecules: - for frame_index in frame_indices: - if search_type == "RAD": - neighbors = self._search.get_RAD_neighbors( - universe=universe, - mol_id=mol_id, - frame_source=frame_source, - frame_index=frame_index, - ) - - elif search_type == "grid": - neighbors = self._search.get_grid_neighbors( - universe=universe, - mol_id=mol_id, - highest_level=highest_level, - frame_source=frame_source, - frame_index=frame_index, - ) - else: - raise ValueError(f"unknown search_type {search_type}") - - number_neighbors.setdefault(group_id, []).append(len(neighbors)) - - number = np.sum(number_neighbors[group_id]) - average_number_neighbors[group_id] = number / (len(molecules) * n_frames) - logger.debug( - "group: %s number neighbors %s", - group_id, - average_number_neighbors[group_id], - ) - - return average_number_neighbors - - def get_symmetry(self, universe, groups): - """ - Calculate symmetry number for the molecule. - - This function converts the MDAnalysis instance of a molecule into - an RDKit object and then calculates the symmetry number and - determines if the molecule is linear. - - Args: - universe: MDAnalysis object - mol_id: index of the molecule of interest - - Returns: - symmetry_number: dict, symmetry number of each group - linear: dict, linear for each group - """ - symmetry_number = {} - linear = {} - - for group_id in groups.keys(): - molecules = groups[group_id] + frame_index = int(frame_index) + frame_counts: NeighborCounts = {} + + for group_id, molecule_ids in groups.items(): + if not molecule_ids: + frame_counts[group_id] = (0, 0) + continue + + highest_level = levels[molecule_ids[0]][-1] + total_neighbors = 0 + sample_count = 0 + + for molecule_id in molecule_ids: + neighbors = self._get_neighbors_for_molecule( + universe=universe, + molecule_id=molecule_id, + highest_level=highest_level, + frame_source=frame_source, + frame_index=frame_index, + search_type=search_type, + ) + total_neighbors += len(neighbors) + sample_count += 1 + + frame_counts[group_id] = (total_neighbors, sample_count) + + return frame_counts + + def get_symmetry( + self, + *, + universe: Any, + groups: dict[int, list[int]], + ) -> tuple[dict[int, int], dict[int, bool]]: + """Return symmetry numbers and linearity flags for each molecule group.""" + symmetry_number: dict[int, int] = {} + linear: dict[int, bool] = {} + + for group_id, molecule_ids in groups.items(): + if not molecule_ids: + symmetry_number[group_id] = 0 + linear[group_id] = False + continue rdkit_mol, number_heavy, number_hydrogen = self._get_rdkit_mol( - universe, molecules[0] + universe, + molecule_ids[0], ) - symmetry_number[group_id] = self._get_symmetry_number( - rdkit_mol, number_heavy, number_hydrogen + rdkit_mol, + number_heavy, + number_hydrogen, ) - linear[group_id] = self._get_linear(rdkit_mol, number_heavy) - logger.debug( - f"group: {group_id}, symmetry: {symmetry_number}, linear: {linear}" + return symmetry_number, linear + + def _get_neighbors_for_molecule( + self, + *, + universe: Any, + molecule_id: int, + highest_level: str, + frame_source: Any, + frame_index: int, + search_type: str, + ) -> Any: + """Run the configured neighbour search for one molecule and frame.""" + if search_type == "RAD": + return self._search.get_RAD_neighbors( + universe=universe, + mol_id=molecule_id, + frame_source=frame_source, + frame_index=frame_index, ) - return symmetry_number, linear + if search_type == "grid": + return self._search.get_grid_neighbors( + universe=universe, + mol_id=molecule_id, + highest_level=highest_level, + frame_source=frame_source, + frame_index=frame_index, + ) - def _get_rdkit_mol(self, universe, mol_id): - """ - Convert molecule to rdkit object. - - MDAnalysis convert_to(RDKIT) needs elements. - We are removing dummy atoms and converting to rkdit format. - If there are dummy atoms you need inferrer=None otherwise you - get errors from it getting the valence wrong. - If possible it is better to use the inferrer to get the bonds - and hybridization correct. - The convert_to argument force=True forces it to continue even when - it cannot find hydrogens, this is needed to avoid errors for molecules - like carbon dioxide which do not have hydrogens. - - Args: - universe: MDAnalysis object - mol_id: index of the molecule of interest - - Returns: - rdkit_mol - number_heavy - number_hydrogen - """ + raise ValueError(f"unknown search_type {search_type}") + @staticmethod + def _get_rdkit_mol(universe: Any, molecule_id: int) -> tuple[Any, int, int]: + """Convert one molecular fragment into an RDKit molecule.""" if not hasattr(universe.atoms, "elements"): universe.guess_TopologyAttrs(to_guess=["elements"]) - molecule = universe.atoms.fragments[mol_id] + molecule = universe.atoms.fragments[molecule_id] + dummy_atoms = molecule.select_atoms("prop mass < 0.1") - dummy = molecule.select_atoms("prop mass < 0.1") - if len(dummy) > 0: - frag = molecule.select_atoms("prop mass > 0.1") - rdkit_mol = frag.convert_to("RDKIT", force=True, inferrer=None) - logger.debug("Warning: Dummy atoms found") + if len(dummy_atoms) > 0: + fragment = molecule.select_atoms("prop mass > 0.1") + rdkit_mol = fragment.convert_to("RDKIT", force=True, inferrer=None) else: try: rdkit_mol = molecule.convert_to("RDKIT", force=True) except Exception: - logger.debug("Warning: Constraint bonds to H atoms found") rdkit_mol = molecule.convert_to("RDKIT", force=True, inferrer=None) number_heavy = rdkit_mol.GetNumHeavyAtoms() number_hydrogen = rdkit_mol.GetNumAtoms() - number_heavy - return rdkit_mol, number_heavy, number_hydrogen - def _get_symmetry_number(self, rdkit_mol, number_heavy, number_hydrogen): - """ - Calculate symmetry number for the molecule. - - For larger molecules, removing the hydrogens reduces the over counting - of symmetry states. When there is only one heavy atom the hydrogens - are important to the symmetry. If there is a single heavy atom with - no hydrogens then the molecule is spherically symmetric. - - Using the RDKit GetSubstructMatches function often works well as - a guess for the symmetry number, but it occasionally returns a - symmetry number 2x the expected value (for example, cyclohexane). - - Args: - rdkit_mol: rdkit object for molecule of interest - number_heavy (int): number of heavy atoms - number_hydrogen (int): number of hydrogen atoms - - Returns: - symmetry_number (int): symmetry number of molecule - """ - + @staticmethod + def _get_symmetry_number( + rdkit_mol: Any, + number_heavy: int, + number_hydrogen: int, + ) -> int: + """Calculate the molecular symmetry number used by orientational entropy.""" if number_heavy > 1: - rdkit_heavy = Chem.RemoveHs(rdkit_mol) + heavy_atom_mol = Chem.RemoveHs(rdkit_mol) matches = rdkit_mol.GetSubstructMatches( - rdkit_heavy, uniquify=False, useChirality=True + heavy_atom_mol, + uniquify=False, + useChirality=True, ) - symmetry_number = len(matches) - elif number_hydrogen > 0: + return len(matches) + + if number_hydrogen > 0: matches = rdkit_mol.GetSubstructMatches( - rdkit_mol, uniquify=False, useChirality=True + rdkit_mol, + uniquify=False, + useChirality=True, ) - symmetry_number = len(matches) - else: - symmetry_number = 0 - - return symmetry_number + return len(matches) - def _get_linear(self, rdkit_mol, number_heavy): - """ - Determine if the molecule is linear. + return 0 - We are not considering the hydrogens, just the united atom beads. - So, molecules like methanol are treated as linear since they have only - two united atoms. + @staticmethod + def _get_linear(rdkit_mol: Any, number_heavy: int) -> bool: + """Return whether a molecule is treated as linear.""" + if number_heavy == 1: + return False - Args: - rkdit_mol: rdkit object for molecule of interest - number_heavy (int): number of heavy atoms + if number_heavy == 2: + return True - Returns: - linear (bool): True if molecule linear - """ - linear = False - if number_heavy == 1: - linear = False - elif number_heavy == 2: - linear = True - else: - rdkit_heavy = Chem.RemoveHs(rdkit_mol) - sp_count = 0 - for x in rdkit_heavy.GetAtoms(): - if x.GetHybridization() == Chem.HybridizationType.SP: - sp_count += 1 - if sp_count >= (number_heavy - 2): - linear = True - - return linear + heavy_atom_mol = Chem.RemoveHs(rdkit_mol) + sp_count = sum( + atom.GetHybridization() == Chem.HybridizationType.SP + for atom in heavy_atom_mol.GetAtoms() + ) + return sp_count >= number_heavy - 2 diff --git a/CodeEntropy/levels/nodes/accumulators.py b/CodeEntropy/levels/nodes/accumulators.py index 722cc96b..38902275 100644 --- a/CodeEntropy/levels/nodes/accumulators.py +++ b/CodeEntropy/levels/nodes/accumulators.py @@ -15,15 +15,12 @@ from __future__ import annotations -import logging from collections.abc import MutableMapping from dataclasses import dataclass from typing import Any import numpy as np -logger = logging.getLogger(__name__) - SharedData = MutableMapping[str, Any] @@ -63,10 +60,6 @@ class InitCovarianceAccumulatorsNode: Group index mapping: - group_id_to_index: {group_id: index} - index_to_group_id: [group_id_by_index] - - Backwards-compatible aliases (kept for older consumers): - - force_torque_stats -> forcetorque_covariances - - force_torque_counts -> forcetorque_counts """ def run(self, shared_data: dict[str, Any]) -> dict[str, Any]: @@ -89,8 +82,6 @@ def run(self, shared_data: dict[str, Any]) -> dict[str, Any]: ) self._attach_to_shared_data(shared_data, group_index, accumulators) - self._attach_backwards_compatible_aliases(shared_data) - return self._build_return_payload(shared_data) @staticmethod @@ -161,16 +152,6 @@ def _attach_to_shared_data( shared_data["forcetorque_covariances"] = acc.forcetorque_covariances shared_data["forcetorque_counts"] = acc.forcetorque_counts - @staticmethod - def _attach_backwards_compatible_aliases(shared_data: SharedData) -> None: - """Attach backwards-compatible aliases. - - Args: - shared_data: Shared pipeline dictionary. - """ - shared_data["force_torque_stats"] = shared_data["forcetorque_covariances"] - shared_data["force_torque_counts"] = shared_data["forcetorque_counts"] - @staticmethod def _build_return_payload(shared_data: SharedData) -> dict[str, Any]: """Build the return payload containing initialized keys. @@ -189,6 +170,4 @@ def _build_return_payload(shared_data: SharedData) -> dict[str, Any]: "frame_counts": shared_data["frame_counts"], "forcetorque_covariances": shared_data["forcetorque_covariances"], "forcetorque_counts": shared_data["forcetorque_counts"], - "force_torque_stats": shared_data["force_torque_stats"], - "force_torque_counts": shared_data["force_torque_counts"], } diff --git a/CodeEntropy/levels/nodes/conformations.py b/CodeEntropy/levels/nodes/conformations.py index c73c153d..5b4220a3 100644 --- a/CodeEntropy/levels/nodes/conformations.py +++ b/CodeEntropy/levels/nodes/conformations.py @@ -1,14 +1,7 @@ -"""Compute conformational states for configurational entropy calculations. - -This module defines a static DAG node that scans the selected trajectory frames -and builds conformational state descriptors (united-atom and residue level). -The resulting states are stored in ``shared_data`` for later use by -configurational entropy calculations. -""" +"""Compute conformational states for configurational entropy calculations.""" from __future__ import annotations -from dataclasses import dataclass from typing import Any from CodeEntropy.levels.dihedrals import ConformationStateBuilder @@ -19,76 +12,53 @@ FlexibleStates = dict[str, Any] -@dataclass(frozen=True) -class ConformationalStateConfig: - """Configuration for conformational state construction. - - Attributes: - n_frames: Number of frames to be analysed. - bin_width: Histogram bin width in degrees. - """ - - n_frames: int - bin_width: int - - class ComputeConformationalStatesNode: - """Static node that computes conformational states from trajectory dihedrals. + """Static node that computes conformational states from selected frames. Produces: shared_data["conformational_states"] = {"ua": states_ua, "res": states_res} shared_data["flexible_dihedrals"] = {"ua": flexible_ua, "res": flexible_res} - - Where: - - states_ua is a dict keyed by ``(group_id, local_residue_id)``. - - states_res is a list-like structure indexed by group id. - - flexible_ua is a dict keyed by ``(group_id, local_residue_id)``. - - flexible_res is a list-like structure indexed by group id. - - Notes: - Frame selection is provided through ``shared_data["frame_selection"]``. - During the current migration stage, that selection uses local - analysis-universe frame indices because the workflow still physically - frame-slices the universe. """ - def __init__(self, universe_operations: Any) -> None: - """Initialize the node. + def __init__(self, universe_operations: Any | None = None) -> None: + """Initialise the conformational-state node. Args: - universe_operations: Object providing universe selection utilities used - by ``ConformationStateBuilder``. + universe_operations: Optional universe-operation adapter passed to the + underlying conformation-state builder. """ - self._dihedral_analysis = ConformationStateBuilder( + self._builder = ConformationStateBuilder( universe_operations=universe_operations ) def run( - self, shared_data: SharedData, *, progress: object | None = None + self, + shared_data: SharedData, + *, + progress: object | None = None, ) -> dict[str, ConformationalStates]: - """Compute conformational states and store them in shared_data. + """Compute conformational states and store them in shared workflow data. Args: - shared_data: Shared data dictionary. Requires: - - ``"reduced_universe"`` - - ``"levels"`` - - ``"groups"`` - - ``"frame_selection"`` - - ``"args"`` with attribute ``bin_width`` - progress: Optional progress sink provided by ResultsReporter.progress(). + shared_data: Shared workflow data containing ``reduced_universe``, + ``levels``, ``groups``, ``frame_selection``, and ``args.bin_width``. + progress: Optional progress sink forwarded to the conformation builder. Returns: - Dict containing ``"conformational_states"``. + A dictionary containing the computed ``conformational_states`` mapping. + + Raises: + KeyError: If required entries are missing from ``shared_data``. """ - u = shared_data["reduced_universe"] + universe = shared_data["reduced_universe"] levels = shared_data["levels"] groups = shared_data["groups"] frame_selection: FrameSelection = shared_data["frame_selection"] bin_width = int(shared_data["args"].bin_width) states_ua, states_res, flexible_ua, flexible_res = ( - self._dihedral_analysis.build_conformational_states( - data_container=u, + self._builder.build_conformational_states( + data_container=universe, levels=levels, groups=groups, bin_width=bin_width, @@ -101,12 +71,12 @@ def run( "ua": states_ua, "res": states_res, } - shared_data["conformational_states"] = conformational_states - - flexible_states: FlexibleStates = { + flexible_dihedrals: FlexibleStates = { "ua": flexible_ua, "res": flexible_res, } - shared_data["flexible_dihedrals"] = flexible_states + + shared_data["conformational_states"] = conformational_states + shared_data["flexible_dihedrals"] = flexible_dihedrals return {"conformational_states": conformational_states} diff --git a/CodeEntropy/levels/nodes/covariance.py b/CodeEntropy/levels/nodes/covariance.py index 25375d47..c98d052c 100644 --- a/CodeEntropy/levels/nodes/covariance.py +++ b/CodeEntropy/levels/nodes/covariance.py @@ -18,7 +18,6 @@ from __future__ import annotations -import logging from typing import Any import numpy as np @@ -26,8 +25,6 @@ from CodeEntropy.levels.forces import ForceTorqueCalculator -logger = logging.getLogger(__name__) - FrameCtx = dict[str, Any] Matrix = np.ndarray @@ -49,23 +46,26 @@ class FrameCovarianceNode: """ def __init__(self) -> None: - """Initialise the frame covariance node.""" + """Initialise the frame covariance node. + + Creates the force/torque calculator used by all frame-local covariance helper + methods. + """ self._ft = ForceTorqueCalculator() def run(self, ctx: FrameCtx) -> dict[str, Any]: - """Compute and store per-frame force/torque (and optional FT) matrices. + """Compute frame-local force, torque, and optional force-torque matrices. Args: - ctx: Frame context dict expected to include: - - "shared": dict containing reduced_universe, groups, levels, beads, - args - - shared["axes_manager"] (created in static stage) + ctx: Frame context containing ``shared`` workflow data. The shared data must + provide ``reduced_universe``, ``groups``, ``levels``, ``beads``, and + ``args``. Returns: - The frame covariance payload also stored at ctx["frame_covariance"]. + The frame covariance payload written to ``ctx["frame_covariance"]``. Raises: - KeyError: If ctx is missing required fields. + KeyError: If ``ctx`` or the shared workflow data is missing required keys. """ shared = self._get_shared(ctx) @@ -176,29 +176,22 @@ def _process_united_atom( out_torque: dict[str, dict[Any, Matrix]], molcount: dict[tuple[int, int], int], ) -> None: - """Compute UA-level force/torque second moments for one molecule. - - For each residue in the molecule, bead vectors are computed for all UA - beads in that residue. The resulting second-moment matrices are then - incrementally averaged across molecules in the same group for this frame. + """Compute united-atom second moments for one molecule. Args: - u: MDAnalysis Universe (or compatible) providing atom access. - mol: Molecule/fragment object providing residues/atoms. - mol_id: Molecule id used for bead keying. - group_id: Group identifier used for within-frame averaging. - beads: Mapping from bead keys to lists of atom indices. - axes_manager: Axes manager used to determine axes/centers/MOI. - box: Optional box vector used for PBC-aware displacements. - force_partitioning: Force scaling factor applied at highest level. - customised_axes: Whether to use customised axes methods when available. - is_highest: Whether the UA level is the highest level for the molecule. - out_force: Output accumulator for UA force second moments. - out_torque: Output accumulator for UA torque second moments. - molcount: Per-(group_id, local_res_i) molecule counters for averaging. - - Returns: - None. Mutates out_force/out_torque and molcount in-place. + u: Universe-like object used to resolve bead atom indices. + mol: Molecule fragment containing residues and atoms. + mol_id: Molecule index used in bead lookup keys. + group_id: Molecule-group identifier used for within-frame averaging. + beads: Mapping of bead keys to reduced-universe atom-index arrays. + axes_manager: Axes helper used to build translation and rotation axes. + box: Optional periodic box vector. + force_partitioning: Force partitioning factor for highest-level vectors. + customised_axes: Whether customised UA axes should be used. + is_highest: Whether united atom is the highest active level. + out_force: Frame-local force second-moment accumulator, mutated in place. + out_torque: Frame-local torque second-moment accumulator, mutated in place. + molcount: Per-residue group sample counters, mutated in place. """ for local_res_i, res in enumerate(mol.residues): bead_key = (mol_id, "united_atom", local_res_i) @@ -247,34 +240,24 @@ def _process_residue( molcount: dict[int, int], combined: bool, ) -> None: - """Compute residue-level force/torque (and optional FT) moments for one - molecule. - - Residue bead vectors are constructed for the molecule and used to compute - per-frame force and torque second-moment matrices. Outputs are then - incrementally averaged across molecules in the same group for this frame. - If combined FT matrices are enabled and this is the highest level, a - force-torque block matrix is also constructed and averaged. + """Compute residue-level second moments for one molecule. Args: - u: MDAnalysis Universe (or compatible) providing atom access. - mol: Molecule/fragment object providing atoms/residues. - mol_id: Molecule id used for bead keying. - group_id: Group identifier used for within-frame averaging. - beads: Mapping from bead keys to lists of atom indices. - axes_manager: Axes manager used to determine axes/centers/MOI. - box: Optional box vector used for PBC-aware displacements. - customised_axes: Whether to use customised axes methods when available. - force_partitioning: Force scaling factor applied at highest level. - is_highest: Whether residue level is the highest level for the molecule. - out_force: Output accumulator for residue force second moments. - out_torque: Output accumulator for residue torque second moments. - out_ft: Optional output accumulator for residue combined FT matrices. - molcount: Per-group molecule counter for within-frame averaging. - combined: Whether combined force-torque matrices are enabled. - - Returns: - None. Mutates output dictionaries and molcount in-place. + u: Universe-like object used to resolve bead atom indices. + mol: Molecule fragment containing residues and atoms. + mol_id: Molecule index used in bead lookup keys. + group_id: Molecule-group identifier used for within-frame averaging. + beads: Mapping of bead keys to reduced-universe atom-index arrays. + axes_manager: Axes helper used to build translation and rotation axes. + box: Optional periodic box vector. + customised_axes: Whether customised residue axes should be used. + force_partitioning: Force partitioning factor for highest-level vectors. + is_highest: Whether residue is the highest active level. + out_force: Frame-local force second-moment accumulator, mutated in place. + out_torque: Frame-local torque second-moment accumulator, mutated in place. + out_ft: Optional combined force-torque accumulator, mutated in place. + molcount: Per-group sample counters, mutated in place. + combined: Whether combined force-torque matrices should be produced. """ bead_key = (mol_id, "residue") bead_idx_list = beads.get(bead_key, []) @@ -328,34 +311,23 @@ def _process_polymer( molcount: dict[int, int], combined: bool, ) -> None: - """Compute polymer-level force/torque (and optional FT) moments for one - molecule. - - Polymer level uses a single bead. Translation/rotation axes, center, and - principal moments of inertia are computed, then used to build the - generalized force and torque vectors. Outputs are incrementally averaged - across molecules in the same group for this frame. If combined FT matrices - are enabled and this is the highest level, a force-torque block matrix is - also constructed and averaged. + """Compute polymer-level second moments for one molecule. Args: - u: MDAnalysis Universe (or compatible) providing atom access. - mol: Molecule/fragment object providing atoms. - mol_id: Molecule id used for bead keying. - group_id: Group identifier used for within-frame averaging. - beads: Mapping from bead keys to lists of atom indices. - axes_manager: Axes manager used to determine axes/centers/MOI. - box: Optional box vector used for PBC-aware displacements. - force_partitioning: Force scaling factor applied at highest level. - is_highest: Whether polymer level is the highest level for the molecule. - out_force: Output accumulator for polymer force second moments. - out_torque: Output accumulator for polymer torque second moments. - out_ft: Optional output accumulator for polymer combined FT matrices. - molcount: Per-group molecule counter for within-frame averaging. - combined: Whether combined force-torque matrices are enabled. - - Returns: - None. Mutates output dictionaries and molcount in-place. + u: Universe-like object used to resolve bead atom indices. + mol: Molecule fragment containing atoms. + mol_id: Molecule index used in bead lookup keys. + group_id: Molecule-group identifier used for within-frame averaging. + beads: Mapping of bead keys to reduced-universe atom-index arrays. + axes_manager: Axes helper used to build translation and rotation axes. + box: Optional periodic box vector. + force_partitioning: Force partitioning factor for highest-level vectors. + is_highest: Whether polymer is the highest active level. + out_force: Frame-local force second-moment accumulator, mutated in place. + out_torque: Frame-local torque second-moment accumulator, mutated in place. + out_ft: Optional combined force-torque accumulator, mutated in place. + molcount: Per-group sample counters, mutated in place. + combined: Whether combined force-torque matrices should be produced. """ bead_key = (mol_id, "polymer") bead_idx_list = beads.get(bead_key, []) @@ -420,20 +392,19 @@ def _build_ua_vectors( customised_axes: bool, is_highest: bool, ) -> tuple[list[np.ndarray], list[np.ndarray]]: - """Build force/torque vectors for UA-level beads of one residue. + """Build force and torque vectors for united-atom beads. Args: - bead_groups: List of UA bead AtomGroups for the residue. - residue_atoms: AtomGroup for the residue atoms (used for axes when vanilla). - axes_manager: Axes manager used to determine axes/centers/MOI. - box: Optional box vector used for PBC-aware displacements. - force_partitioning: Force scaling factor applied at highest level. - customised_axes: Whether to use customised axes methods when available. - is_highest: Whether UA level is the highest level for the molecule. + bead_groups: Atom groups representing UA beads in a residue. + residue_atoms: Atom group for the parent residue. + axes_manager: Axes helper used to select axes, centres, and moments. + box: Optional periodic box vector. + force_partitioning: Force partitioning factor for highest-level vectors. + customised_axes: Whether customised UA axes should be used. + is_highest: Whether UA is the highest active level. Returns: - A tuple (force_vecs, torque_vecs), each a list of (3,) vectors ordered - by UA bead index within the residue. + A tuple containing lists of force vectors and torque vectors. """ force_vecs: list[np.ndarray] = [] torque_vecs: list[np.ndarray] = [] @@ -484,20 +455,19 @@ def _build_residue_vectors( force_partitioning: float, is_highest: bool, ) -> tuple[list[np.ndarray], list[np.ndarray]]: - """Build force/torque vectors for residue-level beads of one molecule. + """Build force and torque vectors for residue beads. Args: - mol: Molecule/fragment object providing residues/atoms. - bead_groups: List of residue bead AtomGroups for the molecule. - axes_manager: Axes manager used to determine axes/centers/MOI. - box: Optional box vector used for PBC-aware displacements. - customised_axes: Whether to use customised axes methods when available. - force_partitioning: Force scaling factor applied at highest level. - is_highest: Whether residue level is the highest level for the molecule. + mol: Molecule fragment containing residues and atoms. + bead_groups: Atom groups representing residue beads. + axes_manager: Axes helper used to select axes, centres, and moments. + box: Optional periodic box vector. + customised_axes: Whether customised residue axes should be used. + force_partitioning: Force partitioning factor for highest-level vectors. + is_highest: Whether residue is the highest active level. Returns: - A tuple (force_vecs, torque_vecs), each a list of (3,) vectors ordered - by residue index within the molecule. + A tuple containing lists of force vectors and torque vectors. """ force_vecs: list[np.ndarray] = [] torque_vecs: list[np.ndarray] = [] @@ -542,21 +512,17 @@ def _get_residue_axes( axes_manager: Any, customised_axes: bool, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Get translation/rotation axes, center and MOI for a residue bead. + """Return axes, centre, and inertia data for a residue bead. Args: - mol: Molecule/fragment object providing residues/atoms. - bead: Residue bead AtomGroup. - local_res_i: Residue index within the molecule. - axes_manager: Axes manager used to determine axes/centers/MOI. - customised_axes: Whether to use customised axes methods when available. + mol: Molecule fragment containing residues and atoms. + bead: Atom group representing the residue bead. + local_res_i: Residue index local to ``mol``. + axes_manager: Axes helper used to select axes, centres, and moments. + customised_axes: Whether customised residue axes should be used. Returns: - Tuple (trans_axes, rot_axes, center, moi) where: - - trans_axes: (3, 3) translation axes - - rot_axes: (3, 3) rotation axes - - center: (3,) center of mass - - moi: (3,) principal moments of inertia + A tuple of translation axes, rotation axes, centre, and moments of inertia. """ if customised_axes: res = mol.residues[local_res_i] @@ -582,16 +548,15 @@ def _get_polymer_axes( bead: Any, axes_manager: Any, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Get translation/rotation axes, center and MOI for a polymer bead. + """Return axes, centre, and inertia data for a polymer bead. Args: - mol: Molecule/fragment object providing atoms. - bead: Polymer bead AtomGroup. - axes_manager: Axes manager used to determine axes/centers/MOI. + mol: Molecule fragment containing atoms. + bead: Atom group representing the polymer bead. + axes_manager: Axes helper used to select axes, centres, and moments. Returns: - Tuple (trans_axes, rot_axes, center, moi) with shapes (3,3), (3,3), (3,), - and (3,) respectively. + A tuple of translation axes, rotation axes, centre, and moments of inertia. """ make_whole(mol.atoms) make_whole(bead) @@ -609,16 +574,16 @@ def _get_polymer_axes( @staticmethod def _get_shared(ctx: FrameCtx) -> dict[str, Any]: - """Fetch shared context from a frame context dict. + """Return shared workflow data from a frame context. Args: - ctx: Frame context dictionary expected to contain a "shared" key. + ctx: Frame-local context dictionary. Returns: - The shared context dict stored at ctx["shared"]. + The shared workflow data stored at ``ctx["shared"]``. Raises: - KeyError: If "shared" is not present in ctx. + KeyError: If ``ctx`` does not contain a ``shared`` entry. """ if "shared" not in ctx: raise KeyError("FrameCovarianceNode expects ctx['shared'].") @@ -626,14 +591,13 @@ def _get_shared(ctx: FrameCtx) -> dict[str, Any]: @staticmethod def _try_get_box(u: Any) -> np.ndarray | None: - """Extract a (3,) box vector from an MDAnalysis universe when available. + """Extract periodic box lengths from a universe-like object. Args: - u: MDAnalysis Universe (or compatible) that may expose dimensions. + u: Universe-like object that may expose ``dimensions``. Returns: - A numpy array of shape (3,) containing box lengths, or None if not - available. + A three-element NumPy array of box lengths, or ``None`` if unavailable. """ try: return np.asarray(u.dimensions[:3], dtype=float) @@ -642,15 +606,15 @@ def _try_get_box(u: Any) -> np.ndarray | None: @staticmethod def _inc_mean(old: np.ndarray | None, new: np.ndarray, n: int) -> np.ndarray: - """Compute an incremental mean (streaming average). + """Update a running mean with one new sample. Args: - old: Previous running mean value, or None for the first sample. + old: Existing running mean, or ``None`` for the first sample. new: New sample to incorporate. - n: 1-based sample count after adding the new sample. + n: One-based sample count after adding ``new``. Returns: - Updated running mean. + The updated running mean. """ if old is None: return new.copy() @@ -660,21 +624,18 @@ def _inc_mean(old: np.ndarray | None, new: np.ndarray, n: int) -> np.ndarray: def _build_ft_block( force_vecs: list[np.ndarray], torque_vecs: list[np.ndarray] ) -> np.ndarray: - """Build a combined force-torque block matrix for a frame. - - For each bead i, create a 6-vector [Fi, Ti]. The block matrix is built - from outer products of these 6-vectors. + """Build a combined force-torque block matrix. Args: - force_vecs: List of per-bead force vectors, each of shape (3,). - torque_vecs: List of per-bead torque vectors, each of shape (3,). + force_vecs: Per-bead force vectors with length three. + torque_vecs: Per-bead torque vectors with length three. Returns: - A block matrix of shape (6N, 6N) where N is the number of beads. + A block matrix with shape ``(6N, 6N)`` for ``N`` bead vectors. Raises: - ValueError: If force_vecs and torque_vecs have different lengths, if no - bead vectors are provided, or if any input vector is not length 3. + ValueError: If the vector lists differ in length, are empty, or contain + vectors that are not length three. """ if len(force_vecs) != len(torque_vecs): raise ValueError("force_vecs and torque_vecs must have the same length.") diff --git a/CodeEntropy/levels/nodes/find_neighbors.py b/CodeEntropy/levels/nodes/find_neighbors.py deleted file mode 100644 index 9b905b2a..00000000 --- a/CodeEntropy/levels/nodes/find_neighbors.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Find neighbor and symmetry info for orientational entropy calculations. - -This module defines a static DAG node that scans the trajectory and -finds neighbors and symmetry numbers. The resulting states are stored -in `shared_data` for later use by configurational entropy calculations. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any - -from CodeEntropy.levels.neighbors import Neighbors - -SharedData = dict[str, Any] -ConformationalStates = dict[str, Any] - - -@dataclass(frozen=True) -class NeighborConfig: - """Configuration for neighbor finding. - - Attributes: - start: Start frame index (inclusive). - end: End frame index (exclusive). - step: Frame stride. - """ - - start: int - end: int - step: int - - -class ComputeNeighborsNode: - """Static node that finds neighbors from trajectory. - - Produces: - shared_data["neighbors"] = {} - shared_data["symmetry_number"] = {} - shared_data["linear"] = {} - - Where: - - neighbors is the average number of neighbors - - symmetry_number is the symmetry number of the molecule, int - - linear is a boolean; True for linear, False for non-linear - """ - - def __init__(self) -> None: - """Initialize the node.""" - self._neighbor_analysis = Neighbors() - - def run( - self, shared_data: SharedData, *, progress: object | None = None - ) -> SharedData: - """Compute neighbour and symmetry information and store it in shared_data. - - Args: - shared_data: Shared data dictionary. Requires: - - ``reduced_universe`` - - ``levels`` - - ``groups`` - - ``frame_source`` - - ``args.search_type`` - progress: Optional progress sink. Currently unused. - - Returns: - The mutated shared data dictionary. - """ - u = shared_data["reduced_universe"] - levels = shared_data["levels"] - groups = shared_data["groups"] - frame_source = shared_data["frame_source"] - search_type = shared_data["args"].search_type - - number_neighbors = self._neighbor_analysis.get_neighbors( - universe=u, - levels=levels, - groups=groups, - frame_source=frame_source, - search_type=search_type, - ) - - symmetry_number, linear = self._neighbor_analysis.get_symmetry( - universe=u, - groups=groups, - ) - - shared_data["neighbors"] = number_neighbors - shared_data["symmetry_number"] = symmetry_number - shared_data["linear"] = linear - - return shared_data diff --git a/docs/api/CodeEntropy.core.dask_clusters.rst b/docs/api/CodeEntropy.core.dask_clusters.rst new file mode 100644 index 00000000..6430e64c --- /dev/null +++ b/docs/api/CodeEntropy.core.dask_clusters.rst @@ -0,0 +1,7 @@ +CodeEntropy.core.dask\_clusters module +====================================== + +.. automodule:: CodeEntropy.core.dask_clusters + :members: + :show-inheritance: + :undoc-members: diff --git a/docs/api/CodeEntropy.core.rst b/docs/api/CodeEntropy.core.rst index 9457905b..695ee31d 100644 --- a/docs/api/CodeEntropy.core.rst +++ b/docs/api/CodeEntropy.core.rst @@ -12,4 +12,5 @@ Submodules .. toctree:: :maxdepth: 4 + CodeEntropy.core.dask_clusters CodeEntropy.core.logging diff --git a/docs/api/CodeEntropy.levels.nodes.find_neighbors.rst b/docs/api/CodeEntropy.levels.nodes.find_neighbors.rst deleted file mode 100644 index 99bde1f6..00000000 --- a/docs/api/CodeEntropy.levels.nodes.find_neighbors.rst +++ /dev/null @@ -1,7 +0,0 @@ -CodeEntropy.levels.nodes.find\_neighbors module -=============================================== - -.. automodule:: CodeEntropy.levels.nodes.find_neighbors - :members: - :show-inheritance: - :undoc-members: diff --git a/docs/api/CodeEntropy.levels.nodes.rst b/docs/api/CodeEntropy.levels.nodes.rst index e6112640..7ce5004f 100644 --- a/docs/api/CodeEntropy.levels.nodes.rst +++ b/docs/api/CodeEntropy.levels.nodes.rst @@ -18,4 +18,3 @@ Submodules CodeEntropy.levels.nodes.covariance CodeEntropy.levels.nodes.detect_levels CodeEntropy.levels.nodes.detect_molecules - CodeEntropy.levels.nodes.find_neighbors diff --git a/tests/unit/CodeEntropy/levels/execution/test_chunks.py b/tests/unit/CodeEntropy/levels/execution/test_chunks.py new file mode 100644 index 00000000..4348e85f --- /dev/null +++ b/tests/unit/CodeEntropy/levels/execution/test_chunks.py @@ -0,0 +1,28 @@ +"""Unit tests for frame chunking helpers.""" + +from __future__ import annotations + +import pytest + +from CodeEntropy.levels.execution.chunks import chunk_frame_indices + + +def test_chunk_frame_indices_splits_into_fixed_size_chunks(): + assert chunk_frame_indices([0, 1, 2, 3, 4], chunk_size=2) == [ + (0, 1), + (2, 3), + (4,), + ] + + +def test_chunk_frame_indices_returns_empty_list_for_no_frames(): + assert chunk_frame_indices([], chunk_size=3) == [] + + +def test_chunk_frame_indices_returns_single_chunk_when_chunk_size_exceeds_frames(): + assert chunk_frame_indices([1, 2], chunk_size=10) == [(1, 2)] + + +def test_chunk_frame_indices_rejects_non_positive_chunk_size(): + with pytest.raises(ValueError, match="chunk_size must be >= 1"): + chunk_frame_indices([0], chunk_size=0) diff --git a/tests/unit/CodeEntropy/levels/execution/test_policy.py b/tests/unit/CodeEntropy/levels/execution/test_policy.py new file mode 100644 index 00000000..a0dbc8a8 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/execution/test_policy.py @@ -0,0 +1,94 @@ +"""Unit tests for internal frame execution policy.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from CodeEntropy.levels.execution.policy import ExecutionPolicy + + +def test_requested_worker_count_uses_dask_workers_when_provided(): + policy = ExecutionPolicy() + shared_data = { + "args": SimpleNamespace( + dask_workers=8, + hpc=True, + hpc_nodes=2, + hpc_processes=4, + ) + } + + assert policy.requested_worker_count(shared_data) == 8 + + +def test_requested_worker_count_clamps_dask_workers_to_at_least_one(): + policy = ExecutionPolicy() + shared_data = {"args": SimpleNamespace(dask_workers=0, hpc=False)} + + assert policy.requested_worker_count(shared_data) == 1 + + +def test_requested_worker_count_uses_hpc_nodes_and_processes_without_local_workers(): + policy = ExecutionPolicy() + shared_data = { + "args": SimpleNamespace( + dask_workers=None, + hpc=True, + hpc_nodes=3, + hpc_processes=2, + ) + } + + assert policy.requested_worker_count(shared_data) == 6 + + +def test_requested_worker_count_clamps_hpc_values_to_at_least_one(): + policy = ExecutionPolicy() + shared_data = { + "args": SimpleNamespace( + dask_workers=None, + hpc=True, + hpc_nodes=0, + hpc_processes=0, + ) + } + + assert policy.requested_worker_count(shared_data) == 1 + + +def test_requested_worker_count_defaults_to_one_without_args(): + assert ExecutionPolicy().requested_worker_count({}) == 1 + + +def test_requested_worker_count_defaults_to_one_for_non_parallel_run(): + policy = ExecutionPolicy() + shared_data = {"args": SimpleNamespace(dask_workers=None, hpc=False)} + + assert policy.requested_worker_count(shared_data) == 1 + + +def test_frame_chunk_size_is_deterministic_and_clamped(): + policy = ExecutionPolicy( + target_frame_chunks_per_worker=2, + min_frame_chunk_size=2, + max_frame_chunk_size=10, + ) + shared_data = {"args": SimpleNamespace(dask_workers=4, hpc=False)} + + assert policy.frame_chunk_size(shared_data, n_frames=100) == 10 + assert policy.frame_chunk_size(shared_data, n_frames=3) == 2 + + +def test_frame_chunk_size_treats_zero_frames_as_one(): + policy = ExecutionPolicy(min_frame_chunk_size=1, max_frame_chunk_size=32) + shared_data = {"args": SimpleNamespace(dask_workers=1, hpc=False)} + + assert policy.frame_chunk_size(shared_data, n_frames=0) == 1 + + +def test_max_frame_in_flight_tasks_is_bounded_by_chunk_count(): + policy = ExecutionPolicy(max_frame_in_flight_multiplier=2) + shared_data = {"args": SimpleNamespace(dask_workers=4, hpc=False)} + + assert policy.max_frame_in_flight_tasks(shared_data, n_chunks=3) == 3 + assert policy.max_frame_in_flight_tasks(shared_data, n_chunks=20) == 8 diff --git a/tests/unit/CodeEntropy/levels/execution/test_reducers.py b/tests/unit/CodeEntropy/levels/execution/test_reducers.py new file mode 100644 index 00000000..0d6fed62 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/execution/test_reducers.py @@ -0,0 +1,327 @@ +"""Unit tests for frame map-reduce reducers.""" + +from __future__ import annotations + +import numpy as np + +from CodeEntropy.levels.execution.reducers import ( + CovarianceReducer, + NeighborReducer, + incremental_mean, + merge_means, + stable_keys, +) +from CodeEntropy.levels.execution.tasks import CovarianceChunkPartial + + +def _shared_covariance_state() -> dict: + return { + "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, + "frame_counts": { + "ua": {}, + "res": np.zeros(1, dtype=int), + "poly": np.zeros(1, dtype=int), + }, + "forcetorque_covariances": {"res": [None], "poly": [None]}, + "forcetorque_counts": { + "res": np.zeros(1, dtype=int), + "poly": np.zeros(1, dtype=int), + }, + "group_id_to_index": {7: 0}, + "groups": {7: [0]}, + } + + +def test_stable_keys_orders_mixed_key_types_deterministically(): + mapping = {(2, 0): "tuple", 1: "int", "a": "str"} + + assert stable_keys(mapping) == [1, "a", (2, 0)] + + +def test_merge_means_returns_old_mean_when_new_count_is_zero(): + old = np.array([1.0, 2.0]) + + assert merge_means(old, old_n=2, new_mean=np.array([9.0, 9.0]), new_n=0) is old + + +def test_merge_means_returns_copy_for_first_numpy_value(): + new = np.array([1.0, 2.0]) + + merged = merge_means(None, old_n=0, new_mean=new, new_n=3) + + np.testing.assert_allclose(merged, new) + new[0] = 99.0 + assert merged[0] != 99.0 + + +def test_merge_means_combines_weighted_means(): + old = np.array([2.0, 4.0]) + new = np.array([8.0, 10.0]) + + merged = merge_means(old, old_n=2, new_mean=new, new_n=1) + + np.testing.assert_allclose(merged, np.array([4.0, 6.0])) + + +def test_incremental_mean_returns_copy_for_first_numpy_value(): + new = np.array([1.0, 2.0]) + + out = incremental_mean(None, new, n=1) + + np.testing.assert_allclose(out, new) + new[0] = 99.0 + assert out[0] != 99.0 + + +def test_incremental_mean_updates_mean(): + old = np.array([2.0, 2.0]) + new = np.array([4.0, 0.0]) + + out = incremental_mean(old, new, n=2) + + np.testing.assert_allclose(out, np.array([3.0, 1.0])) + + +def test_neighbor_reducer_initialise_merge_and_finalise(): + shared_data = {"groups": {7: [0], 9: [1]}} + + NeighborReducer.initialise(shared_data) + NeighborReducer.merge_chunk_partial( + shared_data, + neighbor_totals={7: 6, 9: 0}, + neighbor_samples={7: 3, 9: 0}, + ) + NeighborReducer.finalise(shared_data) + + assert shared_data["neighbor_totals"] == {7: 6, 9: 0} + assert shared_data["neighbor_samples"] == {7: 3, 9: 0} + assert shared_data["neighbors"] == {7: 2.0, 9: 0.0} + + +def test_neighbor_reducer_reduce_frame_output_none_is_noop(): + shared_data = {"groups": {7: [0]}} + NeighborReducer.initialise(shared_data) + + NeighborReducer.reduce_frame_output(shared_data, None) + + assert shared_data["neighbor_totals"] == {7: 0} + assert shared_data["neighbor_samples"] == {7: 0} + + +def test_neighbor_reducer_reduce_frame_output_merges_counts(): + shared_data = {"groups": {7: [0]}} + NeighborReducer.initialise(shared_data) + + NeighborReducer.reduce_frame_output(shared_data, {7: (4, 2)}) + + assert shared_data["neighbor_totals"] == {7: 4} + assert shared_data["neighbor_samples"] == {7: 2} + + +def test_neighbor_reducer_merge_chunk_partial_noops_if_not_initialised(): + shared_data = {} + + NeighborReducer.merge_chunk_partial( + shared_data, + neighbor_totals={7: 1}, + neighbor_samples={7: 1}, + ) + + assert shared_data == {} + + +def test_covariance_reducer_reduce_frame_map_output_merges_covariance_and_neighbors(): + shared_data = _shared_covariance_state() + NeighborReducer.initialise(shared_data) + + force = np.eye(3) + torque = 2.0 * np.eye(3) + ft = np.ones((6, 6)) + + frame_out = { + "covariance": { + "force": {"ua": {(7, 0): force}, "res": {7: force}, "poly": {}}, + "torque": {"ua": {(7, 0): torque}, "res": {7: torque}, "poly": {}}, + "forcetorque": {"res": {7: ft}, "poly": {}}, + }, + "neighbors": {7: (5, 1)}, + } + + CovarianceReducer().reduce_frame_map_output(shared_data, frame_out) + + np.testing.assert_allclose(shared_data["force_covariances"]["ua"][(7, 0)], force) + np.testing.assert_allclose(shared_data["torque_covariances"]["ua"][(7, 0)], torque) + np.testing.assert_allclose(shared_data["force_covariances"]["res"][0], force) + np.testing.assert_allclose(shared_data["torque_covariances"]["res"][0], torque) + np.testing.assert_allclose(shared_data["forcetorque_covariances"]["res"][0], ft) + + assert shared_data["frame_counts"]["ua"][(7, 0)] == 1 + assert shared_data["frame_counts"]["res"][0] == 1 + assert shared_data["forcetorque_counts"]["res"][0] == 1 + assert shared_data["neighbor_totals"] == {7: 5} + assert shared_data["neighbor_samples"] == {7: 1} + + +def test_covariance_reducer_reduce_frame_map_output_accepts_missing_sections(): + shared_data = _shared_covariance_state() + NeighborReducer.initialise(shared_data) + + CovarianceReducer().reduce_frame_map_output(shared_data, {}) + + assert shared_data["frame_counts"]["res"][0] == 0 + assert shared_data["neighbor_totals"] == {7: 0} + + +def test_covariance_reducer_merge_chunk_partial(): + shared_data = _shared_covariance_state() + + force = np.eye(3) + torque = 2.0 * np.eye(3) + ft = np.ones((6, 6)) + + partial = CovarianceChunkPartial() + partial.force["ua"][(7, 0)] = force + partial.torque["ua"][(7, 0)] = torque + partial.frame_counts["ua"][(7, 0)] = 2 + partial.force["res"][7] = force + partial.torque["res"][7] = torque + partial.frame_counts["res"][7] = 2 + partial.force["poly"][7] = force + partial.torque["poly"][7] = torque + partial.frame_counts["poly"][7] = 2 + partial.forcetorque["res"][7] = ft + partial.forcetorque_counts["res"][7] = 2 + partial.forcetorque["poly"][7] = ft + partial.forcetorque_counts["poly"][7] = 2 + + CovarianceReducer().merge_chunk_partial(shared_data, partial) + + np.testing.assert_allclose(shared_data["force_covariances"]["ua"][(7, 0)], force) + np.testing.assert_allclose(shared_data["torque_covariances"]["ua"][(7, 0)], torque) + np.testing.assert_allclose(shared_data["force_covariances"]["res"][0], force) + np.testing.assert_allclose(shared_data["torque_covariances"]["res"][0], torque) + np.testing.assert_allclose(shared_data["force_covariances"]["poly"][0], force) + np.testing.assert_allclose(shared_data["torque_covariances"]["poly"][0], torque) + np.testing.assert_allclose(shared_data["forcetorque_covariances"]["res"][0], ft) + np.testing.assert_allclose(shared_data["forcetorque_covariances"]["poly"][0], ft) + + assert shared_data["frame_counts"]["ua"][(7, 0)] == 2 + assert shared_data["frame_counts"]["res"][0] == 2 + assert shared_data["frame_counts"]["poly"][0] == 2 + assert shared_data["forcetorque_counts"]["res"][0] == 2 + assert shared_data["forcetorque_counts"]["poly"][0] == 2 + + +def test_covariance_reducer_reduce_frame_output_handles_torque_only_branches(): + shared_data = _shared_covariance_state() + + ua_torque = np.eye(3) + res_torque = 2.0 * np.eye(3) + poly_torque = 3.0 * np.eye(3) + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": { + "ua": {(7, 0): ua_torque}, + "res": {7: res_torque}, + "poly": {7: poly_torque}, + }, + } + + CovarianceReducer().reduce_frame_output(shared_data, frame_out) + + assert shared_data["frame_counts"]["ua"][(7, 0)] == 1 + assert shared_data["frame_counts"]["res"][0] == 1 + assert shared_data["frame_counts"]["poly"][0] == 1 + + np.testing.assert_allclose( + shared_data["torque_covariances"]["ua"][(7, 0)], + ua_torque, + ) + np.testing.assert_allclose( + shared_data["torque_covariances"]["res"][0], + res_torque, + ) + np.testing.assert_allclose( + shared_data["torque_covariances"]["poly"][0], + poly_torque, + ) + + assert shared_data["force_covariances"]["ua"] == {} + assert shared_data["force_covariances"]["res"][0] is None + assert shared_data["force_covariances"]["poly"][0] is None + + +def test_covariance_reducer_reduce_frame_output_updates_poly_force_and_torque(): + shared_data = _shared_covariance_state() + + poly_force = np.eye(3) + poly_torque = 2.0 * np.eye(3) + + frame_out = { + "force": { + "ua": {}, + "res": {}, + "poly": {7: poly_force}, + }, + "torque": { + "ua": {}, + "res": {}, + "poly": {7: poly_torque}, + }, + } + + CovarianceReducer().reduce_frame_output(shared_data, frame_out) + + assert shared_data["frame_counts"]["poly"][0] == 1 + np.testing.assert_allclose( + shared_data["force_covariances"]["poly"][0], + poly_force, + ) + np.testing.assert_allclose( + shared_data["torque_covariances"]["poly"][0], + poly_torque, + ) + + +def test_covariance_reducer_reduce_frame_output_without_forcetorque_is_noop_for_ft(): + shared_data = _shared_covariance_state() + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + } + + CovarianceReducer().reduce_frame_output(shared_data, frame_out) + + assert shared_data["forcetorque_counts"]["res"][0] == 0 + assert shared_data["forcetorque_counts"]["poly"][0] == 0 + assert shared_data["forcetorque_covariances"]["res"][0] is None + assert shared_data["forcetorque_covariances"]["poly"][0] is None + + +def test_covariance_reducer_reduce_frame_output_updates_poly_forcetorque(): + shared_data = _shared_covariance_state() + + poly_ft = np.ones((6, 6)) + + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + "forcetorque": { + "res": {}, + "poly": {7: poly_ft}, + }, + } + + CovarianceReducer().reduce_frame_output(shared_data, frame_out) + + assert shared_data["forcetorque_counts"]["poly"][0] == 1 + np.testing.assert_allclose( + shared_data["forcetorque_covariances"]["poly"][0], + poly_ft, + ) + + assert shared_data["forcetorque_counts"]["res"][0] == 0 + assert shared_data["forcetorque_covariances"]["res"][0] is None diff --git a/tests/unit/CodeEntropy/levels/execution/test_scheduler.py b/tests/unit/CodeEntropy/levels/execution/test_scheduler.py new file mode 100644 index 00000000..fb0ce05c --- /dev/null +++ b/tests/unit/CodeEntropy/levels/execution/test_scheduler.py @@ -0,0 +1,399 @@ +"""Unit tests for serial and Dask frame schedulers.""" + +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, call, patch + +import pytest + +from CodeEntropy.levels.execution.policy import ExecutionPolicy +from CodeEntropy.levels.execution.scheduler import FrameScheduler +from CodeEntropy.levels.execution.tasks import ( + CovarianceChunkPartial, + FrameChunkResult, + FrameChunkTask, +) + + +def _scheduler(policy: ExecutionPolicy | MagicMock | None = None) -> FrameScheduler: + return FrameScheduler( + frame_dag=MagicMock(), + policy=policy or ExecutionPolicy(), + universe_operations=MagicMock(), + ) + + +def _chunk_result(chunk_index: int, frame_indices: tuple[int, ...]) -> FrameChunkResult: + return FrameChunkResult( + chunk_index=chunk_index, + covariance_partial=CovarianceChunkPartial(), + neighbor_totals={0: len(frame_indices)}, + neighbor_samples={0: len(frame_indices)}, + frame_indices=frame_indices, + ) + + +def test_execute_creates_progress_task_when_progress_is_supplied(): + scheduler = _scheduler() + scheduler._run_serial = MagicMock() + + progress = MagicMock() + progress.add_task.return_value = "task-id" + + scheduler.execute({}, frame_indices=[0], progress=progress) + + progress.add_task.assert_called_once_with( + "[green]Frame processing", + total=1, + title="Initializing frame stage", + ) + scheduler._run_serial.assert_called_once_with( + {}, + frame_indices=[0], + progress=progress, + task="task-id", + ) + + +def test_execute_uses_dask_when_client_is_available_and_multiple_frames(): + scheduler = _scheduler() + scheduler._run_dask = MagicMock() + scheduler._run_serial = MagicMock() + + client = MagicMock() + shared_data = {"dask_client": client, "parallel_frames": True} + + scheduler.execute(shared_data, frame_indices=[0, 1], progress=None) + + scheduler._run_dask.assert_called_once_with( + shared_data, + frame_indices=[0, 1], + client=client, + progress=None, + task=None, + ) + scheduler._run_serial.assert_not_called() + + +def test_execute_uses_serial_when_only_one_frame_even_with_client(): + scheduler = _scheduler() + scheduler._run_dask = MagicMock() + scheduler._run_serial = MagicMock() + + shared_data = {"dask_client": MagicMock(), "parallel_frames": True} + + scheduler.execute(shared_data, frame_indices=[0]) + + scheduler._run_dask.assert_not_called() + scheduler._run_serial.assert_called_once() + + +def test_execute_uses_serial_when_no_client(): + scheduler = _scheduler() + scheduler._run_dask = MagicMock() + scheduler._run_serial = MagicMock() + + shared_data = {"parallel_frames": True} + + scheduler.execute(shared_data, frame_indices=[0, 1]) + + scheduler._run_dask.assert_not_called() + scheduler._run_serial.assert_called_once() + + +def test_run_serial_executes_and_reduces_each_frame(): + scheduler = _scheduler() + shared_data = {"groups": {0: [0]}} + progress = MagicMock() + task_id = "task-id" + + frame_out0 = {"covariance": "cov0", "neighbors": {0: (1, 1)}} + frame_out1 = {"covariance": "cov1", "neighbors": {0: (2, 1)}} + + with patch( + "CodeEntropy.levels.execution.scheduler.execute_frame_map_output", + side_effect=[frame_out0, frame_out1], + ) as execute_frame: + scheduler._covariance_reducer.reduce_frame_map_output = MagicMock() + + scheduler._run_serial( + shared_data, + frame_indices=[0, 1], + progress=progress, + task=task_id, + ) + + assert execute_frame.call_count == 2 + scheduler._covariance_reducer.reduce_frame_map_output.assert_has_calls( + [ + call(shared_data, frame_out0), + call(shared_data, frame_out1), + ] + ) + progress.update.assert_has_calls( + [ + call(task_id, title="Frame 0"), + call(task_id, title="Frame 1"), + ] + ) + assert progress.advance.call_count == 2 + + +def test_make_frame_chunk_tasks_uses_policy_chunk_size(): + policy = ExecutionPolicy(target_frame_chunks_per_worker=1) + scheduler = _scheduler(policy=policy) + shared_data = {"args": SimpleNamespace(dask_workers=2, hpc=False)} + + tasks = scheduler._make_frame_chunk_tasks(shared_data, [0, 1, 2, 3, 4]) + + assert tasks == [ + FrameChunkTask(chunk_index=0, frame_indices=(0, 1, 2)), + FrameChunkTask(chunk_index=1, frame_indices=(3, 4)), + ] + + +def test_run_dask_scatters_worker_shared_once_and_reduces_in_chunk_order(): + policy = MagicMock() + policy.max_frame_in_flight_tasks.return_value = 2 + scheduler = _scheduler(policy=policy) + + frame_tasks = [ + FrameChunkTask(chunk_index=0, frame_indices=(0,)), + FrameChunkTask(chunk_index=1, frame_indices=(1,)), + ] + scheduler._make_frame_chunk_tasks = MagicMock(return_value=frame_tasks) + scheduler._covariance_reducer.merge_chunk_partial = MagicMock() + + shared_data = { + "groups": {0: [0]}, + "args": SimpleNamespace(dask_workers=2, hpc=False), + "force_covariances": "parent-only", + "frame_source": "kept", + } + + worker_future = MagicMock(name="worker_shared_future") + future_zero = MagicMock(name="future_zero") + future_one = MagicMock(name="future_one") + future_zero.result.return_value = _chunk_result(0, (0,)) + future_one.result.return_value = _chunk_result(1, (1,)) + + client = MagicMock() + client.scatter.return_value = [worker_future] + client.submit.side_effect = [future_zero, future_one] + + fake_distributed = types.ModuleType("distributed") + fake_distributed.wait = MagicMock(return_value=({future_one, future_zero}, set())) + + with ( + patch.dict(sys.modules, {"distributed": fake_distributed}), + patch( + "CodeEntropy.levels.execution.scheduler.execute_frame_chunk_worker" + ) as worker_func, + patch( + "CodeEntropy.levels.execution.scheduler.NeighborReducer.merge_chunk_partial" + ) as merge_neighbors, + ): + scheduler._run_dask( + shared_data, + frame_indices=[0, 1], + client=client, + progress=None, + task=None, + ) + + client.scatter.assert_called_once() + scattered_payload = client.scatter.call_args.args[0] + assert isinstance(scattered_payload, list) + assert len(scattered_payload) == 1 + assert "force_covariances" not in scattered_payload[0] + assert scattered_payload[0]["frame_source"] == "kept" + + client.submit.assert_has_calls( + [ + call( + worker_func, + frame_tasks[0], + worker_future, + scheduler._universe_operations, + pure=False, + ), + call( + worker_func, + frame_tasks[1], + worker_future, + scheduler._universe_operations, + pure=False, + ), + ] + ) + + scheduler._covariance_reducer.merge_chunk_partial.assert_has_calls( + [ + call(shared_data, future_zero.result.return_value.covariance_partial), + call(shared_data, future_one.result.return_value.covariance_partial), + ] + ) + merge_neighbors.assert_has_calls( + [ + call(shared_data, {0: 1}, {0: 1}), + call(shared_data, {0: 1}, {0: 1}), + ] + ) + + future_zero.release.assert_called_once() + future_one.release.assert_called_once() + worker_future.release.assert_called_once() + + +def test_run_dask_submits_more_tasks_as_futures_complete(): + policy = MagicMock() + policy.max_frame_in_flight_tasks.return_value = 1 + scheduler = _scheduler(policy=policy) + + frame_tasks = [ + FrameChunkTask(chunk_index=0, frame_indices=(0,)), + FrameChunkTask(chunk_index=1, frame_indices=(1,)), + ] + scheduler._make_frame_chunk_tasks = MagicMock(return_value=frame_tasks) + scheduler._covariance_reducer.merge_chunk_partial = MagicMock() + + worker_future = MagicMock() + future_zero = MagicMock() + future_one = MagicMock() + future_zero.result.return_value = _chunk_result(0, (0,)) + future_one.result.return_value = _chunk_result(1, (1,)) + + client = MagicMock() + client.scatter.return_value = [worker_future] + client.submit.side_effect = [future_zero, future_one] + + fake_distributed = types.ModuleType("distributed") + fake_distributed.wait = MagicMock( + side_effect=[ + ({future_zero}, set()), + ({future_one}, set()), + ] + ) + + progress = MagicMock() + + with ( + patch.dict(sys.modules, {"distributed": fake_distributed}), + patch( + "CodeEntropy.levels.execution.scheduler.NeighborReducer.merge_chunk_partial" + ), + ): + scheduler._run_dask( + {"groups": {0: [0]}, "args": SimpleNamespace(dask_workers=1, hpc=False)}, + frame_indices=[0, 1], + client=client, + progress=progress, + task="task-id", + ) + + assert client.submit.call_count == 2 + assert progress.advance.call_count == 2 + worker_future.release.assert_called_once() + + +def test_run_dask_cancels_active_futures_and_releases_scattered_data_on_error(): + policy = MagicMock() + policy.max_frame_in_flight_tasks.return_value = 1 + scheduler = _scheduler(policy=policy) + + task = FrameChunkTask(chunk_index=0, frame_indices=(0,)) + scheduler._make_frame_chunk_tasks = MagicMock(return_value=[task]) + + worker_future = MagicMock() + failed_future = MagicMock() + failed_future.result.side_effect = RuntimeError("worker failed") + + client = MagicMock() + client.scatter.return_value = [worker_future] + client.submit.return_value = failed_future + + fake_distributed = types.ModuleType("distributed") + fake_distributed.wait = MagicMock(return_value=({failed_future}, set())) + + with patch.dict(sys.modules, {"distributed": fake_distributed}): + with pytest.raises(RuntimeError, match="worker failed"): + scheduler._run_dask( + { + "groups": {0: [0]}, + "args": SimpleNamespace(dask_workers=1, hpc=False), + }, + frame_indices=[0], + client=client, + progress=None, + task=None, + ) + + client.cancel.assert_called_once() + worker_future.release.assert_called_once() + + +def test_run_dask_raises_when_distributed_is_missing(): + scheduler = _scheduler() + client = MagicMock() + + real_import = __import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "distributed": + raise ImportError("No module named distributed") + return real_import(name, globals, locals, fromlist, level) + + with patch("builtins.__import__", side_effect=fake_import): + with pytest.raises(RuntimeError, match="requires dask.distributed"): + scheduler._run_dask( + { + "groups": {0: [0]}, + "args": SimpleNamespace(dask_workers=1, hpc=False), + }, + frame_indices=[0], + client=client, + progress=None, + task=None, + ) + + +def test_run_dask_raises_if_completed_frame_count_mismatches(): + policy = MagicMock() + policy.max_frame_in_flight_tasks.return_value = 1 + scheduler = _scheduler(policy=policy) + + task = FrameChunkTask(chunk_index=0, frame_indices=(0,)) + scheduler._make_frame_chunk_tasks = MagicMock(return_value=[task]) + scheduler._covariance_reducer.merge_chunk_partial = MagicMock() + + worker_future = MagicMock() + future = MagicMock() + future.result.return_value = _chunk_result(0, (0,)) + + client = MagicMock() + client.scatter.return_value = [worker_future] + client.submit.return_value = future + + fake_distributed = types.ModuleType("distributed") + fake_distributed.wait = MagicMock(return_value=({future}, set())) + + with patch.dict(sys.modules, {"distributed": fake_distributed}): + with pytest.raises( + RuntimeError, + match="Parallel frame execution completed 1 frames, but expected 2", + ): + scheduler._run_dask( + { + "groups": {0: [0]}, + "args": SimpleNamespace(dask_workers=1, hpc=False), + }, + frame_indices=[0, 1], + client=client, + progress=None, + task=None, + ) + + worker_future.release.assert_called_once() diff --git a/tests/unit/CodeEntropy/levels/execution/test_tasks.py b/tests/unit/CodeEntropy/levels/execution/test_tasks.py new file mode 100644 index 00000000..534df137 --- /dev/null +++ b/tests/unit/CodeEntropy/levels/execution/test_tasks.py @@ -0,0 +1,433 @@ +"""Unit tests for frame-chunk task helpers.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np + +from CodeEntropy.levels.execution.tasks import ( + CovarianceChunkPartial, + FrameChunkResult, + FrameChunkTask, + execute_frame_chunk_worker, + execute_frame_map_output, + incremental_mean_value, + make_frame_worker_shared_data, + reduce_frame_covariance_into_partial, +) + + +def _frame_covariance(force_value: float) -> dict: + force = force_value * np.eye(3) + torque = (force_value + 1.0) * np.eye(3) + + return { + "force": {"ua": {(0, 0): force}, "res": {0: force}, "poly": {}}, + "torque": {"ua": {(0, 0): torque}, "res": {0: torque}, "poly": {}}, + } + + +def test_make_frame_worker_shared_data_excludes_parent_owned_state(): + shared_data = { + "force_covariances": "exclude", + "torque_covariances": "exclude", + "forcetorque_covariances": "exclude", + "frame_counts": "exclude", + "forcetorque_counts": "exclude", + "neighbor_totals": "exclude", + "neighbor_samples": "exclude", + "n_frames": 10, + "entropy_manager": "exclude", + "run_manager": "exclude", + "reporter": "exclude", + "dask_client": "exclude", + "frame_source": "keep", + "levels": "keep", + "groups": "keep", + "args": "keep", + } + + assert make_frame_worker_shared_data(shared_data) == { + "frame_source": "keep", + "levels": "keep", + "groups": "keep", + "args": "keep", + } + + +def test_frame_chunk_task_contains_only_lightweight_task_descriptor(): + task = FrameChunkTask(chunk_index=3, frame_indices=(10, 11)) + + assert task.chunk_index == 3 + assert task.frame_indices == (10, 11) + assert not hasattr(task, "worker_shared_data") + assert not hasattr(task, "include_neighbors") + + +def test_incremental_mean_value_returns_copy_for_first_numpy_value(): + new = np.array([1.0, 2.0]) + + out = incremental_mean_value(None, new, n=1) + + np.testing.assert_allclose(out, new) + new[0] = 99.0 + assert out[0] != 99.0 + + +def test_incremental_mean_value_handles_non_copyable_first_value(): + assert incremental_mean_value(None, 3.0, n=1) == 3.0 + + +def test_incremental_mean_value_updates_mean(): + old = np.array([2.0, 2.0]) + new = np.array([4.0, 0.0]) + + np.testing.assert_allclose( + incremental_mean_value(old, new, n=2), + np.array([3.0, 1.0]), + ) + + +def test_reduce_frame_covariance_into_partial_accumulates_running_means(): + partial = CovarianceChunkPartial() + + reduce_frame_covariance_into_partial(partial, _frame_covariance(1.0)) + reduce_frame_covariance_into_partial(partial, _frame_covariance(3.0)) + + np.testing.assert_allclose(partial.force["ua"][(0, 0)], 2.0 * np.eye(3)) + np.testing.assert_allclose(partial.torque["ua"][(0, 0)], 3.0 * np.eye(3)) + np.testing.assert_allclose(partial.force["res"][0], 2.0 * np.eye(3)) + np.testing.assert_allclose(partial.torque["res"][0], 3.0 * np.eye(3)) + + assert partial.frame_counts["ua"][(0, 0)] == 2 + assert partial.frame_counts["res"][0] == 2 + + +def test_reduce_frame_covariance_into_partial_handles_missing_force_keys_for_torque(): + partial = CovarianceChunkPartial() + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": { + "ua": {(0, 0): np.eye(3)}, + "res": {0: np.eye(3)}, + "poly": {0: np.eye(3)}, + }, + } + + reduce_frame_covariance_into_partial(partial, frame_out) + + assert partial.frame_counts["ua"][(0, 0)] == 1 + assert partial.frame_counts["res"][0] == 1 + assert partial.frame_counts["poly"][0] == 1 + + +def test_reduce_frame_covariance_into_partial_handles_forcetorque_blocks(): + partial = CovarianceChunkPartial() + ft = np.ones((6, 6)) + frame_out = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + "forcetorque": {"res": {0: ft}, "poly": {0: 2.0 * ft}}, + } + + reduce_frame_covariance_into_partial(partial, frame_out) + + np.testing.assert_allclose(partial.forcetorque["res"][0], ft) + np.testing.assert_allclose(partial.forcetorque["poly"][0], 2.0 * ft) + assert partial.forcetorque_counts["res"][0] == 1 + assert partial.forcetorque_counts["poly"][0] == 1 + + +def test_execute_frame_map_output_runs_covariance_and_neighbor_count(): + frame_dag = MagicMock() + frame_dag.execute_frame.return_value = _frame_covariance(1.0) + + neighbor_helper = MagicMock() + neighbor_helper.get_frame_neighbor_counts.return_value = {0: (4, 2)} + + shared_data = { + "reduced_universe": "universe", + "levels": [["united_atom"]], + "groups": {0: [0]}, + "frame_source": "frame-source", + "args": SimpleNamespace(search_type="RAD"), + } + + out = execute_frame_map_output( + shared_data=shared_data, + frame_index="5", + frame_dag=frame_dag, + neighbor_helper=neighbor_helper, + ) + + frame_dag.execute_frame.assert_called_once_with(shared_data, 5) + neighbor_helper.get_frame_neighbor_counts.assert_called_once_with( + universe="universe", + levels=[["united_atom"]], + groups={0: [0]}, + frame_source="frame-source", + frame_index=5, + search_type="RAD", + ) + + assert out["covariance"] == frame_dag.execute_frame.return_value + assert out["neighbors"] == {0: (4, 2)} + + +def test_execute_frame_map_output_constructs_neighbor_helper_when_not_provided(): + frame_dag = MagicMock() + frame_dag.execute_frame.return_value = _frame_covariance(1.0) + + shared_data = { + "universe": "fallback-universe", + "levels": [["united_atom"]], + "groups": {0: [0]}, + "frame_source": "frame-source", + "args": SimpleNamespace(search_type="RAD"), + } + + with patch("CodeEntropy.levels.execution.tasks.Neighbors") as Neighbors: + helper = Neighbors.return_value + helper.get_frame_neighbor_counts.return_value = {0: (1, 1)} + + out = execute_frame_map_output( + shared_data=shared_data, + frame_index=0, + frame_dag=frame_dag, + ) + + helper.get_frame_neighbor_counts.assert_called_once_with( + universe="fallback-universe", + levels=[["united_atom"]], + groups={0: [0]}, + frame_source="frame-source", + frame_index=0, + search_type="RAD", + ) + assert out["neighbors"] == {0: (1, 1)} + + +def test_execute_frame_chunk_worker_returns_compact_partials(): + worker_shared_data = { + "reduced_universe": "universe", + "levels": [["united_atom"]], + "groups": {0: [0]}, + "frame_source": "frame-source", + "args": SimpleNamespace(search_type="RAD"), + } + task = FrameChunkTask(chunk_index=1, frame_indices=(0, 1)) + universe_operations = object() + + graph = MagicMock() + graph.execute_frame.side_effect = [ + _frame_covariance(1.0), + _frame_covariance(3.0), + ] + + neighbor_helper = MagicMock() + neighbor_helper.get_frame_neighbor_counts.side_effect = [ + {0: (2, 1)}, + {0: (4, 1)}, + ] + + with ( + patch("CodeEntropy.levels.execution.tasks.FrameGraph") as FrameGraph, + patch("CodeEntropy.levels.execution.tasks.Neighbors") as Neighbors, + ): + FrameGraph.return_value.build.return_value = graph + Neighbors.return_value = neighbor_helper + + result = execute_frame_chunk_worker( + task, + worker_shared_data, + universe_operations=universe_operations, + ) + + FrameGraph.assert_called_once_with(universe_operations=universe_operations) + graph.execute_frame.assert_any_call(worker_shared_data, 0) + graph.execute_frame.assert_any_call(worker_shared_data, 1) + + assert isinstance(result, FrameChunkResult) + assert result.chunk_index == 1 + assert result.frame_indices == (0, 1) + assert result.neighbor_totals == {0: 6} + assert result.neighbor_samples == {0: 2} + np.testing.assert_allclose( + result.covariance_partial.force["ua"][(0, 0)], + 2.0 * np.eye(3), + ) + + +def test_covariance_chunk_partial_default_factories_are_independent(): + partial_a = CovarianceChunkPartial() + partial_b = CovarianceChunkPartial() + + partial_a.force["ua"][(0, 0)] = "value" + partial_a.frame_counts["res"][0] = 1 + partial_a.forcetorque["poly"][0] = "ft" + + assert partial_b.force["ua"] == {} + assert partial_b.frame_counts["res"] == {} + assert partial_b.forcetorque["poly"] == {} + + +def test_reduce_frame_covariance_into_partial_accumulates_poly_force_and_torque(): + partial = CovarianceChunkPartial() + + poly_force_1 = np.eye(3) + poly_torque_1 = 2.0 * np.eye(3) + + poly_force_2 = 3.0 * np.eye(3) + poly_torque_2 = 4.0 * np.eye(3) + + frame_out_1 = { + "force": {"ua": {}, "res": {}, "poly": {0: poly_force_1}}, + "torque": {"ua": {}, "res": {}, "poly": {0: poly_torque_1}}, + } + frame_out_2 = { + "force": {"ua": {}, "res": {}, "poly": {0: poly_force_2}}, + "torque": {"ua": {}, "res": {}, "poly": {0: poly_torque_2}}, + } + + reduce_frame_covariance_into_partial(partial, frame_out_1) + reduce_frame_covariance_into_partial(partial, frame_out_2) + + assert partial.frame_counts["poly"][0] == 2 + np.testing.assert_allclose( + partial.force["poly"][0], + 2.0 * np.eye(3), + ) + np.testing.assert_allclose( + partial.torque["poly"][0], + 3.0 * np.eye(3), + ) + + +def test_reduce_frame_covariance_into_partial_accumulates_forcetorque_running_means(): + partial = CovarianceChunkPartial() + + res_ft_1 = np.ones((6, 6)) + res_ft_2 = 3.0 * np.ones((6, 6)) + + poly_ft_1 = 2.0 * np.ones((6, 6)) + poly_ft_2 = 4.0 * np.ones((6, 6)) + + frame_out_1 = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + "forcetorque": { + "res": {0: res_ft_1}, + "poly": {0: poly_ft_1}, + }, + } + frame_out_2 = { + "force": {"ua": {}, "res": {}, "poly": {}}, + "torque": {"ua": {}, "res": {}, "poly": {}}, + "forcetorque": { + "res": {0: res_ft_2}, + "poly": {0: poly_ft_2}, + }, + } + + reduce_frame_covariance_into_partial(partial, frame_out_1) + reduce_frame_covariance_into_partial(partial, frame_out_2) + + assert partial.forcetorque_counts["res"][0] == 2 + assert partial.forcetorque_counts["poly"][0] == 2 + + np.testing.assert_allclose( + partial.forcetorque["res"][0], + 2.0 * np.ones((6, 6)), + ) + np.testing.assert_allclose( + partial.forcetorque["poly"][0], + 3.0 * np.ones((6, 6)), + ) + + +def test_execute_frame_chunk_worker_handles_empty_chunk(): + worker_shared_data = { + "reduced_universe": "universe", + "levels": [["united_atom"]], + "groups": {0: [0], 1: [1]}, + "frame_source": "frame-source", + "args": SimpleNamespace(search_type="RAD"), + } + task = FrameChunkTask(chunk_index=4, frame_indices=()) + + graph = MagicMock() + neighbor_helper = MagicMock() + + with ( + patch("CodeEntropy.levels.execution.tasks.FrameGraph") as FrameGraph, + patch("CodeEntropy.levels.execution.tasks.Neighbors") as Neighbors, + ): + FrameGraph.return_value.build.return_value = graph + Neighbors.return_value = neighbor_helper + + result = execute_frame_chunk_worker(task, worker_shared_data) + + FrameGraph.assert_called_once_with(universe_operations=None) + graph.execute_frame.assert_not_called() + neighbor_helper.get_frame_neighbor_counts.assert_not_called() + + assert isinstance(result, FrameChunkResult) + assert result.chunk_index == 4 + assert result.frame_indices == () + assert result.neighbor_totals == {0: 0, 1: 0} + assert result.neighbor_samples == {0: 0, 1: 0} + assert result.covariance_partial.force == {"ua": {}, "res": {}, "poly": {}} + assert result.covariance_partial.torque == {"ua": {}, "res": {}, "poly": {}} + + +def test_execute_frame_chunk_worker_falls_back_to_universe_and_adds_new_neighbor(): + worker_shared_data = { + "universe": "fallback-universe", + "levels": [["united_atom"]], + "groups": {0: [0]}, + "frame_source": "frame-source", + "args": SimpleNamespace(search_type="grid"), + } + task = FrameChunkTask(chunk_index=2, frame_indices=("5",)) + + graph = MagicMock() + graph.execute_frame.return_value = _frame_covariance(1.0) + + neighbor_helper = MagicMock() + neighbor_helper.get_frame_neighbor_counts.return_value = { + 99: (3, 2), + } + + universe_operations = object() + + with ( + patch("CodeEntropy.levels.execution.tasks.FrameGraph") as FrameGraph, + patch("CodeEntropy.levels.execution.tasks.Neighbors") as Neighbors, + ): + FrameGraph.return_value.build.return_value = graph + Neighbors.return_value = neighbor_helper + + result = execute_frame_chunk_worker( + task, + worker_shared_data, + universe_operations=universe_operations, + ) + + FrameGraph.assert_called_once_with(universe_operations=universe_operations) + graph.execute_frame.assert_called_once_with(worker_shared_data, 5) + + neighbor_helper.get_frame_neighbor_counts.assert_called_once_with( + universe="fallback-universe", + levels=[["united_atom"]], + groups={0: [0]}, + frame_source="frame-source", + frame_index=5, + search_type="grid", + ) + + assert result.chunk_index == 2 + assert result.frame_indices == ("5",) + assert result.neighbor_totals == {0: 0, 99: 3} + assert result.neighbor_samples == {0: 0, 99: 2} diff --git a/tests/unit/CodeEntropy/levels/nodes/test_beads_node.py b/tests/unit/CodeEntropy/levels/nodes/test_beads_node.py new file mode 100644 index 00000000..606929ec --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_beads_node.py @@ -0,0 +1,238 @@ +"""Atomic unit tests for bead-definition construction.""" + +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np +import pytest + +from CodeEntropy.levels.nodes.beads import BuildBeadsNode + + +class FakeHeavyAtoms: + """Minimal heavy-atom selection result.""" + + def __init__(self, resindices): + self._atoms = [SimpleNamespace(resindex=resindex) for resindex in resindices] + + def __len__(self): + return len(self._atoms) + + def __getitem__(self, index): + return self._atoms[index] + + +class FakeBead: + """Minimal AtomGroup-like bead.""" + + def __init__(self, indices, *, heavy_resindices=None): + self.indices = np.asarray(indices, dtype=int) + self._heavy_resindices = [] if heavy_resindices is None else heavy_resindices + + def __len__(self): + return int(self.indices.size) + + def select_atoms(self, selection): + assert selection == "prop mass > 1.1" + return FakeHeavyAtoms(self._heavy_resindices) + + +class FakeMolecule: + """Minimal molecule fragment with residues.""" + + def __init__(self, name, residue_indices): + self.name = name + self.residues = [ + SimpleNamespace(resindex=resindex) for resindex in residue_indices + ] + + +class FakeUniverse: + """Minimal universe with atom fragments.""" + + def __init__(self, fragments): + self.atoms = SimpleNamespace(fragments=fragments) + + +class FakeHierarchy: + """Controlled HierarchyBuilder test double.""" + + def __init__(self, bead_map): + self.bead_map = bead_map + self.calls = [] + + def get_beads(self, mol, level): + self.calls.append((mol.name, level)) + return self.bead_map.get((mol.name, level), []) + + +def test_run_builds_all_requested_bead_levels_and_writes_shared_data(): + mol0 = FakeMolecule("mol0", residue_indices=[10, 20]) + mol1 = FakeMolecule("mol1", residue_indices=[30]) + universe = FakeUniverse([mol0, mol1]) + + hierarchy = FakeHierarchy( + { + ("mol0", "united_atom"): [ + FakeBead([0, 1], heavy_resindices=[10]), + FakeBead([2], heavy_resindices=[20]), + FakeBead([], heavy_resindices=[]), + ], + ("mol0", "residue"): [FakeBead([0, 1]), FakeBead([2])], + ("mol0", "polymer"): [FakeBead([0, 1, 2])], + ("mol1", "residue"): [FakeBead([3, 4])], + } + ) + + shared_data = { + "reduced_universe": universe, + "levels": [ + ["united_atom", "residue", "polymer"], + ["residue"], + ], + } + + result = BuildBeadsNode(hierarchy=hierarchy).run(shared_data) + + beads = shared_data["beads"] + assert result == {"beads": beads} + + assert hierarchy.calls == [ + ("mol0", "united_atom"), + ("mol0", "residue"), + ("mol0", "polymer"), + ("mol1", "residue"), + ] + + np.testing.assert_array_equal(beads[(0, "united_atom", 0)][0], np.array([0, 1])) + np.testing.assert_array_equal(beads[(0, "united_atom", 1)][0], np.array([2])) + + np.testing.assert_array_equal(beads[(0, "residue")][0], np.array([0, 1])) + np.testing.assert_array_equal(beads[(0, "residue")][1], np.array([2])) + np.testing.assert_array_equal(beads[(0, "polymer")][0], np.array([0, 1, 2])) + np.testing.assert_array_equal(beads[(1, "residue")][0], np.array([3, 4])) + + +def test_run_requires_reduced_universe(): + with pytest.raises(KeyError): + BuildBeadsNode(hierarchy=FakeHierarchy({})).run({"levels": []}) + + +def test_run_requires_levels(): + with pytest.raises(KeyError): + BuildBeadsNode(hierarchy=FakeHierarchy({})).run( + {"reduced_universe": FakeUniverse([])} + ) + + +def test_add_united_atom_beads_creates_empty_bucket_for_missing_residue(): + mol = FakeMolecule("mol", residue_indices=[10, 20, 30]) + hierarchy = FakeHierarchy( + { + ("mol", "united_atom"): [ + FakeBead([5], heavy_resindices=[20]), + ] + } + ) + beads = {} + + BuildBeadsNode(hierarchy=hierarchy)._add_united_atom_beads( + beads=beads, + mol_id=4, + mol=mol, + ) + + assert beads[(4, "united_atom", 0)] == [] + np.testing.assert_array_equal(beads[(4, "united_atom", 1)][0], np.array([5])) + assert beads[(4, "united_atom", 2)] == [] + + +def test_add_residue_beads_logs_error_when_all_beads_are_empty(caplog): + mol = FakeMolecule("mol", residue_indices=[10]) + hierarchy = FakeHierarchy({("mol", "residue"): [FakeBead([])]}) + beads = {} + + BuildBeadsNode(hierarchy=hierarchy)._add_residue_beads( + beads=beads, + mol_id=1, + mol=mol, + ) + + assert beads[(1, "residue")] == [] + assert "No residue beads kept" in caplog.text + + +def test_add_polymer_beads_skips_empty_beads(): + mol = FakeMolecule("mol", residue_indices=[10]) + hierarchy = FakeHierarchy({("mol", "polymer"): [FakeBead([]), FakeBead([1, 2])]}) + beads = {} + + BuildBeadsNode(hierarchy=hierarchy)._add_polymer_beads( + beads=beads, + mol_id=1, + mol=mol, + ) + + assert len(beads[(1, "polymer")]) == 1 + np.testing.assert_array_equal(beads[(1, "polymer")][0], np.array([1, 2])) + + +def test_validate_bead_indices_returns_copy_and_skips_empty_beads(caplog): + bead = FakeBead([1, 2, 3]) + out = BuildBeadsNode._validate_bead_indices( + bead, + mol_id=1, + level="residue", + bead_i=0, + ) + + np.testing.assert_array_equal(out, np.array([1, 2, 3])) + bead.indices[0] = 99 + assert out[0] == 1 + + empty = BuildBeadsNode._validate_bead_indices( + FakeBead([]), + mol_id=1, + level="residue", + bead_i=1, + ) + + assert empty is None + assert "Empty bead skipped" in caplog.text + + +def test_infer_local_residue_id_uses_first_heavy_atom_resindex(): + mol = FakeMolecule("mol", residue_indices=[10, 20]) + + assert ( + BuildBeadsNode._infer_local_residue_id( + mol=mol, + bead=FakeBead([1], heavy_resindices=[20]), + ) + == 1 + ) + + +def test_infer_local_residue_id_falls_back_to_zero_without_heavy_atoms(): + mol = FakeMolecule("mol", residue_indices=[10, 20]) + + assert ( + BuildBeadsNode._infer_local_residue_id( + mol=mol, + bead=FakeBead([1], heavy_resindices=[]), + ) + == 0 + ) + + +def test_infer_local_residue_id_falls_back_to_zero_when_resindex_not_in_molecule(): + mol = FakeMolecule("mol", residue_indices=[10, 20]) + + assert ( + BuildBeadsNode._infer_local_residue_id( + mol=mol, + bead=FakeBead([1], heavy_resindices=[99]), + ) + == 0 + ) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_build_beads_node.py b/tests/unit/CodeEntropy/levels/nodes/test_build_beads_node.py deleted file mode 100644 index a48e83e1..00000000 --- a/tests/unit/CodeEntropy/levels/nodes/test_build_beads_node.py +++ /dev/null @@ -1,190 +0,0 @@ -from unittest.mock import MagicMock - -import numpy as np - -from CodeEntropy.levels.nodes.beads import BuildBeadsNode - - -def _bead(indices, heavy_resindex=None, empty=False): - b = MagicMock() - b.__len__.return_value = 0 if empty else len(indices) - b.indices = np.asarray(indices, dtype=int) - - heavy = MagicMock() - if heavy_resindex is None: - heavy.__len__.return_value = 0 - heavy.__iter__.return_value = iter([]) - else: - a0 = MagicMock() - a0.resindex = int(heavy_resindex) - heavy.__len__.return_value = 1 - heavy.__getitem__.side_effect = lambda i: a0 - heavy.__iter__.return_value = iter([a0]) - - b.select_atoms.return_value = heavy - return b - - -def test_build_beads_node_groups_united_atom_beads_into_local_residue_buckets(): - r0 = MagicMock() - r0.resindex = 10 - r1 = MagicMock() - r1.resindex = 11 - - mol = MagicMock() - mol.residues = [r0, r1] - - ua0 = _bead([1, 2], heavy_resindex=10) - ua1 = _bead([3], heavy_resindex=11) - ua_empty = _bead([], heavy_resindex=10, empty=True) - - hier = MagicMock() - hier.get_beads.side_effect = lambda m, lvl: ( - [ua0, ua1, ua_empty] if lvl == "united_atom" else [] - ) - - node = BuildBeadsNode(hierarchy=hier) - - u = MagicMock() - u.atoms = MagicMock() - u.atoms.fragments = [mol] - - shared = {"reduced_universe": u, "levels": [["united_atom"]]} - - out = node.run(shared) - beads = out["beads"] - - assert (0, "united_atom", 0) in beads - assert (0, "united_atom", 1) in beads - assert len(beads[(0, "united_atom", 0)]) == 1 - assert len(beads[(0, "united_atom", 1)]) == 1 - - np.testing.assert_array_equal(beads[(0, "united_atom", 0)][0], np.array([1, 2])) - np.testing.assert_array_equal(beads[(0, "united_atom", 1)][0], np.array([3])) - - -def test_add_residue_beads_logs_error_if_none_kept(caplog): - hier = MagicMock() - # returns one empty bead -> skipped -> kept stays 0 - empty_bead = MagicMock() - empty_bead.__len__.return_value = 0 - hier.get_beads.return_value = [empty_bead] - - node = BuildBeadsNode(hierarchy=hier) - - beads = {} - mol = MagicMock() - mol.residues = [MagicMock()] - - node._add_residue_beads(beads=beads, mol_id=0, mol=mol) - - assert (0, "residue") in beads - assert beads[(0, "residue")] == [] - assert any("No residue beads kept" in rec.message for rec in caplog.records) - - -def test_infer_local_residue_id_returns_zero_if_no_heavy_atoms(): - mol = MagicMock() - mol.residues = [MagicMock(resindex=10), MagicMock(resindex=11)] - - bead = MagicMock() - heavy = MagicMock() - heavy.__len__.return_value = 0 - bead.select_atoms.return_value = heavy - - out = BuildBeadsNode._infer_local_residue_id(mol=mol, bead=bead) - assert out == 0 - - -def test_infer_local_residue_id_returns_zero_if_resindex_not_found(): - mol = MagicMock() - mol.residues = [MagicMock(resindex=10), MagicMock(resindex=11)] - - bead = MagicMock() - heavy = MagicMock() - heavy.__len__.return_value = 1 - heavy0 = MagicMock(resindex=999) - heavy.__getitem__.return_value = heavy0 - bead.select_atoms.return_value = heavy - - out = BuildBeadsNode._infer_local_residue_id(mol=mol, bead=bead) - assert out == 0 - - -def test_build_beads_node_skips_when_no_levels(): - """ - Covers: early return when levels missing (92/95 style guard branches) - """ - node = BuildBeadsNode(hierarchy=MagicMock()) - out = node.run({"reduced_universe": MagicMock(), "levels": []}) - assert out["beads"] == {} - - -def test_build_beads_node_residue_level_adds_residue_beads(): - """ - Covers: residue path + _add_residue_beads bookkeeping (log around 145 and 166-177) - """ - r0 = MagicMock(resindex=10) - r1 = MagicMock(resindex=11) - - mol = MagicMock() - mol.residues = [r0, r1] - - res0 = _bead([100, 101], heavy_resindex=10) - res1 = _bead([200], heavy_resindex=11) - - ua0 = _bead([1, 2], heavy_resindex=10) - ua1 = _bead([3], heavy_resindex=11) - - hier = MagicMock() - - def _get_beads(m, lvl): - if lvl == "residue": - return [res0, res1] - if lvl == "united_atom": - return [ua0, ua1] - return [] - - hier.get_beads.side_effect = _get_beads - - node = BuildBeadsNode(hierarchy=hier) - - u = MagicMock() - u.atoms = MagicMock() - u.atoms.fragments = [mol] - - shared = {"reduced_universe": u, "levels": [["residue", "united_atom"]]} - out = node.run(shared) - - beads = out["beads"] - - assert (0, "residue") in beads - assert len(beads[(0, "residue")]) == 2 - assert np.array_equal(beads[(0, "residue")][0], np.array([100, 101])) - assert np.array_equal(beads[(0, "residue")][1], np.array([200])) - - -def test_build_beads_node_polymer_level_adds_polymer_beads_and_skips_empty(): - mol0 = MagicMock() - mol0.residues = [MagicMock(resindex=10)] - - u = MagicMock() - u.atoms.fragments = [mol0] - - polymer_beads = [_bead([]), _bead([7, 8, 9])] - - hier = MagicMock() - hier.get_beads.side_effect = lambda m, lvl: ( - polymer_beads if lvl == "polymer" else [] - ) - - node = BuildBeadsNode(hierarchy=hier) - - shared = {"reduced_universe": u, "levels": [["polymer"]]} - out = node.run(shared) - - beads = out["beads"] - assert (0, "polymer") in beads - - assert len(beads[(0, "polymer")]) == 1 - np.testing.assert_array_equal(beads[(0, "polymer")][0], np.array([7, 8, 9])) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_conformations_node.py b/tests/unit/CodeEntropy/levels/nodes/test_conformations_node.py index 5f927f11..22b81070 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_conformations_node.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_conformations_node.py @@ -1,56 +1,127 @@ +"""Unit tests for the conformational-state static node.""" + +from __future__ import annotations + from types import SimpleNamespace -from unittest.mock import MagicMock +from CodeEntropy.levels.nodes import conformations from CodeEntropy.levels.nodes.conformations import ComputeConformationalStatesNode -from CodeEntropy.trajectory.frames import FrameSelection -def test_compute_conformational_states_node_runs_and_writes_shared_data(): - uops = MagicMock() - node = ComputeConformationalStatesNode(universe_operations=uops) +class FakeConformationStateBuilder: + """Test double for ConformationStateBuilder.""" - frame_selection = FrameSelection.from_bounds(start=0, stop=10, step=1) + def __init__(self, universe_operations): + self.universe_operations = universe_operations + self.calls = [] - node._dihedral_analysis.build_conformational_states = MagicMock( - return_value=( - {"ua_key": ["0", "1"]}, - [["00", "01"]], - {"ua_key": [0]}, - [0], + def build_conformational_states( + self, + *, + data_container, + levels, + groups, + bin_width, + frame_selection, + progress=None, + ): + self.calls.append( + { + "data_container": data_container, + "levels": levels, + "groups": groups, + "bin_width": bin_width, + "frame_selection": frame_selection, + "progress": progress, + } ) + return ( + {"ua_key": ["state_a"]}, + [["res_state"]], + {"ua_key": 1}, + [1], + ) + + +def test_compute_conformational_states_node_runs_and_writes_shared_data(monkeypatch): + builder_holder = {} + + def builder_factory(universe_operations): + builder = FakeConformationStateBuilder(universe_operations) + builder_holder["builder"] = builder + return builder + + monkeypatch.setattr( + conformations, + "ConformationStateBuilder", + builder_factory, ) - shared = { - "reduced_universe": MagicMock(), - "levels": {0: ["united_atom"]}, + universe_operations = object() + node = ComputeConformationalStatesNode(universe_operations) + + universe = object() + frame_selection = object() + progress = object() + + shared_data = { + "reduced_universe": universe, + "levels": [["united_atom", "residue"]], "groups": {0: [0]}, "frame_selection": frame_selection, - "args": SimpleNamespace(bin_width=10), + "args": SimpleNamespace(bin_width=30), } - out = node.run(shared) + result = node.run(shared_data, progress=progress) - assert out == { - "conformational_states": { - "ua": {"ua_key": ["0", "1"]}, - "res": [["00", "01"]], - } + assert shared_data["conformational_states"] == { + "ua": {"ua_key": ["state_a"]}, + "res": [["res_state"]], } - - assert shared["conformational_states"] == { - "ua": {"ua_key": ["0", "1"]}, - "res": [["00", "01"]], + assert shared_data["flexible_dihedrals"] == { + "ua": {"ua_key": 1}, + "res": [1], } - assert shared["flexible_dihedrals"] == { - "ua": {"ua_key": [0]}, - "res": [0], + assert result == { + "conformational_states": shared_data["conformational_states"], } - node._dihedral_analysis.build_conformational_states.assert_called_once_with( - data_container=shared["reduced_universe"], - levels=shared["levels"], - groups=shared["groups"], - bin_width=10, - frame_selection=frame_selection, - progress=None, - ) + builder = builder_holder["builder"] + assert builder.universe_operations is universe_operations + assert builder.calls == [ + { + "data_container": universe, + "levels": [["united_atom", "residue"]], + "groups": {0: [0]}, + "bin_width": 30, + "frame_selection": frame_selection, + "progress": progress, + } + ] + + +def test_compute_conformational_states_node_converts_bin_width_to_int(monkeypatch): + captured = {} + + class Builder: + def __init__(self, universe_operations): + pass + + def build_conformational_states(self, **kwargs): + captured.update(kwargs) + return {}, [], {}, [] + + monkeypatch.setattr(conformations, "ConformationStateBuilder", Builder) + + node = ComputeConformationalStatesNode() + shared_data = { + "reduced_universe": object(), + "levels": [], + "groups": {}, + "frame_selection": object(), + "args": SimpleNamespace(bin_width="45"), + } + + node.run(shared_data) + + assert captured["bin_width"] == 45 diff --git a/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py b/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py new file mode 100644 index 00000000..8398113c --- /dev/null +++ b/tests/unit/CodeEntropy/levels/nodes/test_covariance_node.py @@ -0,0 +1,600 @@ +"""Atomic unit tests for frame-local covariance construction.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from CodeEntropy.levels.nodes.covariance import FrameCovarianceNode + + +class FakeAtomGroup: + """Small AtomGroup-like object for covariance-node tests.""" + + def __init__(self, name="ag", *, length=1): + self.name = name + self._length = length + self.indices = np.arange(length) + + def __len__(self): + return self._length + + def principal_axes(self): + return np.eye(3) + + def center_of_mass(self, unwrap=False): + return np.array([1.0, 2.0, 3.0]) + + +class FakeResidue: + """Small residue-like object.""" + + def __init__(self, atoms=None): + self.atoms = atoms or FakeAtomGroup("residue-atoms") + + +class FakeMolecule: + """Small molecule-like fragment.""" + + def __init__(self, n_residues=1): + self.atoms = FakeAtomGroup("mol-atoms") + self.residues = [FakeResidue() for _ in range(n_residues)] + + +class FakeAtoms: + """Container supporting u.atoms.fragments and u.atoms[index_array].""" + + def __init__(self, fragments, returned_groups=None): + self.fragments = fragments + self.returned_groups = list(returned_groups or []) + + def __getitem__(self, index): + if self.returned_groups: + return self.returned_groups.pop(0) + return FakeAtomGroup(f"group-{index}", length=1) + + +class FakeUniverse: + """Small universe-like object.""" + + def __init__(self, fragments, *, dimensions=None, returned_groups=None): + self.atoms = FakeAtoms(fragments, returned_groups=returned_groups) + if dimensions is not None: + self.dimensions = dimensions + + +def _args( + *, + force_partitioning=0.5, + combined_forcetorque=False, + customised_axes=False, +): + return SimpleNamespace( + force_partitioning=force_partitioning, + combined_forcetorque=combined_forcetorque, + customised_axes=customised_axes, + ) + + +def test_run_processes_all_levels_and_writes_frame_covariance(): + node = FrameCovarianceNode() + node._process_united_atom = MagicMock() + node._process_residue = MagicMock() + node._process_polymer = MagicMock() + + mol = FakeMolecule() + universe = FakeUniverse([mol], dimensions=np.array([10.0, 20.0, 30.0, 90.0])) + axes_manager = object() + + ctx = { + "shared": { + "reduced_universe": universe, + "groups": {7: [0]}, + "levels": [["united_atom", "residue", "polymer"]], + "beads": {}, + "args": _args(combined_forcetorque=True, customised_axes=True), + "axes_manager": axes_manager, + } + } + + result = node.run(ctx) + + assert ctx["frame_covariance"] is result + assert set(result) == {"force", "torque", "forcetorque"} + + node._process_united_atom.assert_called_once() + node._process_residue.assert_called_once() + node._process_polymer.assert_called_once() + + ua_kwargs = node._process_united_atom.call_args.kwargs + assert ua_kwargs["u"] is universe + assert ua_kwargs["mol"] is mol + assert ua_kwargs["mol_id"] == 0 + assert ua_kwargs["group_id"] == 7 + assert ua_kwargs["axes_manager"] is axes_manager + assert ua_kwargs["force_partitioning"] == 0.5 + assert ua_kwargs["customised_axes"] is True + assert ua_kwargs["is_highest"] is False + + +def test_run_omits_forcetorque_when_combined_is_false(): + node = FrameCovarianceNode() + node._process_united_atom = MagicMock() + node._process_residue = MagicMock() + node._process_polymer = MagicMock() + + ctx = { + "shared": { + "reduced_universe": FakeUniverse([FakeMolecule()]), + "groups": {0: [0]}, + "levels": [["residue"]], + "beads": {}, + "args": _args(combined_forcetorque=False), + } + } + + result = node.run(ctx) + + assert set(result) == {"force", "torque"} + + +def test_process_united_atom_updates_outputs_and_molcount(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + bead_group = FakeAtomGroup("ua", length=1) + universe = FakeUniverse([mol], returned_groups=[bead_group]) + + node._build_ua_vectors = MagicMock( + return_value=([np.array([1.0, 0.0, 0.0])], [np.array([0.0, 1.0, 0.0])]) + ) + node._ft.compute_frame_covariance = MagicMock( + return_value=(np.eye(3), 2.0 * np.eye(3)) + ) + + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + molcount = {} + + node._process_united_atom( + u=universe, + mol=mol, + mol_id=0, + group_id=7, + beads={(0, "united_atom", 0): [np.array([0])]}, + axes_manager="axes", + box=None, + force_partitioning=0.5, + customised_axes=False, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + molcount=molcount, + ) + + np.testing.assert_allclose(out_force["ua"][(7, 0)], np.eye(3)) + np.testing.assert_allclose(out_torque["ua"][(7, 0)], 2.0 * np.eye(3)) + assert molcount[(7, 0)] == 1 + + +def test_process_united_atom_returns_when_no_beads_or_empty_atom_groups(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + + node._process_united_atom( + u=FakeUniverse([mol]), + mol=mol, + mol_id=0, + group_id=7, + beads={}, + axes_manager=None, + box=None, + force_partitioning=0.5, + customised_axes=False, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + molcount={}, + ) + + assert out_force["ua"] == {} + + empty_group_universe = FakeUniverse( + [mol], returned_groups=[FakeAtomGroup(length=0)] + ) + node._process_united_atom( + u=empty_group_universe, + mol=mol, + mol_id=0, + group_id=7, + beads={(0, "united_atom", 0): [np.array([0])]}, + axes_manager=None, + box=None, + force_partitioning=0.5, + customised_axes=False, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + molcount={}, + ) + + assert out_force["ua"] == {} + + +def test_process_residue_updates_outputs_and_combined_ft(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + universe = FakeUniverse([mol], returned_groups=[FakeAtomGroup("residue", length=1)]) + + force_vecs = [np.array([1.0, 0.0, 0.0])] + torque_vecs = [np.array([0.0, 1.0, 0.0])] + node._build_residue_vectors = MagicMock(return_value=(force_vecs, torque_vecs)) + node._ft.compute_frame_covariance = MagicMock( + return_value=(np.eye(3), 2.0 * np.eye(3)) + ) + node._build_ft_block = MagicMock(return_value=np.ones((6, 6))) + + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + out_ft = {"ua": {}, "res": {}, "poly": {}} + + node._process_residue( + u=universe, + mol=mol, + mol_id=0, + group_id=7, + beads={(0, "residue"): [np.array([0])]}, + axes_manager="axes", + box=None, + customised_axes=True, + force_partitioning=0.5, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=out_ft, + molcount={}, + combined=True, + ) + + np.testing.assert_allclose(out_force["res"][7], np.eye(3)) + np.testing.assert_allclose(out_torque["res"][7], 2.0 * np.eye(3)) + np.testing.assert_allclose(out_ft["res"][7], np.ones((6, 6))) + + +def test_process_residue_returns_when_no_beads_or_empty_groups(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + + node._process_residue( + u=FakeUniverse([mol]), + mol=mol, + mol_id=0, + group_id=7, + beads={}, + axes_manager=None, + box=None, + customised_axes=False, + force_partitioning=0.5, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=None, + molcount={}, + combined=False, + ) + + assert out_force["res"] == {} + + empty_universe = FakeUniverse([mol], returned_groups=[FakeAtomGroup(length=0)]) + node._process_residue( + u=empty_universe, + mol=mol, + mol_id=0, + group_id=7, + beads={(0, "residue"): [np.array([0])]}, + axes_manager=None, + box=None, + customised_axes=False, + force_partitioning=0.5, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=None, + molcount={}, + combined=False, + ) + + assert out_force["res"] == {} + + +def test_process_polymer_updates_outputs_and_combined_ft(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + bead = FakeAtomGroup("polymer", length=1) + universe = FakeUniverse([mol], returned_groups=[bead]) + + node._get_polymer_axes = MagicMock( + return_value=(np.eye(3), np.eye(3), np.zeros(3), np.ones(3)) + ) + node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 0.0, 0.0])) + node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) + node._ft.compute_frame_covariance = MagicMock( + return_value=(np.eye(3), 2.0 * np.eye(3)) + ) + node._build_ft_block = MagicMock(return_value=np.ones((6, 6))) + + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + out_ft = {"ua": {}, "res": {}, "poly": {}} + + node._process_polymer( + u=universe, + mol=mol, + mol_id=0, + group_id=7, + beads={(0, "polymer"): [np.array([0])]}, + axes_manager="axes", + box=None, + force_partitioning=0.5, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=out_ft, + molcount={}, + combined=True, + ) + + np.testing.assert_allclose(out_force["poly"][7], np.eye(3)) + np.testing.assert_allclose(out_torque["poly"][7], 2.0 * np.eye(3)) + np.testing.assert_allclose(out_ft["poly"][7], np.ones((6, 6))) + + +def test_process_polymer_returns_when_no_beads_or_empty_groups(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + out_force = {"ua": {}, "res": {}, "poly": {}} + out_torque = {"ua": {}, "res": {}, "poly": {}} + + node._process_polymer( + u=FakeUniverse([mol]), + mol=mol, + mol_id=0, + group_id=7, + beads={}, + axes_manager=None, + box=None, + force_partitioning=0.5, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=None, + molcount={}, + combined=False, + ) + + assert out_force["poly"] == {} + + empty_universe = FakeUniverse([mol], returned_groups=[FakeAtomGroup(length=0)]) + node._process_polymer( + u=empty_universe, + mol=mol, + mol_id=0, + group_id=7, + beads={(0, "polymer"): [np.array([0])]}, + axes_manager=None, + box=None, + force_partitioning=0.5, + is_highest=True, + out_force=out_force, + out_torque=out_torque, + out_ft=None, + molcount={}, + combined=False, + ) + + assert out_force["poly"] == {} + + +def test_build_ua_vectors_uses_customised_axes(): + node = FrameCovarianceNode() + axes_manager = MagicMock() + axes_manager.get_UA_axes.return_value = ( + np.eye(3), + 2.0 * np.eye(3), + np.ones(3), + np.array([1.0, 2.0, 3.0]), + ) + node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 0.0, 0.0])) + node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) + + force_vecs, torque_vecs = node._build_ua_vectors( + bead_groups=[FakeAtomGroup("ua")], + residue_atoms=FakeAtomGroup("res"), + axes_manager=axes_manager, + box=None, + force_partitioning=0.5, + customised_axes=True, + is_highest=True, + ) + + assert len(force_vecs) == 1 + assert len(torque_vecs) == 1 + axes_manager.get_UA_axes.assert_called_once() + + +def test_build_ua_vectors_uses_vanilla_axes_when_not_customised(): + node = FrameCovarianceNode() + axes_manager = MagicMock() + axes_manager.get_vanilla_axes.return_value = ( + np.eye(3), + np.array([1.0, 2.0, 3.0]), + ) + node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 0.0, 0.0])) + node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) + + with patch("CodeEntropy.levels.nodes.covariance.make_whole") as make_whole: + node._build_ua_vectors( + bead_groups=[FakeAtomGroup("ua")], + residue_atoms=FakeAtomGroup("res"), + axes_manager=axes_manager, + box=None, + force_partitioning=0.5, + customised_axes=False, + is_highest=False, + ) + + assert make_whole.call_count == 2 + axes_manager.get_vanilla_axes.assert_called_once() + + +def test_build_residue_vectors_uses_residue_axes(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + axes_manager = MagicMock() + + node._get_residue_axes = MagicMock( + return_value=(np.eye(3), np.eye(3), np.zeros(3), np.ones(3)) + ) + node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 0.0, 0.0])) + node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) + + force_vecs, torque_vecs = node._build_residue_vectors( + mol=mol, + bead_groups=[FakeAtomGroup("res")], + axes_manager=axes_manager, + box=None, + customised_axes=True, + force_partitioning=0.5, + is_highest=True, + ) + + assert len(force_vecs) == 1 + assert len(torque_vecs) == 1 + node._get_residue_axes.assert_called_once() + + +def test_get_residue_axes_customised_delegates_to_axes_manager(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + axes_manager = MagicMock() + expected = (np.eye(3), np.eye(3), np.zeros(3), np.ones(3)) + axes_manager.get_residue_axes.return_value = expected + + assert ( + node._get_residue_axes( + mol=mol, + bead=FakeAtomGroup("res"), + local_res_i=0, + axes_manager=axes_manager, + customised_axes=True, + ) + == expected + ) + + axes_manager.get_residue_axes.assert_called_once_with( + mol, + 0, + residue=mol.residues[0].atoms, + ) + + +def test_get_residue_axes_vanilla_uses_make_whole_and_vanilla_axes(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + bead = FakeAtomGroup("res") + axes_manager = MagicMock() + axes_manager.get_vanilla_axes.return_value = ( + np.eye(3), + np.array([1.0, 2.0, 3.0]), + ) + + with patch("CodeEntropy.levels.nodes.covariance.make_whole") as make_whole: + trans_axes, rot_axes, center, moi = node._get_residue_axes( + mol=mol, + bead=bead, + local_res_i=0, + axes_manager=axes_manager, + customised_axes=False, + ) + + assert make_whole.call_count == 2 + np.testing.assert_allclose(trans_axes, np.eye(3)) + np.testing.assert_allclose(rot_axes, np.eye(3)) + np.testing.assert_allclose(center, np.array([1.0, 2.0, 3.0])) + np.testing.assert_allclose(moi, np.array([1.0, 2.0, 3.0])) + + +def test_get_polymer_axes_uses_make_whole_and_vanilla_axes(): + node = FrameCovarianceNode() + mol = FakeMolecule(n_residues=1) + bead = FakeAtomGroup("poly") + axes_manager = MagicMock() + axes_manager.get_vanilla_axes.return_value = ( + np.eye(3), + np.array([1.0, 2.0, 3.0]), + ) + + with patch("CodeEntropy.levels.nodes.covariance.make_whole") as make_whole: + trans_axes, rot_axes, center, moi = node._get_polymer_axes( + mol=mol, + bead=bead, + axes_manager=axes_manager, + ) + + assert make_whole.call_count == 2 + np.testing.assert_allclose(trans_axes, np.eye(3)) + np.testing.assert_allclose(rot_axes, np.eye(3)) + np.testing.assert_allclose(center, np.array([1.0, 2.0, 3.0])) + np.testing.assert_allclose(moi, np.array([1.0, 2.0, 3.0])) + + +def test_get_shared_requires_shared_key(): + with pytest.raises(KeyError, match="ctx\\['shared'\\]"): + FrameCovarianceNode._get_shared({}) + + +def test_try_get_box_returns_dimensions_or_none(): + assert np.allclose( + FrameCovarianceNode._try_get_box(SimpleNamespace(dimensions=[1, 2, 3, 90])), + np.array([1.0, 2.0, 3.0]), + ) + assert FrameCovarianceNode._try_get_box(object()) is None + + +def test_inc_mean_copies_first_value_and_updates_existing_mean(): + new = np.array([1.0, 2.0]) + out = FrameCovarianceNode._inc_mean(None, new, n=1) + new[0] = 99.0 + assert out[0] == 1.0 + + np.testing.assert_allclose( + FrameCovarianceNode._inc_mean(np.array([2.0, 2.0]), np.array([4.0, 0.0]), n=2), + np.array([3.0, 1.0]), + ) + + +def test_build_ft_block_builds_symmetric_block_matrix(): + force_vecs = [np.array([1.0, 0.0, 0.0]), np.array([0.0, 1.0, 0.0])] + torque_vecs = [np.array([0.0, 0.0, 1.0]), np.array([1.0, 1.0, 0.0])] + + out = FrameCovarianceNode._build_ft_block(force_vecs, torque_vecs) + + assert out.shape == (12, 12) + np.testing.assert_allclose(out[:6, 6:], out[6:, :6].T) + + +def test_build_ft_block_rejects_invalid_inputs(): + with pytest.raises(ValueError, match="same length"): + FrameCovarianceNode._build_ft_block([np.ones(3)], []) + + with pytest.raises(ValueError, match="No bead vectors"): + FrameCovarianceNode._build_ft_block([], []) + + with pytest.raises(ValueError, match="length 3"): + FrameCovarianceNode._build_ft_block([np.ones(2)], [np.ones(3)]) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_detect_levels_node.py b/tests/unit/CodeEntropy/levels/nodes/test_detect_levels_node.py index e9500810..2d9d04ce 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_detect_levels_node.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_detect_levels_node.py @@ -1,19 +1,46 @@ -from unittest.mock import patch +"""Atomic unit tests for hierarchy-level detection.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest from CodeEntropy.levels.nodes.detect_levels import DetectLevelsNode -def test_detect_levels_node_stores_results(reduced_universe): +def test_run_detects_levels_and_writes_shared_data(): + universe = object() + node = DetectLevelsNode() + node._hierarchy = MagicMock() + node._hierarchy.select_levels.return_value = ( + 2, + [["united_atom"], ["united_atom", "residue"]], + ) + + shared_data = {"reduced_universe": universe} + + result = node.run(shared_data) + + node._hierarchy.select_levels.assert_called_once_with(universe) + assert shared_data["number_molecules"] == 2 + assert shared_data["levels"] == [["united_atom"], ["united_atom", "residue"]] + assert result == { + "number_molecules": 2, + "levels": [["united_atom"], ["united_atom", "residue"]], + } + + +def test_run_requires_reduced_universe(): + with pytest.raises(KeyError): + DetectLevelsNode().run({}) + + +def test_detect_levels_delegates_to_hierarchy_builder(): + universe = object() node = DetectLevelsNode() - shared = {"reduced_universe": reduced_universe} - - with patch.object( - node._hierarchy, - "select_levels", - return_value=(2, [["united_atom"], ["united_atom", "residue"]]), - ): - out = node.run(shared) - - assert shared["number_molecules"] == 2 - assert shared["levels"] == [["united_atom"], ["united_atom", "residue"]] - assert out["levels"] == shared["levels"] + node._hierarchy = MagicMock() + node._hierarchy.select_levels.return_value = (1, [["polymer"]]) + + assert node._detect_levels(universe) == (1, [["polymer"]]) + node._hierarchy.select_levels.assert_called_once_with(universe) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_detect_molecules_node.py b/tests/unit/CodeEntropy/levels/nodes/test_detect_molecules_node.py index 249db1e4..38a5a505 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_detect_molecules_node.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_detect_molecules_node.py @@ -1,45 +1,97 @@ +"""Atomic unit tests for molecule detection and grouping.""" + +from __future__ import annotations + from types import SimpleNamespace -from unittest.mock import patch +from unittest.mock import MagicMock import pytest from CodeEntropy.levels.nodes.detect_molecules import DetectMoleculesNode -def test_run_sets_reduced_universe_when_missing(args, universe_with_fragments): +class FakeUniverse: + """Minimal universe exposing molecule fragments.""" + + def __init__(self, n_fragments: int): + self.atoms = SimpleNamespace(fragments=[object() for _ in range(n_fragments)]) + + +def test_run_uses_existing_reduced_universe_and_configured_grouping(): + universe = FakeUniverse(3) node = DetectMoleculesNode() + node._grouping = MagicMock() + node._grouping.grouping_molecules.return_value = {0: [0, 1], 1: [2]} - shared = { - "universe": universe_with_fragments, - "args": args, + shared_data = { + "reduced_universe": universe, + "universe": FakeUniverse(99), + "args": SimpleNamespace(grouping="molecules"), } - with patch.object(node._grouping, "grouping_molecules", return_value={0: [1]}): - out = node.run(shared) + result = node.run(shared_data) - assert shared["reduced_universe"] is universe_with_fragments - assert shared["groups"] == {0: [1]} - assert shared["number_molecules"] == len(universe_with_fragments.atoms.fragments) - assert out["number_molecules"] == shared["number_molecules"] + node._grouping.grouping_molecules.assert_called_once_with(universe, "molecules") + assert shared_data["groups"] == {0: [0, 1], 1: [2]} + assert shared_data["number_molecules"] == 3 + assert result == { + "groups": {0: [0, 1], 1: [2]}, + "number_molecules": 3, + } -def test_run_uses_args_grouping_strategy(universe_with_fragments): +def test_run_falls_back_to_universe_when_reduced_universe_missing(): + universe = FakeUniverse(2) node = DetectMoleculesNode() - shared = { - "universe": universe_with_fragments, - "args": SimpleNamespace(grouping="molecules"), + node._grouping = MagicMock() + node._grouping.grouping_molecules.return_value = {0: [0], 1: [1]} + + shared_data = { + "universe": universe, + "args": SimpleNamespace(grouping="each"), } - with patch.object( - node._grouping, "grouping_molecules", return_value={"g": [1]} - ) as gm: - node.run(shared) + node.run(shared_data) - gm.assert_called_once() - assert gm.call_args[0][1] == "molecules" + assert shared_data["reduced_universe"] is universe + node._grouping.grouping_molecules.assert_called_once_with(universe, "each") -def test_ensure_reduced_universe_raises_if_missing_universe(): +def test_run_uses_default_grouping_when_args_has_no_grouping_attribute(): + universe = FakeUniverse(1) node = DetectMoleculesNode() + node._grouping = MagicMock() + node._grouping.grouping_molecules.return_value = {0: [0]} + + shared_data = { + "reduced_universe": universe, + "args": SimpleNamespace(), + } + + node.run(shared_data) + + node._grouping.grouping_molecules.assert_called_once_with(universe, "each") + + +def test_run_requires_args(): with pytest.raises(KeyError): - node._ensure_reduced_universe({}) + DetectMoleculesNode().run({"reduced_universe": FakeUniverse(1)}) + + +def test_ensure_reduced_universe_raises_when_no_universe_available(): + with pytest.raises(KeyError, match="shared_data must contain 'universe'"): + DetectMoleculesNode()._ensure_reduced_universe({}) + + +def test_get_grouping_strategy_reads_args_with_default(): + node = DetectMoleculesNode() + + assert ( + node._get_grouping_strategy({"args": SimpleNamespace(grouping="molecules")}) + == "molecules" + ) + assert node._get_grouping_strategy({"args": SimpleNamespace()}) == "each" + + +def test_count_molecules_counts_fragments(): + assert DetectMoleculesNode._count_molecules(FakeUniverse(4)) == 4 diff --git a/tests/unit/CodeEntropy/levels/nodes/test_find_neighbors.py b/tests/unit/CodeEntropy/levels/nodes/test_find_neighbors.py deleted file mode 100644 index 56790306..00000000 --- a/tests/unit/CodeEntropy/levels/nodes/test_find_neighbors.py +++ /dev/null @@ -1,40 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -from CodeEntropy.levels.nodes.find_neighbors import ComputeNeighborsNode - - -def test_compute_find_neighbors_node_runs_and_writes_shared_data(): - node = ComputeNeighborsNode() - - frame_source = MagicMock() - - node._neighbor_analysis.get_neighbors = MagicMock(return_value={0: 7.8}) - node._neighbor_analysis.get_symmetry = MagicMock(return_value=({0: 2}, {0: False})) - - shared = { - "reduced_universe": MagicMock(), - "levels": {0: ["united_atom"]}, - "groups": {0: [0]}, - "frame_source": frame_source, - "args": SimpleNamespace(search_type="RAD"), - } - - out = node.run(shared) - - assert out is shared - assert shared["neighbors"] == {0: 7.8} - assert shared["symmetry_number"] == {0: 2} - assert shared["linear"] == {0: False} - - node._neighbor_analysis.get_neighbors.assert_called_once_with( - universe=shared["reduced_universe"], - levels=shared["levels"], - groups=shared["groups"], - frame_source=frame_source, - search_type="RAD", - ) - node._neighbor_analysis.get_symmetry.assert_called_once_with( - universe=shared["reduced_universe"], - groups=shared["groups"], - ) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py b/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py deleted file mode 100644 index 1b51006b..00000000 --- a/tests/unit/CodeEntropy/levels/nodes/test_frame_covariance_node.py +++ /dev/null @@ -1,634 +0,0 @@ -from unittest.mock import MagicMock, patch - -import numpy as np -import pytest - -from CodeEntropy.levels.nodes import covariance as covmod -from CodeEntropy.levels.nodes.covariance import FrameCovarianceNode - - -class _BeadGroup: - def __init__(self, n=1): - self._n = n - - def __len__(self): - return self._n - - def center_of_mass(self, unwrap=False): - return np.array([0.0, 0.0, 0.0], dtype=float) - - -class _EmptyGroup: - def __len__(self): - return 0 - - -def _mk_atomgroup(n=1): - g = MagicMock() - g.__len__.return_value = n - return g - - -def test_get_shared_missing_raises_keyerror(): - node = FrameCovarianceNode() - with pytest.raises(KeyError): - node._get_shared({}) - - -def test_try_get_box_returns_none_on_failure(): - node = FrameCovarianceNode() - u = MagicMock() - type(u).dimensions = property(lambda self: (_ for _ in ()).throw(RuntimeError("x"))) - assert node._try_get_box(u) is None - - -def test_inc_mean_first_sample_copies(): - node = FrameCovarianceNode() - new = np.eye(2) - out = node._inc_mean(None, new, n=1) - np.testing.assert_allclose(out, new) - new[0, 0] = 999.0 - assert out[0, 0] != 999.0 - - -def test_inc_mean_updates_streaming_average(): - node = FrameCovarianceNode() - old = np.array([[2.0, 2.0], [2.0, 2.0]]) - new = np.array([[4.0, 0.0], [0.0, 4.0]]) - out = node._inc_mean(old, new, n=2) - np.testing.assert_allclose(out, np.array([[3.0, 1.0], [1.0, 3.0]])) - - -def test_build_ft_block_rejects_mismatched_lengths(): - node = FrameCovarianceNode() - with pytest.raises(ValueError): - node._build_ft_block([np.zeros(3)], [np.zeros(3), np.zeros(3)]) - - -def test_build_ft_block_rejects_empty(): - node = FrameCovarianceNode() - with pytest.raises(ValueError): - node._build_ft_block([], []) - - -def test_build_ft_block_rejects_non_length3_vectors(): - node = FrameCovarianceNode() - with pytest.raises(ValueError): - node._build_ft_block([np.zeros(2)], [np.zeros(3)]) - - -def test_build_ft_block_returns_symmetric_block_matrix(): - node = FrameCovarianceNode() - - force_vecs = [np.array([1.0, 0.0, 0.0]), np.array([0.0, 2.0, 0.0])] - torque_vecs = [np.array([0.0, 0.0, 3.0]), np.array([4.0, 0.0, 0.0])] - - M = node._build_ft_block(force_vecs, torque_vecs) - assert M.shape == (12, 12) - - np.testing.assert_allclose(M, M.T) - - -def test_process_residue_skips_when_no_beads_key_present(): - node = FrameCovarianceNode() - - shared = { - "reduced_universe": MagicMock(), - "groups": {0: [0]}, - "levels": [["residue"]], - "beads": {}, - "args": MagicMock( - force_partitioning=1.0, combined_forcetorque=False, customised_axes=False - ), - "axes_manager": MagicMock(), - } - ctx = {"shared": shared} - - out = node.run(ctx) - assert out["force"]["res"] == {} - assert out["torque"]["res"] == {} - assert "forcetorque" not in out - - -def test_process_residue_combined_only_when_highest_level(): - node = FrameCovarianceNode() - - u = MagicMock() - u.atoms = MagicMock() - frag = MagicMock() - frag.residues = [MagicMock()] - u.atoms.fragments = [frag] - u.atoms.__getitem__.side_effect = lambda idx: _mk_atomgroup(1) - u.dimensions = np.array([10.0, 10.0, 10.0, 90.0, 90.0, 90.0]) - - args = MagicMock() - args.force_partitioning = 1.0 - args.combined_forcetorque = True - args.customised_axes = True - - axes_manager = MagicMock() - axes_manager.get_residue_axes.return_value = ( - np.eye(3), - np.eye(3), - np.zeros(3), - np.array([1.0, 1.0, 1.0]), - ) - - shared = { - "reduced_universe": u, - "groups": {7: [0]}, - "levels": [["residue"]], - "beads": {(0, "residue"): [np.array([1, 2, 3])]}, - "args": args, - "axes_manager": axes_manager, - } - - with ( - patch.object( - node._ft, "get_weighted_forces", return_value=np.array([1.0, 0.0, 0.0]) - ), - patch.object( - node._ft, "get_weighted_torques", return_value=np.array([0.0, 1.0, 0.0]) - ), - patch.object( - node._ft, - "compute_frame_covariance", - return_value=(np.eye(3), 2.0 * np.eye(3)), - ), - ): - ctx = {"shared": shared} - out = node.run(ctx) - - assert "forcetorque" in out - assert 7 in out["force"]["res"] - assert 7 in out["torque"]["res"] - assert 7 in out["forcetorque"]["res"] - - -def test_process_residue_combined_not_added_if_not_highest_level(): - node = FrameCovarianceNode() - - u = MagicMock() - u.atoms = MagicMock() - frag = MagicMock() - frag.residues = [MagicMock()] - u.atoms.fragments = [frag] - u.atoms.__getitem__.side_effect = lambda idx: _mk_atomgroup(1) - u.dimensions = np.array([10.0, 10.0, 10.0, 90.0, 90.0, 90.0]) - - args = MagicMock( - force_partitioning=1.0, combined_forcetorque=True, customised_axes=True - ) - - axes_manager = MagicMock() - axes_manager.get_residue_axes.return_value = ( - np.eye(3), - np.eye(3), - np.zeros(3), - np.ones(3), - ) - - shared = { - "reduced_universe": u, - "groups": {7: [0]}, - "levels": [["united_atom", "residue", "polymer"]], - "beads": {(0, "residue"): [np.array([1, 2, 3])]}, - "args": args, - "axes_manager": axes_manager, - } - - with ( - patch.object( - node._ft, "get_weighted_forces", return_value=np.array([1.0, 0.0, 0.0]) - ), - patch.object( - node._ft, "get_weighted_torques", return_value=np.array([0.0, 1.0, 0.0]) - ), - patch.object( - node._ft, - "compute_frame_covariance", - return_value=(np.eye(3), 2.0 * np.eye(3)), - ), - ): - out = node.run({"shared": shared}) - - assert "forcetorque" in out - assert out["forcetorque"]["res"] == {} - - -def test_process_united_atom_returns_when_no_beads_for_level(): - node = FrameCovarianceNode() - - res = MagicMock() - res.atoms = MagicMock() - mol = MagicMock() - mol.residues = [res] - - axes_manager = MagicMock() - - out_force = {"ua": {}, "res": {}, "poly": {}} - out_torque = {"ua": {}, "res": {}, "poly": {}} - molcount = {} - - node._process_united_atom( - u=MagicMock(), - mol=mol, - mol_id=0, - group_id=0, - beads={}, - axes_manager=axes_manager, - box=np.array([10.0, 10.0, 10.0], dtype=float), - force_partitioning=1.0, - customised_axes=False, - is_highest=True, - out_force=out_force, - out_torque=out_torque, - molcount=molcount, - ) - - assert out_force["ua"] == {} - assert out_torque["ua"] == {} - assert molcount == {} - axes_manager.get_UA_axes.assert_not_called() - axes_manager.get_vanilla_axes.assert_not_called() - - -def test_get_residue_axes_vanilla_branch_returns_arrays(monkeypatch): - node = FrameCovarianceNode() - - monkeypatch.setattr( - "CodeEntropy.levels.nodes.covariance.make_whole", lambda _ag: None - ) - - mol = MagicMock() - mol.atoms.principal_axes.return_value = np.eye(3) * 2 - - bead = MagicMock() - bead.center_of_mass.return_value = np.array([1.0, 2.0, 3.0]) - - axes_manager = MagicMock() - axes_manager.get_vanilla_axes.return_value = (np.eye(3), np.array([9.0, 8.0, 7.0])) - - trans, rot, center, moi = node._get_residue_axes( - mol=mol, - bead=bead, - local_res_i=0, - axes_manager=axes_manager, - customised_axes=False, - ) - - assert trans.shape == (3, 3) - assert rot.shape == (3, 3) - assert center.shape == (3,) - assert moi.shape == (3,) - assert np.allclose(trans, np.eye(3) * 2) - assert np.allclose(rot, np.eye(3)) - assert np.allclose(center, np.array([1.0, 2.0, 3.0])) - assert np.allclose(moi, np.array([9.0, 8.0, 7.0])) - - -def test_get_polymer_axes_returns_arrays(monkeypatch): - node = FrameCovarianceNode() - - monkeypatch.setattr( - "CodeEntropy.levels.nodes.covariance.make_whole", lambda _ag: None - ) - - mol = MagicMock() - mol.atoms.principal_axes.return_value = np.eye(3) * 3 - - bead = MagicMock() - bead.center_of_mass.return_value = np.array([0.0, 0.0, 0.0]) - - axes_manager = MagicMock() - axes_manager.get_vanilla_axes.return_value = (np.eye(3), np.array([1.0, 1.0, 1.0])) - - trans, rot, center, moi = node._get_polymer_axes( - mol=mol, - bead=bead, - axes_manager=axes_manager, - ) - - assert trans.shape == (3, 3) - assert rot.shape == (3, 3) - assert center.shape == (3,) - assert moi.shape == (3,) - assert np.allclose(trans, np.eye(3) * 3) - assert np.allclose(rot, np.eye(3)) - assert np.allclose(center, np.array([0.0, 0.0, 0.0])) - assert np.allclose(moi, np.array([1.0, 1.0, 1.0])) - - -def test_process_united_atom_updates_outputs_and_molcount(): - node = FrameCovarianceNode() - - node._build_ua_vectors = MagicMock( - return_value=( - [np.array([1.0, 0.0, 0.0])], - [np.array([0.0, 1.0, 0.0])], - ) - ) - - F = np.eye(3) - T = np.eye(3) * 2 - node._ft.compute_frame_covariance = MagicMock(return_value=(F, T)) - - u = MagicMock() - u.atoms = MagicMock() - u.atoms.__getitem__.side_effect = lambda idx: _BeadGroup(1) - - res = MagicMock() - res.atoms = MagicMock() - mol = MagicMock() - mol.residues = [res] - - beads = {(0, "united_atom", 0): [123]} - out_force = {"ua": {}, "res": {}, "poly": {}} - out_torque = {"ua": {}, "res": {}, "poly": {}} - molcount = {} - - node._process_united_atom( - u=u, - mol=mol, - mol_id=0, - group_id=7, - beads=beads, - axes_manager=MagicMock(), - box=np.array([10.0, 10.0, 10.0]), - force_partitioning=1.0, - customised_axes=False, - is_highest=True, - out_force=out_force, - out_torque=out_torque, - molcount=molcount, - ) - - key = (7, 0) - assert np.allclose(out_force["ua"][key], F) - assert np.allclose(out_torque["ua"][key], T) - assert molcount[key] == 1 - - -def test_process_residue_returns_early_when_no_beads(): - node = FrameCovarianceNode() - - out_force = {"ua": {}, "res": {}, "poly": {}} - out_torque = {"ua": {}, "res": {}, "poly": {}} - - node._process_residue( - u=MagicMock(), - mol=MagicMock(), - mol_id=0, - group_id=0, - beads={}, - axes_manager=MagicMock(), - box=np.array([10.0, 10.0, 10.0]), - customised_axes=False, - force_partitioning=1.0, - is_highest=True, - out_force=out_force, - out_torque=out_torque, - out_ft=None, - molcount={}, - combined=False, - ) - - assert out_force["res"] == {} - assert out_torque["res"] == {} - - -def test_build_ua_vectors_customised_axes_true_calls_get_UA_axes(): - node = FrameCovarianceNode() - - bead = _BeadGroup(1) - residue_atoms = MagicMock() - - axes_manager = MagicMock() - axes_manager.get_UA_axes.return_value = ( - np.eye(3), - np.eye(3), - np.array([0.0, 0.0, 0.0]), - np.array([1.0, 1.0, 1.0]), - ) - - node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) - node._ft.get_weighted_torques = MagicMock(return_value=np.array([4.0, 5.0, 6.0])) - - force_vecs, torque_vecs = node._build_ua_vectors( - bead_groups=[bead], - residue_atoms=residue_atoms, - axes_manager=axes_manager, - box=np.array([10.0, 10.0, 10.0]), - force_partitioning=1.0, - customised_axes=True, - is_highest=True, - ) - - axes_manager.get_UA_axes.assert_called_once() - assert len(force_vecs) == 1 and len(torque_vecs) == 1 - - -def test_build_ua_vectors_vanilla_path_uses_principal_axes_and_vanilla_axes( - monkeypatch, -): - node = FrameCovarianceNode() - - residue_atoms = MagicMock() - residue_atoms.principal_axes.return_value = np.eye(3) - - bead = _BeadGroup(1) - - axes_manager = MagicMock() - axes_manager.get_vanilla_axes.return_value = ( - np.eye(3) * 2, - np.array([9.0, 8.0, 7.0]), - ) - - monkeypatch.setattr(covmod, "make_whole", lambda *_: None) - - node._ft.get_weighted_forces = MagicMock(return_value=np.array([1.0, 0.0, 0.0])) - node._ft.get_weighted_torques = MagicMock(return_value=np.array([0.0, 1.0, 0.0])) - - force_vecs, torque_vecs = node._build_ua_vectors( - bead_groups=[bead], - residue_atoms=residue_atoms, - axes_manager=axes_manager, - box=np.array([10.0, 10.0, 10.0]), - force_partitioning=1.0, - customised_axes=False, - is_highest=True, - ) - - axes_manager.get_vanilla_axes.assert_called_once() - assert len(force_vecs) == 1 and len(torque_vecs) == 1 - - -def test_process_united_atom_skips_when_any_bead_group_is_empty(): - node = FrameCovarianceNode() - - res = MagicMock() - res.atoms = MagicMock() - mol = MagicMock() - mol.residues = [res] - - u = MagicMock() - u.atoms = MagicMock() - u.atoms.__getitem__.side_effect = lambda idx: _EmptyGroup() - - out_force = {"ua": {}, "res": {}, "poly": {}} - out_torque = {"ua": {}, "res": {}, "poly": {}} - - node._process_united_atom( - u=u, - mol=mol, - mol_id=0, - group_id=0, - beads={(0, "united_atom", 0): [123]}, - axes_manager=MagicMock(), - box=np.array([10.0, 10.0, 10.0]), - force_partitioning=1.0, - customised_axes=False, - is_highest=True, - out_force=out_force, - out_torque=out_torque, - molcount={}, - ) - - assert out_force["ua"] == {} - assert out_torque["ua"] == {} - - -def test_process_residue_returns_early_when_any_bead_group_is_empty(): - node = FrameCovarianceNode() - - u = MagicMock() - u.atoms = MagicMock() - u.atoms.__getitem__.side_effect = lambda idx: _EmptyGroup() - - out_force = {"ua": {}, "res": {}, "poly": {}} - out_torque = {"ua": {}, "res": {}, "poly": {}} - - node._process_residue( - u=u, - mol=MagicMock(), - mol_id=0, - group_id=0, - beads={(0, "residue"): [np.array([1, 2, 3])]}, - axes_manager=MagicMock(), - box=np.array([10.0, 10.0, 10.0]), - customised_axes=False, - force_partitioning=1.0, - is_highest=True, - out_force=out_force, - out_torque=out_torque, - out_ft=None, - molcount={}, - combined=False, - ) - - assert out_force["res"] == {} - assert out_torque["res"] == {} - - -def test_process_polymer_skips_when_any_bead_group_is_empty(): - node = FrameCovarianceNode() - - u = MagicMock() - u.atoms = MagicMock() - u.atoms.__getitem__.side_effect = lambda idx: _EmptyGroup() - - out_force = {"ua": {}, "res": {}, "poly": {}} - out_torque = {"ua": {}, "res": {}, "poly": {}} - out_ft = {"ua": {}, "res": {}, "poly": {}} - - node._process_polymer( - u=u, - mol=MagicMock(), - mol_id=0, - group_id=7, - beads={(0, "polymer"): [np.array([1, 2, 3])]}, - axes_manager=MagicMock(), - box=np.array([10.0, 10.0, 10.0]), - force_partitioning=1.0, - is_highest=True, - out_force=out_force, - out_torque=out_torque, - out_ft=out_ft, - molcount={}, - combined=True, - ) - - assert out_force["poly"] == {} - assert out_torque["poly"] == {} - assert out_ft["poly"] == {} - - -def test_process_polymer_happy_path_updates_force_torque_and_optional_ft(): - node = FrameCovarianceNode() - - u = MagicMock() - u.atoms = MagicMock() - - bead_obj = _BeadGroup(1) - u.atoms.__getitem__.side_effect = lambda idx: bead_obj - - mol = MagicMock() - mol.atoms = MagicMock() - - axes_manager = MagicMock() - - f_vec = np.array([1.0, 0.0, 0.0], dtype=float) - t_vec = np.array([0.0, 1.0, 0.0], dtype=float) - - F = np.eye(3) - T = 2.0 * np.eye(3) - FT = np.eye(6) - - out_force = {"ua": {}, "res": {}, "poly": {}} - out_torque = {"ua": {}, "res": {}, "poly": {}} - out_ft = {"ua": {}, "res": {}, "poly": {}} - molcount = {} - - with ( - patch.object( - node, - "_get_polymer_axes", - return_value=(np.eye(3), np.eye(3), np.zeros(3), np.ones(3)), - ) as axes_spy, - patch.object(node._ft, "get_weighted_forces", return_value=f_vec) as f_spy, - patch.object(node._ft, "get_weighted_torques", return_value=t_vec) as t_spy, - patch.object( - node._ft, "compute_frame_covariance", return_value=(F, T) - ) as cov_spy, - patch.object(node, "_build_ft_block", return_value=FT) as ft_spy, - ): - node._process_polymer( - u=u, - mol=mol, - mol_id=0, - group_id=7, - beads={(0, "polymer"): [np.array([1, 2, 3])]}, - axes_manager=axes_manager, - box=np.array([10.0, 10.0, 10.0]), - force_partitioning=0.5, - is_highest=True, - out_force=out_force, - out_torque=out_torque, - out_ft=out_ft, - molcount=molcount, - combined=True, - ) - - assert u.atoms.__getitem__.call_count == 1 - axes_spy.assert_called_once_with(mol=mol, bead=bead_obj, axes_manager=axes_manager) - - f_spy.assert_called_once() - t_spy.assert_called_once() - cov_spy.assert_called_once() - - np.testing.assert_allclose(out_force["poly"][7], F) - np.testing.assert_allclose(out_torque["poly"][7], T) - assert molcount[7] == 1 - - ft_spy.assert_called_once() - np.testing.assert_allclose(out_ft["poly"][7], FT) diff --git a/tests/unit/CodeEntropy/levels/nodes/test_init_covariance_accumulators_node.py b/tests/unit/CodeEntropy/levels/nodes/test_init_covariance_accumulators_node.py index fd5d31e8..2ba4bdac 100644 --- a/tests/unit/CodeEntropy/levels/nodes/test_init_covariance_accumulators_node.py +++ b/tests/unit/CodeEntropy/levels/nodes/test_init_covariance_accumulators_node.py @@ -1,23 +1,58 @@ +"""Unit tests for covariance accumulator initialisation.""" + +from __future__ import annotations + import numpy as np from CodeEntropy.levels.nodes.accumulators import InitCovarianceAccumulatorsNode -def test_init_covariance_accumulators_allocates_and_sets_aliases(): - node = InitCovarianceAccumulatorsNode() +def test_init_covariance_accumulators_allocates_canonical_accumulators(): + shared = { + "groups": {10: [0], 20: [1]}, + } + + result = InitCovarianceAccumulatorsNode().run(shared) + + assert shared["group_id_to_index"] == {10: 0, 20: 1} + assert shared["index_to_group_id"] == [10, 20] + + assert shared["force_covariances"]["ua"] == {} + assert shared["torque_covariances"]["ua"] == {} + + assert len(shared["force_covariances"]["res"]) == 2 + assert len(shared["torque_covariances"]["res"]) == 2 + assert len(shared["force_covariances"]["poly"]) == 2 + assert len(shared["torque_covariances"]["poly"]) == 2 + + np.testing.assert_array_equal(shared["frame_counts"]["res"], np.zeros(2, dtype=int)) + np.testing.assert_array_equal( + shared["frame_counts"]["poly"], np.zeros(2, dtype=int) + ) - shared = {"groups": {9: [1, 2], 2: [3]}} + assert len(shared["forcetorque_covariances"]["res"]) == 2 + assert len(shared["forcetorque_covariances"]["poly"]) == 2 + np.testing.assert_array_equal( + shared["forcetorque_counts"]["res"], + np.zeros(2, dtype=int), + ) + np.testing.assert_array_equal( + shared["forcetorque_counts"]["poly"], + np.zeros(2, dtype=int), + ) - out = node.run(shared) + assert "force_torque_stats" not in shared + assert "force_torque_counts" not in shared - assert out["group_id_to_index"] == {9: 0, 2: 1} - assert out["index_to_group_id"] == [9, 2] + assert result["force_covariances"] is shared["force_covariances"] + assert result["torque_covariances"] is shared["torque_covariances"] + assert result["frame_counts"] is shared["frame_counts"] + assert result["forcetorque_covariances"] is shared["forcetorque_covariances"] + assert result["forcetorque_counts"] is shared["forcetorque_counts"] - assert shared["force_covariances"]["res"] == [None, None] - assert shared["torque_covariances"]["poly"] == [None, None] - assert np.all(shared["frame_counts"]["res"] == np.array([0, 0])) - assert np.all(shared["forcetorque_counts"]["poly"] == np.array([0, 0])) +def test_init_covariance_accumulators_requires_groups(): + import pytest - assert shared["force_torque_stats"] is shared["forcetorque_covariances"] - assert shared["force_torque_counts"] is shared["forcetorque_counts"] + with pytest.raises(KeyError): + InitCovarianceAccumulatorsNode().run({}) diff --git a/tests/unit/CodeEntropy/levels/test_level_dag.py b/tests/unit/CodeEntropy/levels/test_level_dag.py index 067f8dd3..30ae4077 100644 --- a/tests/unit/CodeEntropy/levels/test_level_dag.py +++ b/tests/unit/CodeEntropy/levels/test_level_dag.py @@ -1,100 +1,12 @@ -"""Unit tests for LevelDAG orchestration, reduction, and parallel frame execution.""" +"""Unit tests for hierarchy-level DAG orchestration.""" from __future__ import annotations -import sys -import types -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch -import numpy as np -import pytest - -from CodeEntropy.levels import level_dag as level_dag_module from CodeEntropy.levels.level_dag import LevelDAG -def _empty_frame_out() -> dict: - """Return an empty frame-local covariance payload.""" - return { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - -def _shared_force_torque() -> dict: - """Return minimal shared data for force/torque reduction tests.""" - return { - "force_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "torque_covariances": {"ua": {}, "res": [None], "poly": [None]}, - "frame_counts": { - "ua": {}, - "res": np.zeros(1, dtype=int), - "poly": np.zeros(1, dtype=int), - }, - "group_id_to_index": {7: 0, 9: 0}, - } - - -def _shared_forcetorque() -> dict: - """Return minimal shared data for combined force-torque reduction tests.""" - return { - "forcetorque_covariances": {"res": [None], "poly": [None]}, - "forcetorque_counts": { - "res": np.zeros(1, dtype=int), - "poly": np.zeros(1, dtype=int), - }, - "group_id_to_index": {7: 0, 9: 0}, - } - - -def test_incremental_mean_none_returns_copy_for_numpy(): - arr = np.array([1.0, 2.0]) - - out = LevelDAG._incremental_mean(None, arr, n=1) - - assert np.all(out == arr) - - arr[0] = 999.0 - assert out[0] != 999.0 - - -def test_incremental_mean_updates_mean_correctly(): - old = np.array([2.0, 2.0]) - new = np.array([4.0, 0.0]) - - out = LevelDAG._incremental_mean(old, new, n=2) - - np.testing.assert_allclose(out, np.array([3.0, 1.0])) - - -def test_incremental_mean_handles_non_copyable_values(): - out = LevelDAG._incremental_mean(old=None, new=3.0, n=1) - - assert out == 3.0 - - -def test_execute_sets_default_axes_manager_and_runs_stages(): - dag = LevelDAG() - - shared = { - "reduced_universe": MagicMock(), - "start": 0, - "end": 0, - "step": 1, - "n_frames": 1, - } - - dag._run_static_stage = MagicMock() - dag._run_frame_stage = MagicMock() - - out = dag.execute(shared) - - assert out is shared - assert "axes_manager" in shared - dag._run_static_stage.assert_called_once_with(shared, progress=None) - dag._run_frame_stage.assert_called_once_with(shared, progress=None) - - def test_build_registers_static_nodes_and_builds_frame_dag(): with ( patch("CodeEntropy.levels.level_dag.DetectMoleculesNode"), @@ -102,668 +14,172 @@ def test_build_registers_static_nodes_and_builds_frame_dag(): patch("CodeEntropy.levels.level_dag.BuildBeadsNode"), patch("CodeEntropy.levels.level_dag.InitCovarianceAccumulatorsNode"), patch("CodeEntropy.levels.level_dag.ComputeConformationalStatesNode"), - patch("CodeEntropy.levels.level_dag.ComputeNeighborsNode"), ): - dag = LevelDAG(universe_operations=MagicMock()) + universe_operations = MagicMock() + dag = LevelDAG(universe_operations=universe_operations) dag._frame_dag.build = MagicMock() out = dag.build() assert out is dag - assert "detect_molecules" in dag._static_nodes - assert "detect_levels" in dag._static_nodes - assert "build_beads" in dag._static_nodes - assert "init_covariance_accumulators" in dag._static_nodes - assert "compute_conformational_states" in dag._static_nodes - assert "find_neighbors" in dag._static_nodes - dag._frame_dag.build.assert_called_once() - - -def test_add_static_adds_dependency_edges(): - dag = LevelDAG() - - dag._add_static("A", MagicMock()) - dag._add_static("B", MagicMock(), deps=["A"]) - - assert dag._static_nodes["A"] is not None - assert dag._static_nodes["B"] is not None - assert ("A", "B") in dag._static_graph.edges - - -def test_run_static_stage_calls_nodes_in_topological_sort_order(): - dag = LevelDAG() - dag._static_graph.add_node("a") - dag._static_graph.add_node("b") - - dag._static_nodes["a"] = MagicMock() - dag._static_nodes["b"] = MagicMock() + assert set(dag._static_nodes) == { + "detect_molecules", + "detect_levels", + "build_beads", + "init_covariance_accumulators", + "compute_conformational_states", + } + assert "find_neighbors" not in dag._static_nodes - with patch("networkx.topological_sort", return_value=["a", "b"]): - dag._run_static_stage({"X": 1}) + assert ("detect_molecules", "detect_levels") in dag._static_graph.edges + assert ("detect_levels", "build_beads") in dag._static_graph.edges + assert ("detect_levels", "init_covariance_accumulators") in dag._static_graph.edges + assert ("detect_levels", "compute_conformational_states") in dag._static_graph.edges - dag._static_nodes["a"].run.assert_called_once() - dag._static_nodes["b"].run.assert_called_once() + dag._frame_dag.build.assert_called_once() -def test_run_static_stage_forwards_progress_when_node_accepts_it(): +def test_execute_sets_default_axes_manager_and_runs_workflow_stages(): dag = LevelDAG() - dag._static_graph.add_node("a") - - node = MagicMock() - dag._static_nodes["a"] = node + shared_data = {"groups": {0: [0]}} progress = MagicMock() - with patch("networkx.topological_sort", return_value=["a"]): - dag._run_static_stage({"X": 1}, progress=progress) - - node.run.assert_called_once_with({"X": 1}, progress=progress) - - -def test_run_static_stage_falls_back_when_node_does_not_accept_progress(): - dag = LevelDAG() - dag._static_graph.add_node("a") - - node = MagicMock() - node.run.side_effect = [TypeError("no progress"), None] - dag._static_nodes["a"] = node + dag._run_static_stage = MagicMock() + dag._initialise_neighbor_metadata = MagicMock() + dag._run_frame_stage = MagicMock() - progress = MagicMock() + with ( + patch("CodeEntropy.levels.level_dag.NeighborReducer.initialise") as initialise, + patch("CodeEntropy.levels.level_dag.NeighborReducer.finalise") as finalise, + ): + out = dag.execute(shared_data, progress=progress) - with patch("networkx.topological_sort", return_value=["a"]): - dag._run_static_stage({"X": 1}, progress=progress) + assert out is shared_data + assert "axes_manager" in shared_data - assert node.run.call_count == 2 - node.run.assert_any_call({"X": 1}, progress=progress) - node.run.assert_any_call({"X": 1}) + dag._run_static_stage.assert_called_once_with(shared_data, progress=progress) + dag._initialise_neighbor_metadata.assert_called_once_with(shared_data) + initialise.assert_called_once_with(shared_data) + dag._run_frame_stage.assert_called_once_with(shared_data, progress=progress) + finalise.assert_called_once_with(shared_data) -def test_run_frame_stage_iterates_selected_frames_and_reduces_each(): +def test_add_static_adds_dependency_edges(): dag = LevelDAG() - frame_source = MagicMock() - frame_source.iter_indices.return_value = [10, 11] - - shared = { - "frame_source": frame_source, - "n_frames": 2, - } - - frame_outputs = [_empty_frame_out(), _empty_frame_out()] - - dag._frame_dag = MagicMock() - dag._frame_dag.execute_frame.side_effect = frame_outputs - dag._reduce_one_frame = MagicMock() - - dag._run_frame_stage(shared) + node_a = MagicMock() + node_b = MagicMock() - assert shared["n_frames"] == 2 - frame_source.iter_indices.assert_called_once() - - assert dag._frame_dag.execute_frame.call_count == 2 - dag._frame_dag.execute_frame.assert_any_call(shared, 10) - dag._frame_dag.execute_frame.assert_any_call(shared, 11) + dag._add_static("A", node_a) + dag._add_static("B", node_b, deps=["A"]) - assert dag._reduce_one_frame.call_count == 2 - dag._reduce_one_frame.assert_any_call(shared, frame_outputs[0]) - dag._reduce_one_frame.assert_any_call(shared, frame_outputs[1]) + assert dag._static_nodes["A"] is node_a + assert dag._static_nodes["B"] is node_b + assert ("A", "B") in dag._static_graph.edges -def test_run_frame_stage_progress_total_comes_from_frame_source_indices(): +def test_run_static_stage_calls_nodes_in_topological_order(): dag = LevelDAG() + node_a = MagicMock() + node_b = MagicMock() - frame_source = MagicMock() - frame_source.iter_indices.return_value = list(range(10)) - - shared = { - "frame_source": frame_source, - "n_frames": 0, - } + dag._add_static("a", node_a) + dag._add_static("b", node_b, deps=["a"]) - frame_out = _empty_frame_out() + shared_data = {} - dag._frame_dag = MagicMock() - dag._frame_dag.execute_frame.return_value = frame_out - dag._reduce_one_frame = MagicMock() + dag._run_static_stage(shared_data) - progress = MagicMock() - progress.add_task.return_value = 123 + node_a.run.assert_called_once_with(shared_data) + node_b.run.assert_called_once_with(shared_data) - dag._run_frame_stage(shared, progress=progress) - progress.add_task.assert_called_once_with( - "[green]Frame processing", - total=10, - title="Initializing", - ) - - assert shared["n_frames"] == 10 - frame_source.iter_indices.assert_called_once() - assert dag._frame_dag.execute_frame.call_count == 10 - assert dag._reduce_one_frame.call_count == 10 - assert progress.advance.call_count == 10 - - -def test_run_frame_stage_with_progress_creates_task_and_updates_titles(): +def test_run_static_stage_forwards_progress_when_node_accepts_it(): dag = LevelDAG() - - frame_source = MagicMock() - frame_source.iter_indices.return_value = [10, 11] - - shared = { - "frame_source": frame_source, - "n_frames": 2, - } - - frame_out = _empty_frame_out() - - dag._frame_dag = MagicMock() - dag._frame_dag.execute_frame.return_value = frame_out - dag._reduce_one_frame = MagicMock() - + node = MagicMock() progress = MagicMock() - progress.add_task.return_value = 77 - - dag._run_frame_stage(shared, progress=progress) + shared_data = {} - progress.add_task.assert_called_once_with( - "[green]Frame processing", - total=2, - title="Initializing", - ) + dag._add_static("node", node) - assert progress.update.call_count == 2 - progress.update.assert_any_call(77, title="Frame 10") - progress.update.assert_any_call(77, title="Frame 11") + dag._run_static_stage(shared_data, progress=progress) - assert progress.advance.call_count == 2 - progress.advance.assert_any_call(77) + node.run.assert_called_once_with(shared_data, progress=progress) - assert dag._frame_dag.execute_frame.call_count == 2 - dag._frame_dag.execute_frame.assert_any_call(shared, 10) - dag._frame_dag.execute_frame.assert_any_call(shared, 11) - assert dag._reduce_one_frame.call_count == 2 - dag._reduce_one_frame.assert_any_call(shared, frame_out) - - -def test_run_frame_stage_falls_back_to_sequential_when_only_one_frame(): +def test_run_static_stage_falls_back_when_node_does_not_accept_progress(): dag = LevelDAG() + node = MagicMock() + node.run.side_effect = [TypeError("unexpected keyword argument progress"), None] + progress = MagicMock() + shared_data = {} - frame_source = MagicMock() - frame_source.iter_indices.return_value = [0] - - client = MagicMock() - - shared_data = { - "frame_source": frame_source, - "dask_client": client, - "parallel_frames": True, - } - - frame_out = _empty_frame_out() - - dag._run_frame_stage_dask = MagicMock() - dag._frame_dag = MagicMock() - dag._frame_dag.execute_frame.return_value = frame_out - dag._reduce_one_frame = MagicMock() + dag._add_static("node", node) - dag._run_frame_stage(shared_data) + dag._run_static_stage(shared_data, progress=progress) - dag._run_frame_stage_dask.assert_not_called() - dag._frame_dag.execute_frame.assert_called_once_with(shared_data, 0) - dag._reduce_one_frame.assert_called_once_with(shared_data, frame_out) - assert shared_data["n_frames"] == 1 + assert node.run.call_count == 2 + assert node.run.call_args_list == [ + call(shared_data, progress=progress), + call(shared_data), + ] -def test_run_frame_stage_falls_back_to_sequential_without_client(): - dag = LevelDAG() +def test_run_frame_stage_collects_frame_indices_and_delegates_to_scheduler(): + universe_operations = MagicMock() + dag = LevelDAG(universe_operations=universe_operations) + progress = MagicMock() frame_source = MagicMock() - frame_source.iter_indices.return_value = [0, 1] + frame_source.iter_indices.return_value = ["2", 4] - shared_data = { - "frame_source": frame_source, - "parallel_frames": True, - } - - frame_out = _empty_frame_out() + shared_data = {"frame_source": frame_source} - dag._run_frame_stage_dask = MagicMock() - dag._frame_dag = MagicMock() - dag._frame_dag.execute_frame.return_value = frame_out - dag._reduce_one_frame = MagicMock() + with patch("CodeEntropy.levels.level_dag.FrameScheduler") as Scheduler: + scheduler = Scheduler.return_value - dag._run_frame_stage(shared_data) + dag._run_frame_stage(shared_data, progress=progress) - dag._run_frame_stage_dask.assert_not_called() - assert dag._frame_dag.execute_frame.call_count == 2 - assert dag._reduce_one_frame.call_count == 2 assert shared_data["n_frames"] == 2 + frame_source.iter_indices.assert_called_once() - -def test_reduce_force_and_torque_handles_empty_frame_gracefully(): - dag = LevelDAG() - shared = _shared_force_torque() - - dag._reduce_force_and_torque(shared_data=shared, frame_out=_empty_frame_out()) - - assert shared["force_covariances"]["ua"] == {} - assert shared["torque_covariances"]["ua"] == {} - assert shared["frame_counts"]["res"][0] == 0 - assert shared["frame_counts"]["poly"][0] == 0 - - -def test_reduce_force_and_torque_updates_counts_and_means(): - dag = LevelDAG() - shared = _shared_force_torque() - - F1 = np.eye(3) - T1 = 2.0 * np.eye(3) - - frame_out = { - "force": {"ua": {(0, 0): F1}, "res": {9: F1}, "poly": {}}, - "torque": {"ua": {(0, 0): T1}, "res": {9: T1}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["ua"][(0, 0)] == 1 - np.testing.assert_allclose(shared["force_covariances"]["ua"][(0, 0)], F1) - np.testing.assert_allclose(shared["torque_covariances"]["ua"][(0, 0)], T1) - - assert shared["frame_counts"]["res"][0] == 1 - np.testing.assert_allclose(shared["force_covariances"]["res"][0], F1) - np.testing.assert_allclose(shared["torque_covariances"]["res"][0], T1) - - -def test_reduce_force_and_torque_exercises_count_branches(): - dag = LevelDAG() - shared = _shared_force_torque() - - frame_out = { - "force": { - "ua": {(9, 0): np.array([1.0])}, - "res": {7: np.array([2.0])}, - "poly": {7: np.array([3.0])}, - }, - "torque": { - "ua": {(9, 0): np.array([4.0])}, - "res": {7: np.array([5.0])}, - "poly": {7: np.array([6.0])}, - }, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert (9, 0) in shared["torque_covariances"]["ua"] - assert shared["frame_counts"]["res"][0] == 1 - assert shared["frame_counts"]["poly"][0] == 1 - np.testing.assert_allclose(shared["force_covariances"]["res"][0], np.array([2.0])) - np.testing.assert_allclose(shared["torque_covariances"]["res"][0], np.array([5.0])) - np.testing.assert_allclose(shared["force_covariances"]["poly"][0], np.array([3.0])) - np.testing.assert_allclose(shared["torque_covariances"]["poly"][0], np.array([6.0])) - - -def test_reduce_force_and_torque_res_torque_increments_when_res_count_is_zero(): - dag = LevelDAG() - shared = _shared_force_torque() - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {7: np.eye(3)}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["res"][0] == 1 - assert shared["torque_covariances"]["res"][0] is not None - - -def test_reduce_force_and_torque_poly_torque_increments_when_poly_count_is_zero(): - dag = LevelDAG() - shared = _shared_force_torque() - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {7: np.eye(3)}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["poly"][0] == 1 - assert shared["torque_covariances"]["poly"][0] is not None - - -def test_reduce_force_and_torque_increments_ua_frame_counts_for_force(): - dag = LevelDAG() - shared = _shared_force_torque() - - key = (9, 0) - F = np.eye(3) - - frame_out = { - "force": {"ua": {key: F}, "res": {}, "poly": {}}, - "torque": {"ua": {}, "res": {}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["ua"][key] == 1 - assert key in shared["force_covariances"]["ua"] - np.testing.assert_array_equal(shared["force_covariances"]["ua"][key], F) - - -def test_reduce_force_and_torque_ua_torque_increments_count_when_force_missing_key(): - dag = LevelDAG() - shared = _shared_force_torque() - - key = (9, 0) - T = np.eye(3) - - frame_out = { - "force": {"ua": {}, "res": {}, "poly": {}}, - "torque": {"ua": {key: T}, "res": {}, "poly": {}}, - } - - dag._reduce_force_and_torque(shared, frame_out) - - assert shared["frame_counts"]["ua"][key] == 1 - np.testing.assert_array_equal(shared["torque_covariances"]["ua"][key], T) - - -def test_reduce_one_frame_calls_force_torque_and_forcetorque_reducers(): - dag = LevelDAG() - shared = {} - frame_out = {} - - dag._reduce_force_and_torque = MagicMock() - dag._reduce_forcetorque = MagicMock() - - dag._reduce_one_frame(shared, frame_out) - - dag._reduce_force_and_torque.assert_called_once_with(shared, frame_out) - dag._reduce_forcetorque.assert_called_once_with(shared, frame_out) - - -def test_reduce_forcetorque_no_key_is_noop(): - dag = LevelDAG() - shared = _shared_forcetorque() - - dag._reduce_forcetorque(shared, frame_out={}) - - assert shared["forcetorque_counts"]["res"][0] == 0 - assert shared["forcetorque_covariances"]["res"][0] is None - - -def test_reduce_forcetorque_updates_res_and_poly(): - dag = LevelDAG() - shared = _shared_forcetorque() - - frame_out = { - "forcetorque": { - "res": {7: np.array([1.0, 1.0])}, - "poly": {7: np.array([2.0, 2.0])}, - } - } - - dag._reduce_forcetorque(shared, frame_out) - - assert shared["forcetorque_counts"]["res"][0] == 1 - assert shared["forcetorque_counts"]["poly"][0] == 1 - np.testing.assert_allclose( - shared["forcetorque_covariances"]["res"][0], - np.array([1.0, 1.0]), - ) - np.testing.assert_allclose( - shared["forcetorque_covariances"]["poly"][0], - np.array([2.0, 2.0]), + Scheduler.assert_called_once_with( + frame_dag=dag._frame_dag, + policy=dag._policy, + universe_operations=universe_operations, ) - - -def test_make_frame_worker_shared_data_excludes_parent_only_keys(): - shared_data = { - "force_covariances": "force accumulator", - "torque_covariances": "torque accumulator", - "forcetorque_covariances": "ft accumulator", - "frame_counts": "frame counts", - "forcetorque_counts": "ft counts", - "force_torque_stats": "legacy ft accumulator alias", - "force_torque_counts": "legacy ft counts alias", - "n_frames": 10, - "entropy_manager": "manager", - "run_manager": "run manager", - "reporter": "reporter", - "dask_client": "client", - "frame_source": "frame source", - "levels": "levels", - "groups": "groups", - "args": "args", - } - - worker_shared = LevelDAG._make_frame_worker_shared_data(shared_data) - - assert worker_shared == { - "frame_source": "frame source", - "levels": "levels", - "groups": "groups", - "args": "args", - } - - -def test_execute_frame_worker_builds_frame_graph_and_returns_frame_output(): - shared_data = {"x": 1} - universe_operations = MagicMock() - - with patch("CodeEntropy.levels.level_dag.FrameGraph") as FrameGraphCls: - graph = MagicMock() - graph.execute_frame.return_value = {"force": {}, "torque": {}} - FrameGraphCls.return_value.build.return_value = graph - - frame_index, frame_out = level_dag_module._execute_frame_worker( - shared_data, - frame_index="5", - universe_operations=universe_operations, - ) - - FrameGraphCls.assert_called_once_with(universe_operations=universe_operations) - FrameGraphCls.return_value.build.assert_called_once() - graph.execute_frame.assert_called_once_with(shared_data, 5) - - assert frame_index == 5 - assert frame_out == {"force": {}, "torque": {}} - - -def test_run_frame_stage_uses_dask_when_client_present(): - dag = LevelDAG() - - frame_source = MagicMock() - frame_source.iter_indices.return_value = [0, 1, 2] - - client = MagicMock() - - shared_data = { - "frame_source": frame_source, - "dask_client": client, - "parallel_frames": True, - } - - dag._run_frame_stage_dask = MagicMock() - dag._frame_dag = MagicMock() - dag._reduce_one_frame = MagicMock() - - dag._run_frame_stage(shared_data) - - dag._run_frame_stage_dask.assert_called_once_with( + scheduler.execute.assert_called_once_with( shared_data, - frame_indices=[0, 1, 2], - client=client, - progress=None, - task=None, + frame_indices=[2, 4], + progress=progress, ) - dag._frame_dag.execute_frame.assert_not_called() - dag._reduce_one_frame.assert_not_called() - assert shared_data["n_frames"] == 3 - - -def test_run_frame_stage_dask_submits_each_frame_and_reduces_completed_results(): - dag = LevelDAG() - - shared_data = { - "keep": "value", - "force_covariances": "exclude me", - "reporter": "exclude me too", - } - - client = MagicMock() - - frame_out0 = _empty_frame_out() - frame_out1 = _empty_frame_out() - - future0 = MagicMock() - future0.result.return_value = (0, frame_out0) - - future1 = MagicMock() - future1.result.return_value = (1, frame_out1) - - client.submit.side_effect = [future0, future1] - - fake_distributed = types.ModuleType("distributed") - fake_distributed.as_completed = MagicMock(return_value=[future0, future1]) - dag._reduce_one_frame = MagicMock() - with patch.dict(sys.modules, {"distributed": fake_distributed}): - dag._run_frame_stage_dask( - shared_data, - frame_indices=[0, 1], - client=client, - progress=None, - task=None, - ) - - assert client.submit.call_count == 2 - - for call in client.submit.call_args_list: - args, kwargs = call - assert args[0] is level_dag_module._execute_frame_worker - assert args[1] == {"keep": "value"} - assert kwargs == {"pure": False} - - assert dag._reduce_one_frame.call_count == 2 - dag._reduce_one_frame.assert_any_call(shared_data, frame_out0) - dag._reduce_one_frame.assert_any_call(shared_data, frame_out1) - client.cancel.assert_not_called() - - -def test_run_frame_stage_dask_updates_progress(): - dag = LevelDAG() - - shared_data = {"keep": "value"} - client = MagicMock() - - frame_out = _empty_frame_out() - future = MagicMock() - future.result.return_value = (7, frame_out) - client.submit.return_value = future - - fake_distributed = types.ModuleType("distributed") - fake_distributed.as_completed = MagicMock(return_value=[future]) - - progress = MagicMock() - dag._reduce_one_frame = MagicMock() - - with patch.dict(sys.modules, {"distributed": fake_distributed}): - dag._run_frame_stage_dask( - shared_data, - frame_indices=[7], - client=client, - progress=progress, - task="task-id", - ) - - progress.update.assert_called_once_with("task-id", title="Frame 7") - progress.advance.assert_called_once_with("task-id") - dag._reduce_one_frame.assert_called_once_with(shared_data, frame_out) - - -def test_run_frame_stage_dask_cancels_futures_and_reraises_on_result_error(): - dag = LevelDAG() - - shared_data = {"keep": "value"} - client = MagicMock() - - future = MagicMock() - future.result.side_effect = RuntimeError("worker failed") - client.submit.return_value = future - - fake_distributed = types.ModuleType("distributed") - fake_distributed.as_completed = MagicMock(return_value=[future]) - - with patch.dict(sys.modules, {"distributed": fake_distributed}): - with pytest.raises(RuntimeError, match="worker failed"): - dag._run_frame_stage_dask( - shared_data, - frame_indices=[0], - client=client, - progress=None, - task=None, - ) - - client.cancel.assert_called_once_with([future]) - - -def test_run_frame_stage_dask_raises_when_distributed_missing(): - dag = LevelDAG() - client = MagicMock() - - real_import = __import__ - - def fake_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == "distributed": - raise ImportError("No module named distributed") - return real_import(name, globals, locals, fromlist, level) - - with patch("builtins.__import__", side_effect=fake_import): - with pytest.raises(RuntimeError, match="requires dask.distributed"): - dag._run_frame_stage_dask( - {"keep": "value"}, - frame_indices=[0], - client=client, - progress=None, - task=None, - ) - - -def test_run_frame_stage_dask_raises_if_completed_count_mismatch(): - dag = LevelDAG() +def test_initialise_neighbor_metadata_writes_symmetry_and_linearity(): + universe = object() + groups = {0: [0], 1: [1]} + shared_data = {"reduced_universe": universe, "groups": groups} - shared_data = {"keep": "value"} - client = MagicMock() + with patch("CodeEntropy.levels.level_dag.Neighbors") as Neighbors: + helper = Neighbors.return_value + helper.get_symmetry.return_value = ({0: 12, 1: 2}, {0: False, 1: True}) - future0 = MagicMock() - future0.result.return_value = (0, _empty_frame_out()) + LevelDAG._initialise_neighbor_metadata(shared_data) - future1 = MagicMock() + helper.get_symmetry.assert_called_once_with(universe=universe, groups=groups) + assert shared_data["symmetry_number"] == {0: 12, 1: 2} + assert shared_data["linear"] == {0: False, 1: True} - client.submit.side_effect = [future0, future1] - fake_distributed = types.ModuleType("distributed") - fake_distributed.as_completed = MagicMock(return_value=[future0]) +def test_initialise_neighbor_metadata_falls_back_to_universe_key(): + universe = object() + shared_data = {"universe": universe, "groups": {0: [0]}} - dag._reduce_one_frame = MagicMock() + with patch("CodeEntropy.levels.level_dag.Neighbors") as Neighbors: + helper = Neighbors.return_value + helper.get_symmetry.return_value = ({0: 1}, {0: False}) - with patch.dict(sys.modules, {"distributed": fake_distributed}): - with pytest.raises( - RuntimeError, - match="Parallel frame execution completed 1 frames, but expected 2", - ): - dag._run_frame_stage_dask( - shared_data, - frame_indices=[0, 1], - client=client, - progress=None, - task=None, - ) + LevelDAG._initialise_neighbor_metadata(shared_data) - client.cancel.assert_called_once() + helper.get_symmetry.assert_called_once_with(universe=universe, groups={0: [0]}) diff --git a/tests/unit/CodeEntropy/levels/test_neighbors.py b/tests/unit/CodeEntropy/levels/test_neighbors.py index f8814267..6bdf9b68 100644 --- a/tests/unit/CodeEntropy/levels/test_neighbors.py +++ b/tests/unit/CodeEntropy/levels/test_neighbors.py @@ -1,619 +1,560 @@ -import contextlib +"""Unit tests for neighbour-count and symmetry helpers.""" + +from __future__ import annotations + +from types import SimpleNamespace from unittest.mock import MagicMock, call, patch -import numpy as np import pytest -from CodeEntropy.levels.neighbors import Neighbors +from CodeEntropy.levels.neighbors import Chem, Neighbors -class _FakeProgress: - def __enter__(self): - return self +class FakeSearch: + """Minimal Search test double.""" - def __exit__(self, exc_type, exc, tb): - return False + def __init__(self): + self.rad_calls = [] + self.grid_calls = [] - def add_task(self, *args, **kwargs): - return 1 + def get_RAD_neighbors(self, *, universe, mol_id, frame_source, frame_index): + self.rad_calls.append( + { + "universe": universe, + "mol_id": mol_id, + "frame_source": frame_source, + "frame_index": frame_index, + } + ) + return [1, 2] + + def get_grid_neighbors( + self, + *, + universe, + mol_id, + highest_level, + frame_source, + frame_index, + ): + self.grid_calls.append( + { + "universe": universe, + "mol_id": mol_id, + "highest_level": highest_level, + "frame_source": frame_source, + "frame_index": frame_index, + } + ) + return [1] - def advance(self, *args, **kwargs): - return None +class FakeRdkitMol: + """Minimal RDKit-like molecule.""" -@contextlib.contextmanager -def _fake_progress_bar(*_args, **_kwargs): - yield _FakeProgress() + def __init__(self, *, heavy_atoms: int, total_atoms: int): + self._heavy_atoms = heavy_atoms + self._total_atoms = total_atoms + def GetNumHeavyAtoms(self): + return self._heavy_atoms -def _make_frame_source(indices): - frame_source = MagicMock() - frame_source.iter_indices.return_value = list(indices) - return frame_source + def GetNumAtoms(self): + return self._total_atoms -def test_raises_error_unknown_search_type(): - neighbors = Neighbors() +class FakeAtomSelection: + """Minimal AtomGroup-like object returned by select_atoms.""" - universe = MagicMock() - levels = {0: ["united_atom"]} - groups = {0: [0]} - frame_source = _make_frame_source([0, 1]) + def __init__(self, *, length: int, rdkit_mol=None): + self._length = length + self._rdkit_mol = rdkit_mol + self.convert_to = MagicMock(return_value=rdkit_mol) - with pytest.raises(ValueError, match="unknown search_type"): - neighbors.get_neighbors( - universe=universe, - levels=levels, - groups=groups, - frame_source=frame_source, - search_type="weird", - ) + def __len__(self): + return self._length -def test_average_number_neighbors_RAD(): - neighbors = Neighbors() +class FakeMolecule: + """Minimal molecule fragment for RDKit conversion tests.""" - universe = MagicMock() - levels = {0: ["united_atom"]} - groups = {0: [0]} - frame_source = _make_frame_source([0, 1]) + def __init__(self, *, dummy_atoms, heavy_fragment, rdkit_mol): + self._dummy_atoms = dummy_atoms + self._heavy_fragment = heavy_fragment + self._rdkit_mol = rdkit_mol + self.convert_to = MagicMock(return_value=rdkit_mol) - neighbors._search.get_RAD_neighbors = MagicMock(side_effect=[[1, 2, 3], [1, 3]]) + def select_atoms(self, selection): + if selection == "prop mass < 0.1": + return self._dummy_atoms + if selection == "prop mass > 0.1": + return self._heavy_fragment + raise AssertionError(f"Unexpected selection: {selection}") - result = neighbors.get_neighbors( - universe=universe, - levels=levels, - groups=groups, - frame_source=frame_source, - search_type="RAD", - ) - assert result == {0: np.float64(2.5)} - assert neighbors._search.get_RAD_neighbors.call_args_list == [ - call( - universe=universe, - mol_id=0, - frame_source=frame_source, - frame_index=0, - ), - call( - universe=universe, - mol_id=0, - frame_source=frame_source, - frame_index=1, - ), - ] +class FakeAtomsWithoutElements: + """Universe atoms object without an elements attribute.""" + + def __init__(self, fragments): + self.fragments = fragments + + +class FakeAtomsWithElements: + """Universe atoms object with an elements attribute.""" + def __init__(self, fragments): + self.fragments = fragments + self.elements = ["C"] -def test_average_number_neighbors_grid(): - neighbors = Neighbors() - universe = MagicMock() - levels = {0: ["united_atom"]} - groups = {0: [0]} - frame_source = _make_frame_source([0, 1]) +def test_neighbors_accepts_injected_search_dependency(): + search = FakeSearch() - neighbors._search.get_grid_neighbors = MagicMock(side_effect=[[1, 2, 3], [1, 3]]) + helper = Neighbors(search=search) + + assert helper._search is search + + +def test_get_frame_neighbor_counts_rad(): + helper = Neighbors() + helper._search = FakeSearch() + + universe = object() + frame_source = object() + groups = {0: [0, 1]} + levels = [["united_atom"], ["united_atom", "residue"]] - result = neighbors.get_neighbors( + result = helper.get_frame_neighbor_counts( universe=universe, levels=levels, groups=groups, frame_source=frame_source, - search_type="grid", + frame_index=5, + search_type="RAD", ) - assert result == {0: np.float64(2.5)} - assert neighbors._search.get_grid_neighbors.call_args_list == [ - call( - universe=universe, - mol_id=0, - highest_level="united_atom", - frame_source=frame_source, - frame_index=0, - ), - call( - universe=universe, - mol_id=0, - highest_level="united_atom", - frame_source=frame_source, - frame_index=1, - ), + assert result == {0: (4, 2)} + assert helper._search.rad_calls == [ + { + "universe": universe, + "mol_id": 0, + "frame_source": frame_source, + "frame_index": 5, + }, + { + "universe": universe, + "mol_id": 1, + "frame_source": frame_source, + "frame_index": 5, + }, ] -def test_average_number_neighbors_RAD_multiple(): - neighbors = Neighbors() +def test_get_frame_neighbor_counts_grid(): + helper = Neighbors() + helper._search = FakeSearch() - universe = MagicMock() - levels = {0: ["united_atom"]} + universe = object() + frame_source = object() groups = {0: [0, 1]} - frame_source = _make_frame_source([0, 1]) + levels = [["united_atom"], ["united_atom", "residue"]] - neighbors._search.get_RAD_neighbors = MagicMock( - side_effect=[[1, 2, 3, 5], [1, 3], [2, 3, 4, 5], [3, 5]] - ) - - result = neighbors.get_neighbors( + result = helper.get_frame_neighbor_counts( universe=universe, levels=levels, groups=groups, frame_source=frame_source, - search_type="RAD", + frame_index=3, + search_type="grid", ) - assert result == {0: np.float64(3.0)} - assert neighbors._search.get_RAD_neighbors.call_count == 4 + assert result == {0: (2, 2)} + assert helper._search.grid_calls == [ + { + "universe": universe, + "mol_id": 0, + "highest_level": "united_atom", + "frame_source": frame_source, + "frame_index": 3, + }, + { + "universe": universe, + "mol_id": 1, + "highest_level": "united_atom", + "frame_source": frame_source, + "frame_index": 3, + }, + ] -def test_get_symmetry_number_res(): - neighbors = Neighbors() - rdkit_mol = MagicMock() - number_heavy = 3 - number_hydrogen = 8 +def test_get_frame_neighbor_counts_empty_group(): + helper = Neighbors() + helper._search = FakeSearch() - rdkit_mol.GetSubstructMatches = MagicMock( - side_effect=[((0, 1, 2), (0, 2, 1), (1, 0, 2))] + result = helper.get_frame_neighbor_counts( + universe=object(), + levels=[], + groups={0: []}, + frame_source=object(), + frame_index=0, + search_type="RAD", ) - class _FakeRDKit_Chem: - """Class to mock rdkit functionality.""" - - def RemoveHs(mol): - rdkit_heavy = MagicMock() - return rdkit_heavy - - with patch("CodeEntropy.levels.neighbors.Chem", _FakeRDKit_Chem): - result = neighbors._get_symmetry_number( - rdkit_mol, number_heavy, number_hydrogen - ) - - assert result == 3 - - -def test_get_symmetry_number_ua(): - neighbors = Neighbors() - rdkit_mol = MagicMock() - number_heavy = 1 - number_hydrogen = 2 - - rdkit_mol.GetSubstructMatches = MagicMock(side_effect=[((0, 1, 2), (0, 2, 1))]) - - result = neighbors._get_symmetry_number(rdkit_mol, number_heavy, number_hydrogen) + assert result == {0: (0, 0)} - assert result == 2 +def test_get_frame_neighbor_counts_converts_frame_index_and_handles_multiple_groups(): + search = FakeSearch() + helper = Neighbors(search=search) -def test_get_symmetry_number_sphere(): - neighbors = Neighbors() - rdkit_mol = MagicMock() - number_heavy = 1 - number_hydrogen = 0 - - rdkit_mol.GetSubstructMatches = MagicMock(side_effect=[((0, 1, 2), (0, 2, 1))]) - - result = neighbors._get_symmetry_number(rdkit_mol, number_heavy, number_hydrogen) - - assert result == 0 - - -def test_get_linear_ua(): - neighbors = Neighbors() - rdkit_mol = MagicMock() - number_heavy = 1 - - class _FakeRDKit_Chem: - """Class to mock rdkit functionality.""" - - def RemoveHs(mol): - rdkit_heavy = MagicMock() - return rdkit_heavy - - with patch("CodeEntropy.levels.neighbors.Chem", _FakeRDKit_Chem): - result = neighbors._get_linear(rdkit_mol, number_heavy) - - assert not result - - -def test_get_linear_diatomic(): - neighbors = Neighbors() - rdkit_mol = MagicMock() - number_heavy = 2 - - class _FakeRDKit_Chem: - """Class to mock rdkit functionality.""" - - def RemoveHs(mol): - rdkit_heavy = MagicMock() - return rdkit_heavy - - with patch("CodeEntropy.levels.neighbors.Chem", _FakeRDKit_Chem): - result = neighbors._get_linear(rdkit_mol, number_heavy) - - assert result - - -def test_get_linear_true(): - neighbors = Neighbors() - rdkit_mol = MagicMock() - rdkit_heavy = MagicMock() - number_heavy = 3 - a1 = MagicMock() - a2 = MagicMock() - a3 = MagicMock() - - class _FakeRDKit_Chem: - """Class to mock rdkit functionality.""" - - @staticmethod - def RemoveHs(mol): - return rdkit_heavy - - class HybridizationType: - SP = "SP" - - rdkit_heavy.GetAtoms = MagicMock(return_value=[a1, a2, a3]) - - a1.GetHybridization = MagicMock(return_value="SP3") - a2.GetHybridization = MagicMock(return_value="SP") - a3.GetHybridization = MagicMock(return_value="SP3") - - with patch("CodeEntropy.levels.neighbors.Chem", _FakeRDKit_Chem): - result = neighbors._get_linear(rdkit_mol, number_heavy) - - assert result is True - - -def test_get_linear_false(): - neighbors = Neighbors() - rdkit_mol = MagicMock() - rdkit_heavy = MagicMock() - number_heavy = 3 - a1 = MagicMock() - a2 = MagicMock() - a3 = MagicMock() - - class _FakeRDKit_Chem: - """Class to mock rdkit functionality.""" - - @staticmethod - def RemoveHs(mol): - return rdkit_heavy - - class HybridizationType: - SP = "SP" + universe = object() + frame_source = object() + groups = { + 7: [0], + 9: [1, 2], + } + levels = [ + ["united_atom"], + ["united_atom", "residue"], + ["united_atom", "residue", "polymer"], + ] - rdkit_heavy.GetAtoms = MagicMock(return_value=[a1, a2, a3]) - a1.GetHybridization = MagicMock(return_value="SP3") - a2.GetHybridization = MagicMock(return_value="SP3") - a3.GetHybridization = MagicMock(return_value="SP3") + result = helper.get_frame_neighbor_counts( + universe=universe, + levels=levels, + groups=groups, + frame_source=frame_source, + frame_index="12", + search_type="RAD", + ) - with patch("CodeEntropy.levels.neighbors.Chem", _FakeRDKit_Chem): - result = neighbors._get_linear(rdkit_mol, number_heavy) + assert result == { + 7: (2, 1), + 9: (4, 2), + } + assert search.rad_calls == [ + { + "universe": universe, + "mol_id": 0, + "frame_source": frame_source, + "frame_index": 12, + }, + { + "universe": universe, + "mol_id": 1, + "frame_source": frame_source, + "frame_index": 12, + }, + { + "universe": universe, + "mol_id": 2, + "frame_source": frame_source, + "frame_index": 12, + }, + ] - assert result is False +def test_get_frame_neighbor_counts_raises_for_unknown_search_type(): + helper = Neighbors() + helper._search = FakeSearch() -def test_get_symmetry_returns_dicts_for_single_group(): - neighbors = Neighbors() - universe = MagicMock() - groups = {7: [42, 99]} + with pytest.raises(ValueError, match="unknown search_type"): + helper.get_frame_neighbor_counts( + universe=object(), + levels=[["united_atom"]], + groups={0: [0]}, + frame_source=object(), + frame_index=0, + search_type="unknown", + ) - rdkit_mol = MagicMock() - neighbors._get_rdkit_mol = MagicMock(return_value=(rdkit_mol, 5, 8)) - neighbors._get_symmetry_number = MagicMock(return_value=12) - neighbors._get_linear = MagicMock(return_value=True) +def test_get_neighbors_for_molecule_rad_delegates_to_search(): + search = FakeSearch() + helper = Neighbors(search=search) - symmetry_number, linear = neighbors.get_symmetry(universe, groups) + universe = object() + frame_source = object() - assert symmetry_number == {7: 12} - assert linear == {7: True} + result = helper._get_neighbors_for_molecule( + universe=universe, + molecule_id=3, + highest_level="residue", + frame_source=frame_source, + frame_index=8, + search_type="RAD", + ) - neighbors._get_rdkit_mol.assert_called_once_with(universe, 42) - neighbors._get_symmetry_number.assert_called_once_with(rdkit_mol, 5, 8) - neighbors._get_linear.assert_called_once_with(rdkit_mol, 5) + assert result == [1, 2] + assert search.rad_calls == [ + { + "universe": universe, + "mol_id": 3, + "frame_source": frame_source, + "frame_index": 8, + } + ] -def test_get_symmetry_uses_first_molecule_in_each_group_only(): - neighbors = Neighbors() - universe = MagicMock() - groups = { - 0: [10, 11, 12], - 1: [20, 21], - } +def test_get_neighbors_for_molecule_grid_delegates_to_search_with_highest_level(): + search = FakeSearch() + helper = Neighbors(search=search) - rdkit_mol_0 = MagicMock() - rdkit_mol_1 = MagicMock() + universe = object() + frame_source = object() - neighbors._get_rdkit_mol = MagicMock( - side_effect=[ - (rdkit_mol_0, 3, 6), - (rdkit_mol_1, 4, 8), - ] + result = helper._get_neighbors_for_molecule( + universe=universe, + molecule_id=4, + highest_level="polymer", + frame_source=frame_source, + frame_index=9, + search_type="grid", ) - neighbors._get_symmetry_number = MagicMock(side_effect=[2, 4]) - neighbors._get_linear = MagicMock(side_effect=[False, True]) - - symmetry_number, linear = neighbors.get_symmetry(universe, groups) - assert symmetry_number == {0: 2, 1: 4} - assert linear == {0: False, 1: True} - - assert neighbors._get_rdkit_mol.call_args_list == [ - call(universe, 10), - call(universe, 20), + assert result == [1] + assert search.grid_calls == [ + { + "universe": universe, + "mol_id": 4, + "highest_level": "polymer", + "frame_source": frame_source, + "frame_index": 9, + } ] -def test_get_symmetry_calls_helpers_for_each_group_in_order(): - neighbors = Neighbors() - universe = MagicMock() - groups = { - 3: [100], - 5: [200], - } - - rdkit_mol_a = MagicMock() - rdkit_mol_b = MagicMock() +def test_get_neighbors_for_molecule_raises_for_unknown_search_type(): + helper = Neighbors(search=FakeSearch()) - neighbors._get_rdkit_mol = MagicMock( - side_effect=[ - (rdkit_mol_a, 1, 2), - (rdkit_mol_b, 7, 0), - ] - ) - neighbors._get_symmetry_number = MagicMock(side_effect=[9, 1]) - neighbors._get_linear = MagicMock(side_effect=[True, False]) + with pytest.raises(ValueError, match="unknown search_type unknown"): + helper._get_neighbors_for_molecule( + universe=object(), + molecule_id=0, + highest_level="united_atom", + frame_source=object(), + frame_index=0, + search_type="unknown", + ) - symmetry_number, linear = neighbors.get_symmetry(universe, groups) - assert symmetry_number == {3: 9, 5: 1} - assert linear == {3: True, 5: False} +def test_get_symmetry_calls_helpers_for_first_molecule_in_each_group(): + helper = Neighbors() + calls = [] - assert neighbors._get_symmetry_number.call_args_list == [ - call(rdkit_mol_a, 1, 2), - call(rdkit_mol_b, 7, 0), - ] - assert neighbors._get_linear.call_args_list == [ - call(rdkit_mol_a, 1), - call(rdkit_mol_b, 7), - ] + def fake_get_rdkit_mol(universe, molecule_id): + calls.append(molecule_id) + return f"mol-{molecule_id}", 2 + molecule_id, molecule_id + helper._get_rdkit_mol = fake_get_rdkit_mol + helper._get_symmetry_number = lambda rdkit_mol, number_heavy, number_hydrogen: ( + number_heavy + number_hydrogen + ) + helper._get_linear = lambda rdkit_mol, number_heavy: number_heavy == 2 -def test_get_symmetry_returns_empty_dicts_for_empty_groups(): - neighbors = Neighbors() - universe = MagicMock() - groups = {} + symmetry, linear = helper.get_symmetry( + universe=object(), + groups={7: [0, 1], 9: [2]}, + ) - neighbors._get_rdkit_mol = MagicMock() - neighbors._get_symmetry_number = MagicMock() - neighbors._get_linear = MagicMock() + assert calls == [0, 2] + assert symmetry == {7: 2, 9: 6} + assert linear == {7: True, 9: False} - symmetry_number, linear = neighbors.get_symmetry(universe, groups) - assert symmetry_number == {} - assert linear == {} +def test_get_symmetry_returns_zero_for_empty_groups(): + helper = Neighbors() + helper._get_rdkit_mol = MagicMock() - neighbors._get_rdkit_mol.assert_not_called() - neighbors._get_symmetry_number.assert_not_called() - neighbors._get_linear.assert_not_called() + symmetry, linear = helper.get_symmetry(universe=object(), groups={7: []}) + assert symmetry == {7: 0} + assert linear == {7: False} + helper._get_rdkit_mol.assert_not_called() -def test_get_symmetry_propagates_error_from_get_rdkit_mol(): - neighbors = Neighbors() - universe = MagicMock() - groups = {0: [123]} - neighbors._get_rdkit_mol = MagicMock(side_effect=RuntimeError("bad molecule")) - neighbors._get_symmetry_number = MagicMock() - neighbors._get_linear = MagicMock() +def test_get_symmetry_propagates_error_from_rdkit_conversion(): + helper = Neighbors() + helper._get_rdkit_mol = MagicMock(side_effect=RuntimeError("bad molecule")) with pytest.raises(RuntimeError, match="bad molecule"): - neighbors.get_symmetry(universe, groups) - - neighbors._get_symmetry_number.assert_not_called() - neighbors._get_linear.assert_not_called() - + helper.get_symmetry(universe=object(), groups={7: [0]}) -def test_get_rdkit_mol_guesses_elements_when_missing(): - neighbors = Neighbors() - universe = MagicMock() - molecule = MagicMock() - dummy = MagicMock() +def test_get_rdkit_mol_guesses_elements_when_missing_and_uses_normal_conversion(): + rdkit_mol = FakeRdkitMol(heavy_atoms=2, total_atoms=6) - del universe.atoms.elements - universe.atoms.fragments = [molecule] - - molecule.select_atoms.side_effect = [dummy] - dummy.__len__.return_value = 0 + dummy_atoms = FakeAtomSelection(length=0) + heavy_fragment = FakeAtomSelection(length=2, rdkit_mol=rdkit_mol) + molecule = FakeMolecule( + dummy_atoms=dummy_atoms, + heavy_fragment=heavy_fragment, + rdkit_mol=rdkit_mol, + ) - rdkit_mol = MagicMock() - rdkit_mol.GetNumHeavyAtoms.return_value = 3 - rdkit_mol.GetNumAtoms.return_value = 8 - molecule.convert_to.return_value = rdkit_mol + universe = SimpleNamespace( + atoms=FakeAtomsWithoutElements([molecule]), + guess_TopologyAttrs=MagicMock(), + ) - result = neighbors._get_rdkit_mol(universe, 0) + out_mol, number_heavy, number_hydrogen = Neighbors._get_rdkit_mol(universe, 0) universe.guess_TopologyAttrs.assert_called_once_with(to_guess=["elements"]) molecule.convert_to.assert_called_once_with("RDKIT", force=True) - assert result == (rdkit_mol, 3, 5) - -def test_get_rdkit_mol_does_not_guess_elements_when_present(): - neighbors = Neighbors() + assert out_mol is rdkit_mol + assert number_heavy == 2 + assert number_hydrogen == 4 - universe = MagicMock() - molecule = MagicMock() - dummy = MagicMock() - universe.atoms.elements = ["C", "H"] - universe.atoms.fragments = [molecule] +def test_get_rdkit_mol_skips_guessing_when_elements_exist(): + rdkit_mol = FakeRdkitMol(heavy_atoms=1, total_atoms=4) - molecule.select_atoms.side_effect = [dummy] - dummy.__len__.return_value = 0 + molecule = FakeMolecule( + dummy_atoms=FakeAtomSelection(length=0), + heavy_fragment=FakeAtomSelection(length=1, rdkit_mol=rdkit_mol), + rdkit_mol=rdkit_mol, + ) - rdkit_mol = MagicMock() - rdkit_mol.GetNumHeavyAtoms.return_value = 2 - rdkit_mol.GetNumAtoms.return_value = 6 - molecule.convert_to.return_value = rdkit_mol + universe = SimpleNamespace( + atoms=FakeAtomsWithElements([molecule]), + guess_TopologyAttrs=MagicMock(), + ) - result = neighbors._get_rdkit_mol(universe, 0) + out_mol, number_heavy, number_hydrogen = Neighbors._get_rdkit_mol(universe, 0) universe.guess_TopologyAttrs.assert_not_called() molecule.convert_to.assert_called_once_with("RDKIT", force=True) - assert result == (rdkit_mol, 2, 4) - - -def test_get_rdkit_mol_uses_full_molecule_when_no_dummy_atoms(): - neighbors = Neighbors() - universe = MagicMock() - molecule = MagicMock() - dummy = MagicMock() + assert out_mol is rdkit_mol + assert number_heavy == 1 + assert number_hydrogen == 3 - universe.atoms.elements = ["C"] - universe.atoms.fragments = [molecule] - molecule.select_atoms.side_effect = [dummy] - dummy.__len__.return_value = 0 +def test_get_rdkit_mol_uses_heavy_fragment_when_dummy_atoms_are_present(): + rdkit_mol = FakeRdkitMol(heavy_atoms=3, total_atoms=8) - rdkit_mol = MagicMock() - rdkit_mol.GetNumHeavyAtoms.return_value = 4 - rdkit_mol.GetNumAtoms.return_value = 10 - molecule.convert_to.return_value = rdkit_mol + dummy_atoms = FakeAtomSelection(length=2) + heavy_fragment = FakeAtomSelection(length=3, rdkit_mol=rdkit_mol) + molecule = FakeMolecule( + dummy_atoms=dummy_atoms, + heavy_fragment=heavy_fragment, + rdkit_mol=MagicMock(), + ) - result = neighbors._get_rdkit_mol(universe, 0) + universe = SimpleNamespace( + atoms=FakeAtomsWithElements([molecule]), + guess_TopologyAttrs=MagicMock(), + ) - molecule.select_atoms.assert_called_once_with("prop mass < 0.1") - molecule.convert_to.assert_called_once_with("RDKIT", force=True) - assert result == (rdkit_mol, 4, 6) + out_mol, number_heavy, number_hydrogen = Neighbors._get_rdkit_mol(universe, 0) + molecule.convert_to.assert_not_called() + heavy_fragment.convert_to.assert_called_once_with( + "RDKIT", + force=True, + inferrer=None, + ) -def test_get_rdkit_mol_removes_dummy_atoms_and_uses_inferrer_none(): - neighbors = Neighbors() + assert out_mol is rdkit_mol + assert number_heavy == 3 + assert number_hydrogen == 5 - universe = MagicMock() - molecule = MagicMock() - dummy = MagicMock() - frag = MagicMock() - universe.atoms.elements = ["C"] - universe.atoms.fragments = [molecule] +def test_get_rdkit_mol_falls_back_to_inferrer_none_when_normal_conversion_fails(): + rdkit_mol = FakeRdkitMol(heavy_atoms=2, total_atoms=5) - molecule.select_atoms.side_effect = [dummy, frag] - dummy.__len__.return_value = 2 + molecule = FakeMolecule( + dummy_atoms=FakeAtomSelection(length=0), + heavy_fragment=FakeAtomSelection(length=2, rdkit_mol=rdkit_mol), + rdkit_mol=rdkit_mol, + ) + molecule.convert_to = MagicMock( + side_effect=[RuntimeError("bad valence"), rdkit_mol] + ) - rdkit_mol = MagicMock() - rdkit_mol.GetNumHeavyAtoms.return_value = 5 - rdkit_mol.GetNumAtoms.return_value = 12 - frag.convert_to.return_value = rdkit_mol + universe = SimpleNamespace( + atoms=FakeAtomsWithElements([molecule]), + guess_TopologyAttrs=MagicMock(), + ) - result = neighbors._get_rdkit_mol(universe, 0) + out_mol, number_heavy, number_hydrogen = Neighbors._get_rdkit_mol(universe, 0) - assert molecule.select_atoms.call_args_list == [ - (("prop mass < 0.1",),), - (("prop mass > 0.1",),), + assert molecule.convert_to.call_args_list == [ + call("RDKIT", force=True), + call("RDKIT", force=True, inferrer=None), ] - frag.convert_to.assert_called_once_with("RDKIT", force=True, inferrer=None) - molecule.convert_to.assert_not_called() - assert result == (rdkit_mol, 5, 7) - -def test_get_rdkit_mol_returns_correct_heavy_and_hydrogen_counts(): - neighbors = Neighbors() + assert out_mol is rdkit_mol + assert number_heavy == 2 + assert number_hydrogen == 3 - universe = MagicMock() - molecule = MagicMock() - dummy = MagicMock() - universe.atoms.elements = ["O", "H", "H"] - universe.atoms.fragments = [molecule] +def test_get_symmetry_number_uses_heavy_atom_matches_for_multi_heavy_molecule(): + helper = Neighbors() + rdkit_mol = MagicMock() + rdkit_mol.GetSubstructMatches.return_value = [1, 2, 3] - molecule.select_atoms.side_effect = [dummy] - dummy.__len__.return_value = 0 + with patch("CodeEntropy.levels.neighbors.Chem.RemoveHs", return_value="heavy"): + assert helper._get_symmetry_number(rdkit_mol, 2, 0) == 3 - rdkit_mol = MagicMock() - rdkit_mol.GetNumHeavyAtoms.return_value = 1 - rdkit_mol.GetNumAtoms.return_value = 3 - molecule.convert_to.return_value = rdkit_mol + rdkit_mol.GetSubstructMatches.assert_called_once_with( + "heavy", + uniquify=False, + useChirality=True, + ) - rdkit_out, number_heavy, number_hydrogen = neighbors._get_rdkit_mol(universe, 0) - assert rdkit_out is rdkit_mol - assert number_heavy == 1 - assert number_hydrogen == 2 +def test_get_symmetry_number_uses_full_molecule_for_single_heavy_with_hydrogens(): + helper = Neighbors() + rdkit_mol = MagicMock() + rdkit_mol.GetSubstructMatches.return_value = [1, 2] + assert helper._get_symmetry_number(rdkit_mol, 1, 4) == 2 -def test_get_neighbors_returns_zero_for_each_group_when_no_frames_selected(): - neighbors = Neighbors() + rdkit_mol.GetSubstructMatches.assert_called_once_with( + rdkit_mol, + uniquify=False, + useChirality=True, + ) - universe = MagicMock() - levels = { - 0: ["united_atom"], - 1: ["residue"], - } - groups = { - 0: [0], - 1: [1], - } - frame_source = MagicMock() - frame_source.iter_indices.return_value = [] +def test_get_symmetry_number_returns_zero_for_single_heavy_without_hydrogens(): + assert Neighbors()._get_symmetry_number(MagicMock(), 1, 0) == 0 - neighbors._search.get_RAD_neighbors = MagicMock() - neighbors._search.get_grid_neighbors = MagicMock() - result = neighbors.get_neighbors( - universe=universe, - levels=levels, - groups=groups, - frame_source=frame_source, - search_type="RAD", - ) +def test_get_linear_for_one_or_two_heavy_atoms(): + helper = Neighbors() - assert result == { - 0: 0.0, - 1: 0.0, - } + assert helper._get_linear(MagicMock(), 1) is False + assert helper._get_linear(MagicMock(), 2) is True - frame_source.iter_indices.assert_called_once() - neighbors._search.get_RAD_neighbors.assert_not_called() - neighbors._search.get_grid_neighbors.assert_not_called() +def test_get_linear_for_larger_molecule_uses_sp_hybridisation_count(): + helper = Neighbors() + rdkit_mol = MagicMock() -def test_get_rdkit_mol_falls_back_to_inferrer_none_when_convert_to_raises(): - neighbors = Neighbors() + sp_atom = MagicMock() + sp_atom.GetHybridization.return_value = Chem.HybridizationType.SP + non_sp_atom = MagicMock() + non_sp_atom.GetHybridization.return_value = Chem.HybridizationType.SP3 - universe = MagicMock() - molecule = MagicMock() - dummy = MagicMock() + rdkit_heavy = MagicMock() + rdkit_heavy.GetAtoms.return_value = [sp_atom, sp_atom, non_sp_atom] - universe.atoms.elements = ["C", "H"] - universe.atoms.fragments = [molecule] + with patch("CodeEntropy.levels.neighbors.Chem.RemoveHs", return_value=rdkit_heavy): + assert helper._get_linear(rdkit_mol, 4) is True - molecule.select_atoms.return_value = dummy - dummy.__len__.return_value = 0 +def test_get_linear_for_larger_molecule_returns_false_when_too_few_sp_atoms(): + helper = Neighbors() rdkit_mol = MagicMock() - rdkit_mol.GetNumHeavyAtoms.return_value = 1 - rdkit_mol.GetNumAtoms.return_value = 4 - - molecule.convert_to.side_effect = [ - RuntimeError("constraint bond issue"), - rdkit_mol, - ] - result = neighbors._get_rdkit_mol(universe, mol_id=0) + non_sp_atom = MagicMock() + non_sp_atom.GetHybridization.return_value = Chem.HybridizationType.SP3 - assert result == (rdkit_mol, 1, 3) + rdkit_heavy = MagicMock() + rdkit_heavy.GetAtoms.return_value = [non_sp_atom, non_sp_atom, non_sp_atom] - molecule.select_atoms.assert_called_once_with("prop mass < 0.1") - molecule.convert_to.assert_has_calls( - [ - call("RDKIT", force=True), - call("RDKIT", force=True, inferrer=None), - ] - ) - universe.guess_TopologyAttrs.assert_not_called() + with patch("CodeEntropy.levels.neighbors.Chem.RemoveHs", return_value=rdkit_heavy): + assert helper._get_linear(rdkit_mol, 4) is False