Coverage for subcell_pipeline/visualization/combined_trajectory.py: 0%
75 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
1"""Visualization methods for combined simulators."""
3import os
4from typing import Optional
6import numpy as np
7import pandas as pd
8from io_collection.keys.check_key import check_key
9from io_collection.load.load_buffer import load_buffer
10from io_collection.load.load_dataframe import load_dataframe
11from io_collection.save.save_buffer import save_buffer
12from simulariumio import (
13 DISPLAY_TYPE,
14 AgentData,
15 CameraData,
16 DisplayData,
17 MetaData,
18 TrajectoryConverter,
19 TrajectoryData,
20 UnitData,
21)
23from subcell_pipeline.analysis.compression_metrics.compression_analysis import (
24 get_compression_metric_data,
25)
26from subcell_pipeline.analysis.compression_metrics.compression_metric import (
27 CompressionMetric,
28)
29from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import align_fibers
30from subcell_pipeline.visualization.scatter_plots import make_empty_scatter_plots
32BOX_SIZE: np.ndarray = np.array(3 * [600.0])
33"""Bounding box size for combined simulator trajectories."""
36def _load_fiber_points_from_dataframe(
37 dataframe: pd.DataFrame, n_timepoints: int
38) -> np.ndarray:
39 """
40 Load and reshape fiber points from sampled dataframe.
42 Sampled dataframe is in the shape (n_timepoints x n_fiber_points, 3); method
43 returns the dataframe reshaped to (n_timepoints, n_fiber_points x 3). If the
44 sampled dataframe does not have the expected number of timepoints, method
45 will raise an exception.
46 """
48 dataframe.sort_values(by=["time", "fiber_point"])
49 total_steps = dataframe.time.unique().shape[0]
51 if total_steps != n_timepoints:
52 raise Exception(
53 f"Requested number of timesteps [ {n_timepoints} ] does not match "
54 f"number of timesteps in dataset [ {total_steps} ]."
55 )
57 align_fibers(dataframe)
59 fiber_points = []
60 for _, group in dataframe.groupby("time"):
61 fiber_points.append(group[["xpos", "ypos", "zpos"]].values.flatten())
63 return np.array(fiber_points)
66def _get_combined_trajectory_converter(
67 fiber_points: list[np.ndarray],
68 type_names: list[str],
69 display_data: dict[str, DisplayData],
70) -> TrajectoryConverter:
71 """
72 Generate a TrajectoryConverter to visualize simulations from ReaDDy and
73 Cytosim together.
74 """
76 total_conditions = len(fiber_points)
77 total_steps = fiber_points[0].shape[0]
78 total_subpoints = fiber_points[0].shape[1]
80 traj_data = TrajectoryData(
81 meta_data=MetaData(
82 box_size=BOX_SIZE,
83 camera_defaults=CameraData(
84 position=np.array([75.0, 220.0, 15.0]),
85 look_at_position=np.array([75.0, 75.0, 0.0]),
86 fov_degrees=60.0,
87 ),
88 trajectory_title="Actin compression in Cytosim and Readdy",
89 ),
90 agent_data=AgentData(
91 times=np.arange(total_steps),
92 n_agents=total_conditions * np.ones(total_steps),
93 viz_types=1001
94 * np.ones((total_steps, total_conditions)), # fiber viz type = 1001
95 unique_ids=np.array(total_steps * [list(range(total_conditions))]),
96 types=total_steps * [type_names],
97 positions=np.zeros((total_steps, total_conditions, 3)),
98 radii=np.ones((total_steps, total_conditions)),
99 n_subpoints=total_subpoints * np.ones((total_steps, total_conditions)),
100 subpoints=np.moveaxis(np.array(fiber_points), [0, 1], [1, 0]),
101 display_data=display_data,
102 ),
103 time_units=UnitData("count"), # frames
104 spatial_units=UnitData("nm"), # nanometer
105 )
106 return TrajectoryConverter(traj_data)
109def _add_combined_plots(
110 converter: TrajectoryConverter,
111 metrics: list[CompressionMetric],
112 metrics_data: pd.DataFrame,
113 n_timepoints: int,
114 plot_names: list[tuple[str, str, int]],
115 type_names: list[str],
116) -> None:
117 """Add plots for combined trajectories with calculated metrics."""
118 scatter_plots = make_empty_scatter_plots(metrics, total_steps=n_timepoints)
120 for metric, plot in scatter_plots.items():
121 for plot_name, type_name in zip(plot_names, type_names):
122 simulator, key, seed = plot_name
123 simulator_data = metrics_data[simulator]
124 data = simulator_data[
125 (simulator_data["key"] == key) & (simulator_data["seed"] == seed)
126 ]
127 plot.ytraces[type_name] = np.array(data[metric.value])
128 converter.add_plot(plot, "scatter")
131def visualize_combined_trajectories(
132 buckets: dict[str, str],
133 series_name: str,
134 condition_keys: list[str],
135 replicates: list[int],
136 n_timepoints: int,
137 simulator_colors: dict[str, str],
138 temp_path: str,
139 metrics: Optional[list[CompressionMetric]] = None,
140 recalculate: bool = False,
141) -> None:
142 """
143 Visualize combined simulations from ReaDDy and Cytosim for select conditions
144 and number of replicates.
146 Parameters
147 ----------
148 buckets
149 Names of S3 buckets for input and output files for each simulator and
150 visualization.
151 series_name
152 Name of simulation series.
153 condition_keys
154 List of condition keys.
155 replicates
156 Simulation replicates ids.
157 n_timepoints
158 Number of equally spaced timepoints to visualize.
159 simulator_colors
160 Map of simulator name to color.
161 temp_path
162 Local path for saving visualization output files.
163 metrics
164 List of metrics to include in visualization plots.
165 recalculate
166 True to recalculate visualization files, False otherwise.
167 """
169 fiber_points = []
170 type_names = []
171 plot_names = []
172 display_data = {}
173 all_metrics_data = {}
175 for simulator, color in simulator_colors.items():
176 bucket = buckets[simulator]
178 # Load calculated compression metric data.
179 if metrics is not None:
180 all_metrics_data[simulator] = get_compression_metric_data(
181 bucket,
182 series_name,
183 condition_keys,
184 replicates,
185 metrics,
186 recalculate=recalculate,
187 )
188 else:
189 metrics = []
190 all_metrics_data[simulator] = pd.DataFrame(columns=["key", "seed"])
192 for condition_key in condition_keys:
193 series_key = (
194 f"{series_name}_{condition_key}" if condition_key else series_name
195 )
197 for replicate in replicates:
198 dataframe_key = (
199 f"{series_name}/samples/{series_key}_{replicate:06d}.csv"
200 )
202 # Skip if input dataframe does not exist.
203 if not check_key(bucket, dataframe_key):
204 print(
205 f"Dataframe not available for {simulator} "
206 f"[ { dataframe_key } ]. Skipping."
207 )
208 continue
210 print(
211 f"Loading data for [ {simulator} ] "
212 f"condition [ { dataframe_key } ] "
213 f"replicate [ {replicate} ]"
214 )
216 dataframe = load_dataframe(bucket, dataframe_key)
217 fiber_points.append(
218 _load_fiber_points_from_dataframe(dataframe, n_timepoints)
219 )
221 condition = int(condition_key) / 10
222 condition = round(condition) if condition_key[-1] == "0" else condition
224 type_names.append(f"{simulator}#{condition} um/s {replicate}")
225 plot_names.append((simulator, condition_key, replicate))
226 display_data[type_names[-1]] = DisplayData(
227 name=type_names[-1],
228 display_type=DISPLAY_TYPE.FIBER,
229 color=color,
230 )
232 converter = _get_combined_trajectory_converter(
233 fiber_points, type_names, display_data
234 )
236 if metrics:
237 _add_combined_plots(
238 converter, metrics, all_metrics_data, n_timepoints, plot_names, type_names
239 )
241 output_key = "actin_compression_cytosim_readdy.simularium"
242 local_file_path = os.path.join(temp_path, output_key)
243 converter.save(output_path=local_file_path.replace(".simularium", ""))
244 output_bucket = buckets["combined"]
245 save_buffer(output_bucket, output_key, load_buffer(temp_path, output_key))