Coverage for subcell_pipeline/analysis/dimensionality_reduction/fiber_data.py: 18%
85 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
1"""Methods for fiber data merging and alignment."""
3from typing import Optional
5import matplotlib.pyplot as plt
6import numpy as np
7import pandas as pd
8from io_collection.keys.check_key import check_key
9from io_collection.load.load_dataframe import load_dataframe
10from io_collection.save.save_dataframe import save_dataframe
11from io_collection.save.save_figure import save_figure
12from io_collection.save.save_json import save_json
15def get_merged_data(
16 bucket: str,
17 series_name: str,
18 condition_keys: list[str],
19 random_seeds: list[int],
20 align: bool = True,
21) -> pd.DataFrame:
22 """
23 Load or create merged data for given conditions and random seeds.
25 If merged data (aligned or unaligned) already exists, load the data.
26 Otherwise, iterate through the conditions and seeds to merge the data.
28 Parameters
29 ----------
30 bucket
31 Name of S3 bucket for input and output files.
32 series_name
33 Name of simulation series.
34 condition_keys
35 List of condition keys.
36 random_seeds
37 Random seeds for simulations.
38 align
39 True if data should be aligned, False otherwise.
41 Returns
42 -------
43 :
44 Merged data.
45 """
47 align_key = "all_samples_aligned" if align else "all_samples_unaligned"
48 data_key = f"{series_name}/analysis/{series_name}_{align_key}.csv"
50 # Return data, if merged data already exists.
51 if check_key(bucket, data_key):
52 print(
53 f"Dataframe [ { data_key } ] already exists. Loading existing merged data."
54 )
55 return load_dataframe(bucket, data_key, dtype={"key": "str"})
57 all_samples: list[pd.DataFrame] = []
59 for condition_key in condition_keys:
60 series_key = f"{series_name}_{condition_key}" if condition_key else series_name
62 for seed in random_seeds:
63 print(f"Loading samples for [ {condition_key} ] seed [ {seed} ]")
65 sample_key = f"{series_name}/samples/{series_key}_{seed:06d}.csv"
66 samples = load_dataframe(bucket, sample_key)
67 samples["seed"] = seed
68 samples["key"] = condition_key
70 if align:
71 align_fibers(samples)
73 all_samples.append(samples)
75 samples_dataframe = pd.concat(all_samples)
76 save_dataframe(bucket, data_key, samples_dataframe, index=False)
78 return samples_dataframe
81def align_fibers(data: pd.DataFrame) -> None:
82 """
83 Align fibers for each time point in the data.
85 Parameters
86 ----------
87 data
88 Simulated fiber data.
89 """
91 aligned_fibers = []
93 for time, group in data.groupby("time", sort=False):
94 coords = group[["xpos", "ypos", "zpos"]].values
96 if time == 0:
97 fiber = coords
98 else:
99 fiber, _ = align_fiber(coords)
101 aligned_fibers.append(fiber)
103 all_aligned_fibers = np.vstack(aligned_fibers)
105 data["xpos"] = all_aligned_fibers[:, 0]
106 data["ypos"] = all_aligned_fibers[:, 1]
107 data["zpos"] = all_aligned_fibers[:, 2]
110def align_fiber(coords: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
111 """
112 Align an array of x, y, z coordinates along the positive y axis.
114 The function identifies the furthest point in the yz-plane and computes the
115 angle needed to rotate this point to lie on the positive y axis. This
116 rotation angle is applied to all y and z coordinates; x coordinates are not
117 changed. For example, if the furthest point is (0.5, 0, 1), it is rotated to
118 (0.5, 1, 0) with an angle of pi / 2.
120 Parameters
121 ----------
122 coords
123 Array of x, y, and z positions.
124 """
126 # Identify rotation angle based on distance to point furthest from (0,0)
127 distances = np.sqrt(np.sum(coords[:, 1:] ** 2, axis=1))
128 max_index = np.argmax(distances)
129 angle = np.arctan2(coords[max_index, 2], coords[max_index, 1])
131 # Create rotation matrix
132 c, s = np.cos(angle), np.sin(angle)
133 rot = np.array(((c, -s), (s, c)))
135 # Rotate y and z
136 rotated = np.dot(coords[:, 1:], rot)
138 return np.concatenate((coords[:, 0:1], rotated), axis=1), rot
141def reshape_fibers(data: pd.DataFrame) -> tuple[np.ndarray, pd.DataFrame]:
142 """
143 Reshape data from tidy data format to array of fibers and fiber features.
145 Parameters
146 ----------
147 data
148 Simulated fiber data.
150 Returns
151 -------
152 :
153 Array of fibers and dataframe of fiber features.
154 """
156 all_features = []
157 all_fibers = []
159 for (time, velocity, repeat, simulator), group in data.groupby(
160 ["time", "velocity", "repeat", "simulator"]
161 ):
162 fiber = group[["xpos", "ypos", "zpos"]].values.reshape(-1, 1)
163 all_fibers.append(fiber)
164 all_features.append(
165 {
166 "TIME": time,
167 "VELOCITY": velocity,
168 "REPEAT": repeat,
169 "SIMULATOR": simulator.upper(),
170 }
171 )
173 return np.array(all_fibers).squeeze(), pd.DataFrame(all_features)
176def save_aligned_fibers(
177 data: pd.DataFrame, time_map: dict, save_location: str, save_key: str
178) -> None:
179 """
180 Save aligned fiber data.
182 Parameters
183 ----------
184 data
185 Aligned fiber data.
186 time_map
187 Map of selected aligned time for each simulator and condition.
188 save_location
189 Location for output file (local path or S3 bucket).
190 save_key
191 Name key for output file.
192 """
194 output = []
196 for (simulator, repeat, key, time), group in data.groupby(
197 ["simulator", "repeat", "key", "time"]
198 ):
199 if time != time_map[(simulator, key)]:
200 continue
202 fiber = group[["xpos", "ypos", "zpos"]].values
203 output.append(
204 {
205 "simulator": simulator.upper(),
206 "repeat": int(repeat),
207 "key": key,
208 "x": fiber[:, 0].tolist(),
209 "y": fiber[:, 1].tolist(),
210 "z": fiber[:, 2].tolist(),
211 }
212 )
214 save_json(save_location, save_key, output)
217def plot_fibers_by_key_and_seed(
218 data: pd.DataFrame,
219 save_location: Optional[str] = None,
220 save_key: str = "aligned_fibers_by_key_and_seed.png",
221) -> None:
222 """
223 Plot simulated fiber data for each condition key and random seed.
225 Parameters
226 ----------
227 data
228 Simulated fiber data.
229 save_location
230 Location for output file (local path or S3 bucket).
231 save_key
232 Name key for output file.
233 """
235 rows = data["key"].unique()
236 cols = data["seed"].unique()
238 figure, ax = plt.subplots(
239 len(rows), len(cols), figsize=(10, 6), sharey=True, sharex=True
240 )
242 for row_index, row in enumerate(rows):
243 for col_index, col in enumerate(cols):
244 if row_index == 0:
245 ax[row_index, col_index].set_title(f"REPEAT = {col}")
246 if col_index == 0:
247 ax[row_index, col_index].set_ylabel(f"KEY = {row}")
249 subset = data[(data["key"] == row) & (data["seed"] == col)]
251 for (_, simulator), group in subset.groupby(["time", "simulator"]):
252 color = "red" if simulator == "readdy" else "blue"
253 coords = group[["xpos", "ypos", "zpos"]].values
254 ax[row_index, col_index].plot(
255 coords[:, 1], coords[:, 2], lw=0.5, color=color, alpha=0.5
256 )
258 plt.tight_layout()
259 plt.show()
261 if save_location is not None:
262 save_figure(save_location, save_key, figure)