Coverage for subcell_pipeline/analysis/dimensionality_reduction/pca_dim_reduction.py: 0%
81 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
1"""Methods for dimensionality reduction using PCA."""
3import random
4from typing import Optional
6import matplotlib.pyplot as plt
7import numpy as np
8import pandas as pd
9from io_collection.save.save_dataframe import save_dataframe
10from io_collection.save.save_figure import save_figure
11from io_collection.save.save_json import save_json
12from sklearn.decomposition import PCA
14from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import reshape_fibers
17def run_pca(data: pd.DataFrame) -> tuple[pd.DataFrame, PCA]:
18 """
19 Run Principal Component Analysis (PCA) on simulation data.
21 Parameters
22 ----------
23 data
24 Simulated fiber data.
26 Returns
27 -------
28 :
29 Dataframe with PCA components appended and the PCA object.
30 """
32 all_fibers, all_features = reshape_fibers(data)
34 pca = PCA(n_components=2)
35 pca = pca.fit(all_fibers)
36 transform = pca.transform(all_fibers)
38 pca_results = pd.concat(
39 [pd.DataFrame(transform, columns=["PCA1", "PCA2"]), all_features],
40 axis=1,
41 )
43 return pca_results, pca
46def save_pca_results(
47 pca_results: pd.DataFrame, save_location: str, save_key: str, resample: bool = True
48) -> None:
49 """
50 Save PCA results data.
52 Parameters
53 ----------
54 pca_results
55 PCA trajectory data.
56 save_location
57 Location for output file (local path or S3 bucket).
58 save_key
59 Name key for output file.
60 resample
61 True if data should be resampled before saving, False otherwise.
62 """
64 if resample:
65 pca_results = pca_results.copy().sample(frac=1.0, random_state=1)
67 save_dataframe(save_location, save_key, pca_results, index=False)
70def save_pca_trajectories(
71 pca_results: pd.DataFrame, save_location: str, save_key: str
72) -> None:
73 """
74 Save PCA trajectories data.
76 Parameters
77 ----------
78 pca_results
79 PCA trajectory data.
80 save_location
81 Location for output file (local path or S3 bucket).
82 save_key
83 Name key for output file.
84 """
86 output = []
88 for (simulator, repeat, velocity), group in pca_results.groupby(
89 ["SIMULATOR", "REPEAT", "VELOCITY"]
90 ):
91 output.append(
92 {
93 "simulator": simulator.upper(),
94 "replicate": int(repeat),
95 "velocity": velocity,
96 "x": group["PCA1"].tolist(),
97 "y": group["PCA2"].tolist(),
98 }
99 )
101 random.Random(1).shuffle(output)
102 save_json(save_location, save_key, output)
105def save_pca_transforms(
106 pca: PCA, points: list[list[float]], save_location: str, save_key: str
107) -> None:
108 """
109 Save PCA transform data.
111 Parameters
112 ----------
113 pca
114 PCA object.
115 points
116 List of inverse transform points.
117 save_location
118 Location for output file (local path or S3 bucket).
119 save_key
120 Name key for output file.
121 """
123 output = []
125 pc1_points, pc2_points = points
127 for point in pc1_points:
128 fiber = pca.inverse_transform([point, 0]).reshape(-1, 3)
129 output.append(
130 {
131 "component": 1,
132 "point": point,
133 "x": fiber[:, 0].tolist(),
134 "y": fiber[:, 1].tolist(),
135 "z": fiber[:, 2].tolist(),
136 }
137 )
139 for point in pc2_points:
140 fiber = pca.inverse_transform([0, point]).reshape(-1, 3)
141 output.append(
142 {
143 "component": 2,
144 "point": point,
145 "x": fiber[:, 0].tolist(),
146 "y": fiber[:, 1].tolist(),
147 "z": fiber[:, 2].tolist(),
148 }
149 )
151 save_json(save_location, save_key, output)
154def plot_pca_feature_scatter(
155 data: pd.DataFrame,
156 features: dict,
157 pca: PCA,
158 save_location: Optional[str] = None,
159 save_key: str = "pca_feature_scatter.png",
160) -> None:
161 """
162 Plot scatter of PCA components colored by the given features.
164 Parameters
165 ----------
166 data
167 PCA results data.
168 features
169 Map of feature name to coloring.
170 pca
171 PCA object.
172 save_location
173 Location for output file (local path or S3 bucket).
174 save_key
175 Name key for output file.
176 """
178 figure, ax = plt.subplots(
179 1, len(features), figsize=(10, 3), sharey=True, sharex=True
180 )
182 for index, (feature, colors) in enumerate(features.items()):
183 if isinstance(colors, dict):
184 ax[index].scatter(
185 data["PCA1"],
186 data["PCA2"],
187 s=2,
188 c=data[feature].map(colors),
189 )
190 elif isinstance(colors, tuple):
191 ax[index].scatter(
192 data["PCA1"],
193 data["PCA2"],
194 s=2,
195 c=data[feature].map(colors[0]),
196 cmap=colors[1],
197 )
198 else:
199 ax[index].scatter(
200 data["PCA1"],
201 data["PCA2"],
202 s=2,
203 c=data[feature],
204 cmap=colors,
205 )
207 ax[index].set_title(feature)
208 ax[index].set_xlabel(f"PCA1 ({(pca.explained_variance_ratio_[0] * 100):.1f} %)")
209 ax[index].set_ylabel(f"PCA2 ({(pca.explained_variance_ratio_[1] * 100):.1f} %)")
211 plt.tight_layout()
212 plt.show()
214 if save_location is not None:
215 save_figure(save_location, save_key, figure)
218def plot_pca_inverse_transform(
219 pca: PCA,
220 pca_results: pd.DataFrame,
221 save_location: Optional[str] = None,
222 save_key: str = "pca_inverse_transform.png",
223) -> None:
224 """
225 Plot inverse transform of PCA.
227 Parameters
228 ----------
229 pca
230 PCA object.
231 pca_results
232 PCA results data.
233 save_location
234 Location for output file (local path or S3 bucket).
235 save_key
236 Name key for output file.
237 """
239 figure, ax = plt.subplots(2, 3, figsize=(10, 6))
241 points = np.arange(-2, 2, 0.5)
242 stdev_pc1 = pca_results["PCA1"].std(ddof=0)
243 stdev_pc2 = pca_results["PCA2"].std(ddof=0)
244 cmap = plt.colormaps.get_cmap("RdBu_r")
246 for point in points:
247 # Traverse PC 1
248 fiber = pca.inverse_transform([point * stdev_pc1, 0]).reshape(-1, 3)
249 ax[0, 0].plot(fiber[:, 0], fiber[:, 1], color=cmap((point + 2) / 4))
250 ax[0, 1].plot(fiber[:, 1], fiber[:, 2], color=cmap((point + 2) / 4))
251 ax[0, 2].plot(fiber[:, 0], fiber[:, 2], color=cmap((point + 2) / 4))
253 # Traverse PC 2
254 fiber = pca.inverse_transform([0, point * stdev_pc2]).reshape(-1, 3)
255 ax[1, 0].plot(fiber[:, 0], fiber[:, 1], color=cmap((point + 2) / 4))
256 ax[1, 1].plot(fiber[:, 1], fiber[:, 2], color=cmap((point + 2) / 4))
257 ax[1, 2].plot(fiber[:, 0], fiber[:, 2], color=cmap((point + 2) / 4))
259 for index in [0, 1]:
260 ax[index, 0].set_xlabel("X")
261 ax[index, 0].set_ylabel("Y", rotation=0)
262 ax[index, 1].set_xlabel("Y")
263 ax[index, 1].set_ylabel("Z", rotation=0)
264 ax[index, 2].set_xlabel("X")
265 ax[index, 2].set_ylabel("Z", rotation=0)
267 for index in [0, 1, 2]:
268 ax[0, index].set_title("PC1")
269 ax[1, index].set_title("PC2")
271 plt.tight_layout()
272 plt.show()
274 if save_location is not None:
275 save_figure(save_location, save_key, figure)