from collections.abc import Iterable
import numpy as np
import plotly.colors as pc
import plotly.graph_objects as go
from wetting_angle_kit.analysis.slicing.results import SlicingResults
from wetting_angle_kit.visualization.base_trajectory_plotter import (
BaseTrajectoryPlotter,
)
from wetting_angle_kit.visualization.stats import TrajectoryStats
def _shoelace_area(points: np.ndarray) -> float:
"""Polygon area via the shoelace formula."""
x = points[:, 0]
y = points[:, 1]
return float(0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))))
def _hex_to_rgba(hex_color: str, alpha: float) -> str:
"""Return a CSS ``rgba(...)`` string from a ``#rrggbb`` hex color."""
h = hex_color.lstrip("#")
r, g, b = int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
return f"rgba({r},{g},{b},{alpha})"
[docs]
class SlicingTrajectoryPlotter(BaseTrajectoryPlotter):
"""Plot statistics derived from one or more :class:`SlicingResults`."""
def __init__(
self,
results: SlicingResults | Iterable[SlicingResults],
labels: list[str] | None = None,
time_steps: list[float] | None = None,
time_unit: str = "ps",
) -> None:
"""
Parameters
----------
results : SlicingResults or iterable of SlicingResults
One results container per trajectory.
labels : list of str, optional
Display labels (one per results container). Defaults to
``["trajectory_0", ...]``.
time_steps : list of float, optional
Per-trajectory time step applied to ``frames`` for the time
axis of evolution plots. Defaults to ``1.0`` for each.
time_unit : str, optional
Time unit shown on x-axis labels.
"""
if isinstance(results, SlicingResults):
results = [results]
else:
results = list(results)
self.results = results
self.labels = labels or [f"trajectory_{i}" for i in range(len(results))]
self.time_steps = time_steps or [1.0] * len(results)
self.time_unit = time_unit
def _mean_surface_areas(self, result: SlicingResults) -> np.ndarray:
"""Per-frame mean polygon area (shoelace over the frame's slices)."""
return np.array(
[
float(np.mean([_shoelace_area(s) for s in frame_surfaces]))
for frame_surfaces in result.surfaces
]
)
[docs]
def summary(self) -> list[TrajectoryStats]:
stats: list[TrajectoryStats] = []
for label, result in zip(self.labels, self.results, strict=False):
surfaces = self._mean_surface_areas(result)
stats.append(
TrajectoryStats(
method_name="Slicing Analysis",
label=label,
mean_surface_area=float(np.mean(surfaces)),
mean_contact_angle=result.mean_angle,
std_contact_angle=result.std_angle,
n_samples=len(result),
)
)
return stats
[docs]
def plot_angle_evolution(
self,
stat: str = "median",
per_frame_std: bool = True,
running_mean: bool = True,
timestep: float | None = None,
time_unit: str | None = None,
save_path: str | None = None,
) -> go.Figure:
"""Plot per-frame contact angle as a function of time.
Parameters
----------
stat : str, default "median"
Per-frame aggregation across slices; one of ``"median"`` or ``"mean"``.
per_frame_std : bool, default True
If True, draw a transparent ±σ band around the per-frame curve
using the inter-slice spread within each frame — shows how noisy
the contact angle estimate is at each instant.
running_mean : bool, default True
If True, overlay the cumulative running mean of the per-frame
central tendency as a dashed line, plus a transparent ±σ band
of that cumulative series — shows how the time-averaged contact
angle converges as more frames are accumulated.
timestep : float, optional
Time between two consecutive frames *in the trajectory file*
(i.e. dump interval × MD integration timestep). Applied
uniformly to all trajectories, overriding the per-trajectory
``time_steps`` passed at construction. This is **not** the MD
integration timestep — it is the spacing between frames as
they appear in the dump.
time_unit : str, optional
Override for the x-axis time unit label. Defaults to the
``time_unit`` passed at construction.
save_path : str, optional
If provided, write the figure as standalone HTML.
Returns
-------
plotly.graph_objects.Figure
Figure with one per-frame line per trajectory, optionally with
an inter-slice ±σ band and/or a running mean line with its
cumulative ±σ band.
"""
if stat not in ("median", "mean"):
raise ValueError(f"stat must be 'median' or 'mean', got {stat!r}")
agg = np.median if stat == "median" else np.mean
palette = pc.qualitative.Plotly
band_traces: list[go.Scatter] = []
line_traces: list[go.Scatter] = []
effective_unit = time_unit if time_unit is not None else self.time_unit
for idx, (label, result, default_dt) in enumerate(
zip(self.labels, self.results, self.time_steps, strict=False)
):
dt = timestep if timestep is not None else default_dt
color = palette[idx % len(palette)]
band_color = _hex_to_rgba(color, 0.2)
times = np.array(result.frames) * dt
per_frame = np.array([float(agg(a)) for a in result.angles])
per_frame_group = label
running_group = f"{label} running mean"
line_traces.append(
go.Scatter(
x=times,
y=per_frame,
mode="lines",
name=label,
line=dict(width=2, color=color),
legendgroup=per_frame_group,
)
)
if per_frame_std:
std = np.array([float(np.std(a)) for a in result.angles])
band_traces.append(
go.Scatter(
x=np.concatenate([times, times[::-1]]),
y=np.concatenate([per_frame + std, (per_frame - std)[::-1]]),
fill="toself",
fillcolor=band_color,
line=dict(width=0),
name=f"{label} ±σ",
legendgroup=per_frame_group,
showlegend=False,
hoverinfo="skip",
)
)
if running_mean:
counts = np.arange(1, len(per_frame) + 1)
cum_mean = np.cumsum(per_frame) / counts
sq_mean = np.cumsum(per_frame**2) / counts
cum_std = np.sqrt(np.maximum(sq_mean - cum_mean**2, 0.0))
band_traces.append(
go.Scatter(
x=np.concatenate([times, times[::-1]]),
y=np.concatenate(
[cum_mean + cum_std, (cum_mean - cum_std)[::-1]]
),
fill="toself",
fillcolor=band_color,
line=dict(width=0),
name=f"{label} running ±σ",
legendgroup=running_group,
showlegend=False,
hoverinfo="skip",
)
)
line_traces.append(
go.Scatter(
x=times,
y=cum_mean,
mode="lines",
name=running_group,
line=dict(width=2, color=color, dash="dash"),
legendgroup=running_group,
)
)
fig = go.Figure()
for trace in band_traces:
fig.add_trace(trace)
for trace in line_traces:
fig.add_trace(trace)
fig.update_layout(
title=f"Contact angle evolution ({stat})",
xaxis_title=f"Time ({effective_unit})",
yaxis_title="Contact angle (°)",
template="plotly_white",
)
if save_path:
fig.write_html(save_path)
return fig