"""Visualization methods for combined simulators."""
import os
from typing import Optional
import numpy as np
import pandas as pd
from io_collection.keys.check_key import check_key
from io_collection.load.load_buffer import load_buffer
from io_collection.load.load_dataframe import load_dataframe
from io_collection.save.save_buffer import save_buffer
from simulariumio import (
DISPLAY_TYPE,
AgentData,
CameraData,
DisplayData,
MetaData,
TrajectoryConverter,
TrajectoryData,
UnitData,
)
from subcell_pipeline.analysis.compression_metrics.compression_analysis import (
get_compression_metric_data,
)
from subcell_pipeline.analysis.compression_metrics.compression_metric import (
CompressionMetric,
)
from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import align_fibers
from subcell_pipeline.visualization.scatter_plots import make_empty_scatter_plots
BOX_SIZE: np.ndarray = np.array(3 * [600.0])
"""Bounding box size for combined simulator trajectories."""
def _load_fiber_points_from_dataframe(
dataframe: pd.DataFrame, n_timepoints: int
) -> np.ndarray:
"""
Load and reshape fiber points from sampled dataframe.
Sampled dataframe is in the shape (n_timepoints x n_fiber_points, 3); method
returns the dataframe reshaped to (n_timepoints, n_fiber_points x 3). If the
sampled dataframe does not have the expected number of timepoints, method
will raise an exception.
"""
dataframe.sort_values(by=["time", "fiber_point"])
total_steps = dataframe.time.unique().shape[0]
if total_steps != n_timepoints:
raise Exception(
f"Requested number of timesteps [ {n_timepoints} ] does not match "
f"number of timesteps in dataset [ {total_steps} ]."
)
align_fibers(dataframe)
fiber_points = []
for _, group in dataframe.groupby("time"):
fiber_points.append(group[["xpos", "ypos", "zpos"]].values.flatten())
return np.array(fiber_points)
def _get_combined_trajectory_converter(
fiber_points: list[np.ndarray],
type_names: list[str],
display_data: dict[str, DisplayData],
) -> TrajectoryConverter:
"""
Generate a TrajectoryConverter to visualize simulations from ReaDDy and
Cytosim together.
"""
total_conditions = len(fiber_points)
total_steps = fiber_points[0].shape[0]
total_subpoints = fiber_points[0].shape[1]
traj_data = TrajectoryData(
meta_data=MetaData(
box_size=BOX_SIZE,
camera_defaults=CameraData(
position=np.array([75.0, 220.0, 15.0]),
look_at_position=np.array([75.0, 75.0, 0.0]),
fov_degrees=60.0,
),
trajectory_title="Actin compression in Cytosim and Readdy",
),
agent_data=AgentData(
times=np.arange(total_steps),
n_agents=total_conditions * np.ones(total_steps),
viz_types=1001
* np.ones((total_steps, total_conditions)), # fiber viz type = 1001
unique_ids=np.array(total_steps * [list(range(total_conditions))]),
types=total_steps * [type_names],
positions=np.zeros((total_steps, total_conditions, 3)),
radii=np.ones((total_steps, total_conditions)),
n_subpoints=total_subpoints * np.ones((total_steps, total_conditions)),
subpoints=np.moveaxis(np.array(fiber_points), [0, 1], [1, 0]),
display_data=display_data,
),
time_units=UnitData("count"), # frames
spatial_units=UnitData("nm"), # nanometer
)
return TrajectoryConverter(traj_data)
def _add_combined_plots(
converter: TrajectoryConverter,
metrics: list[CompressionMetric],
metrics_data: pd.DataFrame,
n_timepoints: int,
plot_names: list[tuple[str, str, int]],
type_names: list[str],
) -> None:
"""Add plots for combined trajectories with calculated metrics."""
scatter_plots = make_empty_scatter_plots(metrics, total_steps=n_timepoints)
for metric, plot in scatter_plots.items():
for plot_name, type_name in zip(plot_names, type_names):
simulator, key, seed = plot_name
simulator_data = metrics_data[simulator]
data = simulator_data[
(simulator_data["key"] == key) & (simulator_data["seed"] == seed)
]
plot.ytraces[type_name] = np.array(data[metric.value])
converter.add_plot(plot, "scatter")
[docs]
def visualize_combined_trajectories(
buckets: dict[str, str],
series_name: str,
condition_keys: list[str],
replicates: list[int],
n_timepoints: int,
simulator_colors: dict[str, str],
temp_path: str,
metrics: Optional[list[CompressionMetric]] = None,
recalculate: bool = False,
) -> None:
"""
Visualize combined simulations from ReaDDy and Cytosim for select conditions
and number of replicates.
Parameters
----------
buckets
Names of S3 buckets for input and output files for each simulator and
visualization.
series_name
Name of simulation series.
condition_keys
List of condition keys.
replicates
Simulation replicates ids.
n_timepoints
Number of equally spaced timepoints to visualize.
simulator_colors
Map of simulator name to color.
temp_path
Local path for saving visualization output files.
metrics
List of metrics to include in visualization plots.
recalculate
True to recalculate visualization files, False otherwise.
"""
fiber_points = []
type_names = []
plot_names = []
display_data = {}
all_metrics_data = {}
for simulator, color in simulator_colors.items():
bucket = buckets[simulator]
# Load calculated compression metric data.
if metrics is not None:
all_metrics_data[simulator] = get_compression_metric_data(
bucket,
series_name,
condition_keys,
replicates,
metrics,
recalculate=recalculate,
)
else:
metrics = []
all_metrics_data[simulator] = pd.DataFrame(columns=["key", "seed"])
for condition_key in condition_keys:
series_key = (
f"{series_name}_{condition_key}" if condition_key else series_name
)
for replicate in replicates:
dataframe_key = (
f"{series_name}/samples/{series_key}_{replicate:06d}.csv"
)
# Skip if input dataframe does not exist.
if not check_key(bucket, dataframe_key):
print(
f"Dataframe not available for {simulator} "
f"[ { dataframe_key } ]. Skipping."
)
continue
print(
f"Loading data for [ {simulator} ] "
f"condition [ { dataframe_key } ] "
f"replicate [ {replicate} ]"
)
dataframe = load_dataframe(bucket, dataframe_key)
fiber_points.append(
_load_fiber_points_from_dataframe(dataframe, n_timepoints)
)
condition = int(condition_key) / 10
condition = round(condition) if condition_key[-1] == "0" else condition
type_names.append(f"{simulator}#{condition} um/s {replicate}")
plot_names.append((simulator, condition_key, replicate))
display_data[type_names[-1]] = DisplayData(
name=type_names[-1],
display_type=DISPLAY_TYPE.FIBER,
color=color,
)
converter = _get_combined_trajectory_converter(
fiber_points, type_names, display_data
)
if metrics:
_add_combined_plots(
converter, metrics, all_metrics_data, n_timepoints, plot_names, type_names
)
output_key = "actin_compression_cytosim_readdy.simularium"
local_file_path = os.path.join(temp_path, output_key)
converter.save(output_path=local_file_path.replace(".simularium", ""))
output_bucket = buckets["combined"]
save_buffer(output_bucket, output_key, load_buffer(temp_path, output_key))