"""Visualization methods for individual simulators."""
from typing import Optional
import numpy as np
import pandas as pd
from io_collection.keys.check_key import check_key
from io_collection.load.load_buffer import load_buffer
from io_collection.load.load_text import load_text
from io_collection.save.save_buffer import save_buffer
from pint import UnitRegistry
from simulariumio import (
DISPLAY_TYPE,
CameraData,
DisplayData,
InputFileData,
MetaData,
TrajectoryConverter,
UnitData,
)
from simulariumio.cytosim import CytosimConverter, CytosimData, CytosimObjectInfo
from simulariumio.filters import EveryNthTimestepFilter
from simulariumio.readdy import ReaddyConverter, ReaddyData
from subcell_pipeline.analysis.compression_metrics.compression_analysis import (
get_compression_metric_data,
)
from subcell_pipeline.analysis.compression_metrics.compression_metric import (
CompressionMetric,
)
from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import align_fiber
from subcell_pipeline.simulation.cytosim.post_processing import CYTOSIM_SCALE_FACTOR
from subcell_pipeline.simulation.readdy.loader import ReaddyLoader
from subcell_pipeline.simulation.readdy.parser import BOX_SIZE as READDY_BOX_SIZE
from subcell_pipeline.simulation.readdy.parser import (
READDY_TIMESTEP,
download_readdy_hdf5,
)
from subcell_pipeline.simulation.readdy.post_processor import ReaddyPostProcessor
from subcell_pipeline.visualization.display_data import get_readdy_display_data
from subcell_pipeline.visualization.scatter_plots import make_empty_scatter_plots
from subcell_pipeline.visualization.spatial_annotator import (
add_fiber_annotation_agents,
add_sphere_annotation_agents,
)
READDY_SAVED_FRAMES: int = 1000
"""Number of saved frames for ReaDDy simulations."""
BOX_SIZE: np.ndarray = np.array(3 * [600.0])
"""Bounding box size for individual simulator trajectory."""
def _add_individual_plots(
converter: TrajectoryConverter,
metrics: list[CompressionMetric],
metrics_data: pd.DataFrame,
times: np.ndarray,
time_units: UnitData,
) -> None:
"""Add plots to individual trajectory with calculated metrics."""
scatter_plots = make_empty_scatter_plots(
metrics, times=times, time_units=time_units
)
for metric, plot in scatter_plots.items():
plot.ytraces["filament"] = np.array(metrics_data[metric.value])
converter.add_plot(plot, "scatter")
def _add_readdy_spatial_annotations(
converter: TrajectoryConverter,
post_processor: ReaddyPostProcessor,
n_monomer_points: int,
) -> None:
"""
Add visualizations of edges, normals, and control points to the ReaDDy
Simularium data.
"""
fiber_chain_ids = post_processor.linear_fiber_chain_ids(polymer_number_range=5)
axis_positions, _ = post_processor.linear_fiber_axis_positions(fiber_chain_ids)
fiber_points = post_processor.linear_fiber_control_points(
axis_positions=axis_positions,
n_points=n_monomer_points,
)
converter._data.agent_data.positions, fiber_points = (
post_processor.align_trajectory(fiber_points)
)
axis_positions, _ = post_processor.linear_fiber_axis_positions(fiber_chain_ids)
edges = post_processor.edge_positions()
# edges
converter._data = add_fiber_annotation_agents(
converter._data,
fiber_points=edges,
type_name="edge",
fiber_width=0.5,
color="#eaeaea",
)
# normals
normals = post_processor.linear_fiber_normals(
fiber_chain_ids=fiber_chain_ids,
axis_positions=axis_positions,
normal_length=10.0,
)
converter._data = add_fiber_annotation_agents(
converter._data,
fiber_points=normals,
type_name="normal",
fiber_width=0.5,
color="#685bf3",
)
# control points
sphere_positions = []
for time_ix in range(len(fiber_points)):
sphere_positions.append(fiber_points[time_ix][0])
converter._data = add_sphere_annotation_agents(
converter._data,
sphere_positions,
type_name="fiber point",
radius=0.8,
rainbow_colors=True,
)
def _get_readdy_simularium_converter(
path_to_readdy_h5: str,
total_steps: int,
n_timepoints: int,
) -> TrajectoryConverter:
"""
Load from ReaDDy outputs and generate a TrajectoryConverter to visualize an
actin trajectory in Simularium.
"""
converter = ReaddyConverter(
ReaddyData(
timestep=1e-6 * (READDY_TIMESTEP * total_steps / READDY_SAVED_FRAMES),
path_to_readdy_h5=path_to_readdy_h5,
meta_data=MetaData(
box_size=READDY_BOX_SIZE,
camera_defaults=CameraData(
position=np.array([70.0, 70.0, 300.0]),
look_at_position=np.array([70.0, 70.0, 0.0]),
fov_degrees=60.0,
),
scale_factor=1.0,
),
display_data=get_readdy_display_data(),
time_units=UnitData("ms"),
spatial_units=UnitData("nm"),
)
)
return _filter_time(converter, n_timepoints)
[docs]
def visualize_individual_readdy_trajectory(
bucket: str,
series_name: str,
series_key: str,
rep_ix: int,
n_timepoints: int,
n_monomer_points: int,
total_steps: int,
temp_path: str,
metrics: list[CompressionMetric],
metrics_data: pd.DataFrame,
) -> None:
"""
Save a Simularium file for a single ReaDDy trajectory with plots and spatial
annotations.
Parameters
----------
bucket
Name of S3 bucket for input and output files.
series_name
Name of simulation series.
series_key
Combination of series and condition names.
rep_ix
Replicate index.
n_timepoints
Number of equally spaced timepoints to visualize.
n_monomer_points
Number of equally spaced monomer points to visualize.
total_steps
Total number of steps for each simulation key.
temp_path
Local path for saving visualization output files.
metrics
List of metrics to include in visualization plots.
metrics_data
Calculated compression metrics data.
"""
h5_file_path = download_readdy_hdf5(
bucket, series_name, series_key, rep_ix, temp_path
)
assert isinstance(h5_file_path, str)
converter = _get_readdy_simularium_converter(
h5_file_path, total_steps, n_timepoints
)
if metrics:
times = 2 * metrics_data["time"].values # "time" seems to range (0, 0.5)
times *= 1e-6 * (READDY_TIMESTEP * total_steps / n_timepoints)
_add_individual_plots(
converter, metrics, metrics_data, times, converter._data.time_units
)
assert isinstance(h5_file_path, str)
rep_id = rep_ix + 1
pickle_key = f"{series_name}/data/{series_key}_{rep_id:06d}.pkl"
time_inc = total_steps // n_timepoints
readdy_loader = ReaddyLoader(
h5_file_path=h5_file_path,
time_inc=time_inc,
timestep=READDY_TIMESTEP,
pickle_location=bucket,
pickle_key=pickle_key,
)
post_processor = ReaddyPostProcessor(
readdy_loader.trajectory(), box_size=READDY_BOX_SIZE
)
_add_readdy_spatial_annotations(converter, post_processor, n_monomer_points)
# Save simularium file. Turn off validate IDs for performance.
converter.save(output_path=h5_file_path, validate_ids=False)
[docs]
def visualize_individual_readdy_trajectories(
bucket: str,
series_name: str,
condition_keys: list[str],
n_replicates: int,
n_timepoints: int,
n_monomer_points: int,
total_steps: dict[str, int],
temp_path: str,
metrics: Optional[list[CompressionMetric]] = None,
recalculate: bool = True,
) -> None:
"""
Visualize individual ReaDDy simulations for select conditions and
replicates.
Parameters
----------
bucket
Name of S3 bucket for input and output files.
series_name
Name of simulation series.
condition_keys
List of condition keys.
n_replicates
Number of simulation replicates.
n_timepoints
Number of equally spaced timepoints to visualize.
n_monomer_points
Number of equally spaced monomer points to visualize.
total_steps
Total number of steps for each simulation key.
temp_path
Local path for saving visualization output files.
metrics
List of metrics to include in visualization plots.
recalculate
True to recalculate visualization files, False otherwise.
"""
if metrics is not None:
all_metrics_data = get_compression_metric_data(
bucket,
series_name,
condition_keys,
list(range(1, n_replicates + 1)),
metrics,
recalculate=False,
)
else:
metrics = []
all_metrics_data = pd.DataFrame(columns=["key", "seed"])
for condition_key in condition_keys:
series_key = f"{series_name}_{condition_key}" if condition_key else series_name
for rep_ix in range(n_replicates):
rep_id = rep_ix + 1
output_key = f"{series_name}/viz/{series_key}_{rep_id:06d}.simularium"
# Skip if output file already exists.
if not recalculate and check_key(bucket, output_key):
print(
f"Simularium file for [ { output_key } ] already exists. Skipping."
)
continue
print(f"Visualizing data for [ {condition_key} ] replicate [ {rep_ix} ]")
# Filter metrics data for specific conditon and replicate.
if condition_key:
metrics_data = all_metrics_data[
(all_metrics_data["key"] == condition_key)
& (all_metrics_data["seed"] == rep_id)
]
else:
metrics_data = all_metrics_data[(all_metrics_data["seed"] == rep_id)]
visualize_individual_readdy_trajectory(
bucket,
series_name,
series_key,
rep_ix,
n_timepoints,
n_monomer_points,
total_steps[condition_key],
temp_path,
metrics,
metrics_data,
)
# Upload saved file to S3.
temp_key = f"{series_key}_{rep_ix}.h5.simularium"
save_buffer(bucket, output_key, load_buffer(temp_path, temp_key))
def _find_time_units(raw_time: float, units: str = "s") -> tuple[str, float]:
"""Get compact time units and a multiplier to put the times in those units."""
time = UnitRegistry().Quantity(raw_time, units)
time_compact = time.to_compact()
return f"{time_compact.units:~}", time_compact.magnitude / raw_time
def _filter_time(
converter: TrajectoryConverter, n_timepoints: int
) -> TrajectoryConverter:
"""Filter times using simulariumio time filter."""
time_inc = int(converter._data.agent_data.times.shape[0] / n_timepoints)
if time_inc < 2:
return converter
converter._data = converter.filter_data([EveryNthTimestepFilter(n=time_inc)])
return converter
def _align_cytosim_fiber(converter: TrajectoryConverter) -> None:
"""
Align the fiber subpoints so that the furthest point from the x-axis
is aligned with the positive y-axis at the last time point.
"""
fiber_points = converter._data.agent_data.subpoints[:, 0, :]
n_timesteps = fiber_points.shape[0]
n_points = int(fiber_points.shape[1] / 3)
fiber_points = fiber_points.reshape((n_timesteps, n_points, 3))
_, rotation = align_fiber(fiber_points[-1])
for time_ix in range(n_timesteps):
rotated = np.dot(fiber_points[time_ix][:, 1:], rotation)
converter._data.agent_data.subpoints[time_ix, 0, :] = np.concatenate(
(fiber_points[time_ix][:, 0:1], rotated), axis=1
).reshape(n_points * 3)
def _get_cytosim_simularium_converter(
fiber_points_data: str,
singles_data: str,
n_timepoints: int,
) -> TrajectoryConverter:
"""
Load from Cytosim outputs and generate a TrajectoryConverter to visualize an
actin trajectory in Simularium.
"""
singles_display_data = DisplayData(
name="linker",
radius=0.004,
display_type=DISPLAY_TYPE.SPHERE,
color="#eaeaea",
)
converter = CytosimConverter(
CytosimData(
meta_data=MetaData(
box_size=BOX_SIZE,
camera_defaults=CameraData(
position=np.array([70.0, 70.0, 300.0]),
look_at_position=np.array([70.0, 70.0, 0.0]),
fov_degrees=60.0,
),
scale_factor=1,
),
object_info={
"fibers": CytosimObjectInfo(
cytosim_file=InputFileData(
file_contents=fiber_points_data,
),
display_data={
1: DisplayData(
name="actin",
radius=0.002,
display_type=DISPLAY_TYPE.FIBER,
)
},
),
"singles": CytosimObjectInfo(
cytosim_file=InputFileData(
file_contents=singles_data,
),
display_data={
1: singles_display_data,
2: singles_display_data,
3: singles_display_data,
4: singles_display_data,
},
),
},
)
)
_align_cytosim_fiber(converter)
converter._data.agent_data.radii *= CYTOSIM_SCALE_FACTOR
converter._data.agent_data.positions *= CYTOSIM_SCALE_FACTOR
converter._data.agent_data.subpoints *= CYTOSIM_SCALE_FACTOR
converter = _filter_time(converter, n_timepoints)
time_units, time_multiplier = _find_time_units(converter._data.agent_data.times[-1])
converter._data.agent_data.times *= time_multiplier
converter._data.time_units = UnitData(time_units)
return converter
[docs]
def visualize_individual_cytosim_trajectory(
bucket: str,
series_name: str,
series_key: str,
index: int,
n_timepoints: int,
temp_path: str,
metrics: list[CompressionMetric],
metrics_data: pd.DataFrame,
) -> None:
"""
Save a Simularium file for a single Cytosim trajectory with plots and
spatial annotations.
Parameters
----------
bucket
Name of S3 bucket for input and output files.
series_name
Name of simulation series.
series_key
Combination of series and condition names.
index
Simulation replicate index.
n_timepoints
Number of equally spaced timepoints to visualize.
temp_path
Local path for saving visualization output files.
metrics
List of metrics to include in visualization plots.
metrics_data
Calculated compression metrics data.
"""
output_key_template = f"{series_name}/outputs/{series_key}_{index}/%s"
fiber_points_data = load_text(bucket, output_key_template % "fiber_points.txt")
singles_data = load_text(bucket, output_key_template % "singles.txt")
converter = _get_cytosim_simularium_converter(
fiber_points_data, singles_data, n_timepoints
)
if metrics:
times = 1e3 * metrics_data["time"].values # s --> ms
_add_individual_plots(
converter, metrics, metrics_data, times, converter._data.time_units
)
# Save simularium file. Turn off validate IDs for performance.
local_file_path = f"{temp_path}/{series_key}_{index}"
converter.save(output_path=local_file_path, validate_ids=False)
[docs]
def visualize_individual_cytosim_trajectories(
bucket: str,
series_name: str,
condition_keys: list[str],
random_seeds: list[int],
n_timepoints: int,
temp_path: str,
metrics: Optional[list[CompressionMetric]] = None,
recalculate: bool = True,
) -> None:
"""
Visualize individual Cytosim simulations for select conditions and
replicates.
Parameters
----------
bucket
Name of S3 bucket for input and output files.
series_name
Name of simulation series.
condition_keys
List of condition keys.
random_seeds
Random seeds for simulations.
n_timepoints
Number of equally spaced timepoints to visualize.
temp_path
Local path for saving visualization output files.
metrics
List of metrics to include in visualization plots.
recalculate
True to recalculate visualization files, False otherwise.
"""
if metrics is not None:
all_metrics_data = get_compression_metric_data(
bucket,
series_name,
condition_keys,
random_seeds,
metrics,
recalculate=False,
)
else:
metrics = []
all_metrics_data = pd.DataFrame(columns=["key", "seed"])
for condition_key in condition_keys:
series_key = f"{series_name}_{condition_key}" if condition_key else series_name
for index, seed in enumerate(random_seeds):
output_key = f"{series_name}/viz/{series_key}_{seed:06d}.simularium"
# Skip if output file already exists.
if not recalculate and check_key(bucket, output_key):
print(
f"Simularium file for [ { output_key } ] already exists. Skipping."
)
continue
print(f"Visualizing data for [ {condition_key} ] seed [ {seed} ]")
# Filter metrics data for specific conditon and replicate.
if condition_key:
metrics_data = all_metrics_data[
(all_metrics_data["key"] == condition_key)
& (all_metrics_data["seed"] == seed)
]
else:
metrics_data = all_metrics_data[(all_metrics_data["seed"] == seed)]
visualize_individual_cytosim_trajectory(
bucket,
series_name,
series_key,
index,
n_timepoints,
temp_path,
metrics,
metrics_data,
)
temp_key = f"{series_key}_{index}.simularium"
save_buffer(bucket, output_key, load_buffer(temp_path, temp_key))