Coverage for subcell_pipeline/visualization/dimensionality_reduction.py: 0%
96 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
1"""Visualization methods for dimensionality reduction analysis."""
3import os
5import matplotlib.pyplot as plt
6import numpy as np
7from io_collection.load.load_buffer import load_buffer
8from io_collection.load.load_dataframe import load_dataframe
9from io_collection.load.load_pickle import load_pickle
10from io_collection.save.save_buffer import save_buffer
11from matplotlib.colors import Colormap
12from simulariumio import DISPLAY_TYPE, CameraData, DisplayData, MetaData, UnitData
13from sklearn.decomposition import PCA
15from subcell_pipeline.visualization.fiber_points import (
16 generate_trajectory_converter_for_fiber_points,
17)
19BOX_SIZE: np.ndarray = np.array(3 * [600.0])
20"""Bounding box size for dimensionality reduction trajectory."""
23def _rgb_to_hex_color(color: tuple[float, float, float]) -> str:
24 """
25 Convert RGB color to hexadecimal format.
27 Parameters
28 ----------
29 color
30 Red, green, and blue colors (between 0.0 and 1.0).
32 Returns
33 -------
34 :
35 Color in hexadecimal format.
36 """
37 rgb = (int(255 * color[0]), int(255 * color[1]), int(255 * color[2]))
38 return f"#{rgb[0]:02X}{rgb[1]:02X}{rgb[2]:02X}"
41def _pca_fiber_points_over_time(
42 samples: list[np.ndarray],
43 pca: PCA,
44 pc_ix: int,
45 simulator_name: str = "Combined",
46 color: str = "#eaeaea",
47) -> tuple[list[np.ndarray], list[str], dict[str, DisplayData]]:
48 """
49 Get fiber_points for samples of the PC distributions in order to visualize
50 the samples over time.
51 """
52 if simulator_name == "Combined":
53 simulator_name = ""
54 if simulator_name:
55 simulator_name += "#"
56 fiber_points: list[np.ndarray] = []
57 display_data: dict[str, DisplayData] = {}
58 for sample_ix in range(len(samples[0])):
59 if pc_ix < 1:
60 data = [samples[0][sample_ix], 0]
61 else:
62 data = [0, samples[1][sample_ix]]
63 fiber_points.append(pca.inverse_transform(data).reshape(-1, 3))
64 fiber_points_arr: np.ndarray = np.array(fiber_points)
65 type_name: str = f"{simulator_name}PC{pc_ix + 1}"
66 display_data[type_name] = DisplayData(
67 name=type_name,
68 display_type=DISPLAY_TYPE.FIBER,
69 color=color,
70 )
71 return [fiber_points_arr], [type_name], display_data
74def _pca_fiber_points_one_timestep(
75 samples: list[np.ndarray],
76 pca: PCA,
77 color_maps: dict[str, Colormap],
78 pc_ix: int,
79 simulator_name: str = "Combined",
80) -> tuple[list[np.ndarray], list[str], dict[str, DisplayData]]:
81 """
82 Get fiber_points for samples of the PC distributions in order to visualize
83 the samples together in one timestep.
84 """
85 color_map = color_maps[simulator_name]
86 if simulator_name == "Combined":
87 simulator_name = ""
88 if simulator_name:
89 simulator_name += "_"
91 fiber_points = []
92 type_names = []
93 display_data = {}
94 for sample_ix in range(len(samples[0])):
95 data = [
96 [samples[0][sample_ix], 0],
97 [0, samples[1][sample_ix]],
98 ]
99 fiber_points.append(pca.inverse_transform(data[pc_ix]).reshape(1, -1, 3))
100 sample = samples[pc_ix][sample_ix]
101 sample_name = str(round(sample))
102 type_name = f"{simulator_name}PC{pc_ix + 1}#{sample_name}"
103 type_names.append(type_name)
104 if type_name not in display_data:
105 color_range = -samples[pc_ix][0]
106 display_data[type_name] = DisplayData(
107 name=type_name,
108 display_type=DISPLAY_TYPE.FIBER,
109 color=_rgb_to_hex_color(color_map(abs(sample) / color_range)),
110 )
111 return fiber_points, type_names, display_data
114def _generate_simularium_and_save(
115 name: str,
116 fiber_points: list[np.ndarray],
117 type_names: list[str],
118 display_data: dict[str, DisplayData],
119 distribution_over_time: bool,
120 simulator_detail: bool,
121 bucket: str,
122 temp_path: str,
123 pc: str,
124) -> None:
125 """Generate a simulariumio object for the fiber points and save it."""
126 meta_data = MetaData(
127 box_size=BOX_SIZE,
128 camera_defaults=CameraData(
129 position=np.array([70.0, 70.0, 300.0]),
130 look_at_position=np.array([70.0, 70.0, 0.0]),
131 fov_degrees=60.0,
132 ),
133 trajectory_title="Actin Compression Dimensionality Reduction",
134 )
135 time_units = UnitData("count") # frames
136 spatial_units = UnitData("nm") # nanometers
137 converter = generate_trajectory_converter_for_fiber_points(
138 fiber_points,
139 type_names,
140 meta_data,
141 display_data,
142 time_units,
143 spatial_units,
144 fiber_radius=1.0,
145 )
147 # Save locally and copy to bucket.
148 output_key = name
149 output_key += "_time" if distribution_over_time else ""
150 output_key += "_simulators" if simulator_detail else ""
151 output_key += f"_pc{pc}" if pc else ""
152 local_file_path = os.path.join(temp_path, output_key)
153 converter.save(output_path=local_file_path)
154 output_key = f"{output_key}.simularium"
155 save_buffer(bucket, f"{name}/{output_key}", load_buffer(temp_path, output_key))
158def visualize_dimensionality_reduction(
159 bucket: str,
160 pca_results_key: str,
161 pca_pickle_key: str,
162 distribution_over_time: bool,
163 simulator_detail: bool,
164 sample_ranges: dict[str, list[list[float]]],
165 separate_pcs: bool,
166 sample_resolution: int,
167 temp_path: str,
168) -> None:
169 """
170 Visualize PCA space for actin fibers.
172 Parameters
173 ----------
174 bucket
175 Name of S3 bucket for input and output files.
176 pca_results_key
177 File key for PCA results dataframe.
178 pca_pickle_key
179 File key for PCA object pickle.
180 distribution_over_time
181 True to scroll through the PC distributions over time, False otherwise.
182 simulator_detail
183 True to show individual simulator ranges, False otherwise.
184 sample_ranges
185 Min and max values to visualize for each PC (and each simulator if
186 simulator_detail).
187 separate_pcs
188 True to Visualize PCs in separate files, False otherwise.
189 sample_resolution
190 Number of samples for each PC distribution.
191 temp_path
192 Local path for saving visualization output files.
193 """
194 pca_results = load_dataframe(bucket, pca_results_key)
195 pca = load_pickle(bucket, pca_pickle_key)
197 fiber_points: list[list[np.ndarray]] = [[], []]
198 type_names: list[list[str]] = [[], []]
199 display_data: list[dict[str, DisplayData]] = [{}, {}]
200 pca_results_simulators = {
201 "Combined": pca_results,
202 }
203 if simulator_detail:
204 pca_results_simulators["ReaDDy"] = pca_results.loc[
205 pca_results["SIMULATOR"] == "READDY"
206 ]
207 pca_results_simulators["Cytosim"] = pca_results.loc[
208 pca_results["SIMULATOR"] == "CYTOSIM"
209 ]
210 color_maps = {
211 "Combined": plt.colormaps.get_cmap("RdPu"),
212 "ReaDDy": plt.colormaps.get_cmap("YlOrRd"),
213 "Cytosim": plt.colormaps.get_cmap("GnBu"),
214 }
215 over_time_colors = {
216 "Combined": "#ffffff",
217 "ReaDDy": "#ff8f52",
218 "Cytosim": "#1cbfaa",
219 }
220 dataset_name = os.path.splitext(pca_pickle_key)[0]
221 pc_ixs = list(range(2))
222 for simulator in pca_results_simulators:
223 samples = [
224 np.arange(
225 sample_ranges[simulator][0][0],
226 sample_ranges[simulator][0][1],
227 (sample_ranges[simulator][0][1] - sample_ranges[simulator][0][0])
228 / float(sample_resolution),
229 ),
230 np.arange(
231 sample_ranges[simulator][1][0],
232 sample_ranges[simulator][1][1],
233 (sample_ranges[simulator][1][1] - sample_ranges[simulator][1][0])
234 / float(sample_resolution),
235 ),
236 ]
237 for pc_ix in pc_ixs:
238 if distribution_over_time:
239 _fiber_points, _type_names, _display_data = _pca_fiber_points_over_time(
240 samples, pca, pc_ix, simulator, over_time_colors[simulator]
241 )
242 else:
243 _fiber_points, _type_names, _display_data = (
244 _pca_fiber_points_one_timestep(
245 samples, pca, color_maps, pc_ix, simulator
246 )
247 )
248 if separate_pcs:
249 fiber_points[pc_ix] += _fiber_points
250 type_names[pc_ix] += _type_names
251 display_data[pc_ix] = {**display_data[pc_ix], **_display_data}
252 else:
253 fiber_points[0] += _fiber_points
254 type_names[0] += _type_names
255 display_data[0] = {**display_data[0], **_display_data}
256 if separate_pcs:
257 for pc_ix in pc_ixs:
258 _generate_simularium_and_save(
259 dataset_name,
260 fiber_points[pc_ix],
261 type_names[pc_ix],
262 display_data[pc_ix],
263 distribution_over_time,
264 simulator_detail,
265 bucket,
266 temp_path,
267 str(pc_ix + 1),
268 )
269 else:
270 _generate_simularium_and_save(
271 dataset_name,
272 fiber_points[0],
273 type_names[0],
274 display_data[0],
275 distribution_over_time,
276 simulator_detail,
277 bucket,
278 temp_path,
279 "",
280 )