Coverage for subcell_pipeline/analysis/tomography_data/tomography_data.py: 0%
86 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
1"""Methods for analyzing tomography data."""
3from typing import Optional
5import matplotlib.pyplot as plt
6import numpy as np
7import pandas as pd
8from io_collection.keys.check_key import check_key
9from io_collection.load.load_dataframe import load_dataframe
10from io_collection.save.save_dataframe import save_dataframe
11from io_collection.save.save_figure import save_figure
13TOMOGRAPHY_SAMPLE_COLUMNS: list[str] = ["xpos", "ypos", "zpos"]
14"""Columns names used when sampling tomography data."""
17def test_consecutive_segment_angles(polymer_trace: np.ndarray) -> bool:
18 """
19 Test if all angles between consecutive segments of a polymer trace are less
20 than 90 degrees.
22 Parameters
23 ----------
24 polymer_trace
25 A 2D array where each row is a point in 3D space.
27 Returns
28 -------
29 :
30 True if all consecutive angles are less than 90 degrees, False
31 otherwise.
32 """
33 vectors = polymer_trace[1:] - polymer_trace[:-1]
35 vectors /= np.linalg.norm(vectors, axis=1)[:, np.newaxis]
36 dot_products = np.dot(vectors[1:], vectors[:-1].T)
38 return np.all(dot_products > 0).item()
41def read_tomography_data(file: str, label: str = "fil") -> pd.DataFrame:
42 """
43 Read tomography data from file as dataframe.
45 Parameters
46 ----------
47 file
48 Path to tomography data.
49 label
50 Label for the filament id column.
52 Returns
53 -------
54 :
55 Dataframe of tomography data.
56 """
58 coordinates = pd.read_table(file, delim_whitespace=True)
60 if len(coordinates.columns) == 4:
61 coordinates.columns = [label, "xpos", "ypos", "zpos"]
62 elif len(coordinates.columns) == 5:
63 coordinates.columns = ["object", label, "xpos", "ypos", "zpos"]
64 else:
65 print(f"Data file [ {file} ] has an unexpected number of columns")
67 return coordinates
70def rescale_tomography_data(data: pd.DataFrame, scale_factor: float = 1.0) -> None:
71 """
72 Rescale tomography data from pixels to um.
74 Parameters
75 ----------
76 data
77 Unscaled tomography data.
78 scale_factor
79 Data scaling factor (pixels to um).
80 """
82 data["xpos"] = data["xpos"] * scale_factor
83 data["ypos"] = data["ypos"] * scale_factor
84 data["zpos"] = data["zpos"] * scale_factor
87def get_branched_tomography_data(
88 bucket: str,
89 name: str,
90 repository: str,
91 datasets: list[tuple[str, str]],
92 scale_factor: float = 1.0,
93) -> pd.DataFrame:
94 """
95 Load or create merged branched actin tomography data for given datasets.
97 Parameters
98 ----------
99 bucket
100 Name of S3 bucket for input and output files.
101 name
102 Name of dataset.
103 repository
104 Data repository for downloading tomography data.
105 datasets
106 Folders and names of branched actin datasets.
107 scale_factor
108 Data scaling factor (pixels to um).
110 Returns
111 -------
112 :
113 Merged branched tomography data.
114 """
116 return get_tomography_data(
117 bucket, name, repository, datasets, "branched", scale_factor
118 )
121def get_unbranched_tomography_data(
122 bucket: str,
123 name: str,
124 repository: str,
125 datasets: list[tuple[str, str]],
126 scale_factor: float = 1.0,
127) -> pd.DataFrame:
128 """
129 Load or create merged unbranched actin tomography data for given datasets.
131 Parameters
132 ----------
133 bucket
134 Name of S3 bucket for input and output files.
135 name
136 Name of dataset.
137 repository
138 Data repository for downloading tomography data.
139 datasets
140 Folders and names of branched actin datasets.
141 scale_factor
142 Data scaling factor (pixels to um).
144 Returns
145 -------
146 :
147 Merged unbranched tomography data.
148 """
150 return get_tomography_data(
151 bucket, name, repository, datasets, "unbranched", scale_factor
152 )
155def get_tomography_data(
156 bucket: str,
157 name: str,
158 repository: str,
159 datasets: list[tuple[str, str]],
160 group: str,
161 scale_factor: float = 1.0,
162) -> pd.DataFrame:
163 """
164 Load or create merged tomography data for given datasets.
166 Parameters
167 ----------
168 bucket
169 Name of S3 bucket for input and output files.
170 name
171 Name of dataset.
172 repository
173 Data repository for downloading tomography data.
174 datasets
175 Folders and names of branched actin datasets.
176 group
177 Actin filament group ("branched" or "unbranched").
178 scale_factor
179 Data scaling factor (pixels to um).
181 Returns
182 -------
183 :
184 Merged tomography data.
185 """
187 data_key = f"{name}/{name}_coordinates_{group}.csv"
189 if check_key(bucket, data_key):
190 print(f"Loading existing combined tomogram data from [ { data_key } ]")
191 return load_dataframe(bucket, data_key)
192 else:
193 all_tomogram_dfs = []
195 for folder, name in datasets:
196 print(f"Loading tomogram data for [ { name } ]")
197 tomogram_file = f"{repository}/{folder}/{group.title()}Actin_{name}.txt"
198 tomogram_df = read_tomography_data(tomogram_file)
199 tomogram_df["dataset"] = name
200 tomogram_df["id"] = tomogram_df["fil"].apply(
201 lambda row, name=name: f"{row:02d}_{name}"
202 )
203 rescale_tomography_data(tomogram_df, scale_factor)
204 all_tomogram_dfs.append(tomogram_df)
206 all_tomogram_df = pd.concat(all_tomogram_dfs)
208 print(f"Saving combined tomogram data to [ { data_key } ]")
209 save_dataframe(bucket, data_key, all_tomogram_df, index=False)
211 return all_tomogram_df
214def sample_tomography_data(
215 data: pd.DataFrame,
216 save_location: str,
217 save_key: str,
218 n_monomer_points: int,
219 minimum_points: int,
220 sampled_columns: list[str] = TOMOGRAPHY_SAMPLE_COLUMNS,
221 recalculate: bool = False,
222) -> pd.DataFrame:
223 """
224 Sample selected columns from tomography data at given resolution.
226 Parameters
227 ----------
228 data
229 Tomography data to sample.
230 save_location
231 Location to save sampled data.
232 save_key
233 File key for sampled data.
234 n_monomer_points
235 Number of equally spaced monomer points to sample.
236 minimum_points
237 Minimum number of points for valid fiber.
238 sampled_columns
239 List of column names to sample.
240 recalculate
241 True to recalculate the sampled tomography data, False otherwise.
243 Returns
244 -------
245 :
246 Sampled tomography data.
247 """
249 if check_key(save_location, save_key) and not recalculate:
250 print(f"Loading existing sampled tomogram data from [ { save_key } ]")
251 return load_dataframe(save_location, save_key)
252 else:
253 all_sampled_points = []
255 # TODO sort experimental samples in order along the fiber before resampling
256 # (see simularium visualization)
257 for fiber_id, group in data.groupby("id"):
258 if len(group) < minimum_points:
259 continue
261 sampled_points = pd.DataFrame()
262 sampled_points["monomer_ids"] = np.arange(n_monomer_points)
263 sampled_points["dataset"] = group["dataset"].unique()[0]
264 sampled_points["id"] = fiber_id
266 for column in sampled_columns:
267 sampled_points[column] = np.interp(
268 np.linspace(0, 1, n_monomer_points),
269 np.linspace(0, 1, group.shape[0]),
270 group[column].to_numpy(),
271 )
273 sampled_points["ordered"] = test_consecutive_segment_angles(
274 sampled_points[sampled_columns].to_numpy()
275 )
277 all_sampled_points.append(sampled_points)
279 all_sampled_df = pd.concat(all_sampled_points)
281 print(f"Saving sampled tomogram data to [ { save_key } ]")
282 save_dataframe(save_location, save_key, all_sampled_df, index=False)
284 return all_sampled_df
287def plot_tomography_data_by_dataset(
288 data: pd.DataFrame,
289 save_location: Optional[str] = None,
290 save_key_template: str = "tomography_data_%s.png",
291) -> None:
292 """
293 Plot tomography data for each dataset.
295 Parameters
296 ----------
297 data
298 Tomography data.
299 save_location
300 Location for output file (local path or S3 bucket).
301 save_key_template
302 Name key template for output file.
303 """
305 for dataset, group in data.groupby("dataset"):
306 figure, ax = plt.subplots(1, 3, figsize=(6, 2))
307 ax[1].set_title(dataset)
309 views = ["XY", "XZ", "YZ"]
310 for index, view in enumerate(views):
311 ax[index].set_xticks([])
312 ax[index].set_yticks([])
313 ax[index].set_xlabel(view[0])
314 ax[index].set_ylabel(view[1], rotation=0)
316 for _, fiber in group.groupby("id"):
317 ax[0].plot(fiber["xpos"], fiber["ypos"], marker="o", ms=1, lw=1)
318 ax[1].plot(fiber["xpos"], fiber["zpos"], marker="o", ms=1, lw=1)
319 ax[2].plot(fiber["ypos"], fiber["zpos"], marker="o", ms=1, lw=1)
321 if save_location is not None:
322 save_key = save_key_template % dataset
323 save_figure(save_location, save_key, figure)