Coverage for subcell_pipeline/analysis/dimensionality_reduction/pca_dim_reduction.py: 0%

81 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Methods for dimensionality reduction using PCA.""" 

2 

3import random 

4from typing import Optional 

5 

6import matplotlib.pyplot as plt 

7import numpy as np 

8import pandas as pd 

9from io_collection.save.save_dataframe import save_dataframe 

10from io_collection.save.save_figure import save_figure 

11from io_collection.save.save_json import save_json 

12from sklearn.decomposition import PCA 

13 

14from subcell_pipeline.analysis.dimensionality_reduction.fiber_data import reshape_fibers 

15 

16 

17def run_pca(data: pd.DataFrame) -> tuple[pd.DataFrame, PCA]: 

18 """ 

19 Run Principal Component Analysis (PCA) on simulation data. 

20 

21 Parameters 

22 ---------- 

23 data 

24 Simulated fiber data. 

25 

26 Returns 

27 ------- 

28 : 

29 Dataframe with PCA components appended and the PCA object. 

30 """ 

31 

32 all_fibers, all_features = reshape_fibers(data) 

33 

34 pca = PCA(n_components=2) 

35 pca = pca.fit(all_fibers) 

36 transform = pca.transform(all_fibers) 

37 

38 pca_results = pd.concat( 

39 [pd.DataFrame(transform, columns=["PCA1", "PCA2"]), all_features], 

40 axis=1, 

41 ) 

42 

43 return pca_results, pca 

44 

45 

46def save_pca_results( 

47 pca_results: pd.DataFrame, save_location: str, save_key: str, resample: bool = True 

48) -> None: 

49 """ 

50 Save PCA results data. 

51 

52 Parameters 

53 ---------- 

54 pca_results 

55 PCA trajectory data. 

56 save_location 

57 Location for output file (local path or S3 bucket). 

58 save_key 

59 Name key for output file. 

60 resample 

61 True if data should be resampled before saving, False otherwise. 

62 """ 

63 

64 if resample: 

65 pca_results = pca_results.copy().sample(frac=1.0, random_state=1) 

66 

67 save_dataframe(save_location, save_key, pca_results, index=False) 

68 

69 

70def save_pca_trajectories( 

71 pca_results: pd.DataFrame, save_location: str, save_key: str 

72) -> None: 

73 """ 

74 Save PCA trajectories data. 

75 

76 Parameters 

77 ---------- 

78 pca_results 

79 PCA trajectory data. 

80 save_location 

81 Location for output file (local path or S3 bucket). 

82 save_key 

83 Name key for output file. 

84 """ 

85 

86 output = [] 

87 

88 for (simulator, repeat, velocity), group in pca_results.groupby( 

89 ["SIMULATOR", "REPEAT", "VELOCITY"] 

90 ): 

91 output.append( 

92 { 

93 "simulator": simulator.upper(), 

94 "replicate": int(repeat), 

95 "velocity": velocity, 

96 "x": group["PCA1"].tolist(), 

97 "y": group["PCA2"].tolist(), 

98 } 

99 ) 

100 

101 random.Random(1).shuffle(output) 

102 save_json(save_location, save_key, output) 

103 

104 

105def save_pca_transforms( 

106 pca: PCA, points: list[list[float]], save_location: str, save_key: str 

107) -> None: 

108 """ 

109 Save PCA transform data. 

110 

111 Parameters 

112 ---------- 

113 pca 

114 PCA object. 

115 points 

116 List of inverse transform points. 

117 save_location 

118 Location for output file (local path or S3 bucket). 

119 save_key 

120 Name key for output file. 

121 """ 

122 

123 output = [] 

124 

125 pc1_points, pc2_points = points 

126 

127 for point in pc1_points: 

128 fiber = pca.inverse_transform([point, 0]).reshape(-1, 3) 

129 output.append( 

130 { 

131 "component": 1, 

132 "point": point, 

133 "x": fiber[:, 0].tolist(), 

134 "y": fiber[:, 1].tolist(), 

135 "z": fiber[:, 2].tolist(), 

136 } 

137 ) 

138 

139 for point in pc2_points: 

140 fiber = pca.inverse_transform([0, point]).reshape(-1, 3) 

141 output.append( 

142 { 

143 "component": 2, 

144 "point": point, 

145 "x": fiber[:, 0].tolist(), 

146 "y": fiber[:, 1].tolist(), 

147 "z": fiber[:, 2].tolist(), 

148 } 

149 ) 

150 

151 save_json(save_location, save_key, output) 

152 

153 

154def plot_pca_feature_scatter( 

155 data: pd.DataFrame, 

156 features: dict, 

157 pca: PCA, 

158 save_location: Optional[str] = None, 

159 save_key: str = "pca_feature_scatter.png", 

160) -> None: 

161 """ 

162 Plot scatter of PCA components colored by the given features. 

163 

164 Parameters 

165 ---------- 

166 data 

167 PCA results data. 

168 features 

169 Map of feature name to coloring. 

170 pca 

171 PCA object. 

172 save_location 

173 Location for output file (local path or S3 bucket). 

174 save_key 

175 Name key for output file. 

176 """ 

177 

178 figure, ax = plt.subplots( 

179 1, len(features), figsize=(10, 3), sharey=True, sharex=True 

180 ) 

181 

182 for index, (feature, colors) in enumerate(features.items()): 

183 if isinstance(colors, dict): 

184 ax[index].scatter( 

185 data["PCA1"], 

186 data["PCA2"], 

187 s=2, 

188 c=data[feature].map(colors), 

189 ) 

190 elif isinstance(colors, tuple): 

191 ax[index].scatter( 

192 data["PCA1"], 

193 data["PCA2"], 

194 s=2, 

195 c=data[feature].map(colors[0]), 

196 cmap=colors[1], 

197 ) 

198 else: 

199 ax[index].scatter( 

200 data["PCA1"], 

201 data["PCA2"], 

202 s=2, 

203 c=data[feature], 

204 cmap=colors, 

205 ) 

206 

207 ax[index].set_title(feature) 

208 ax[index].set_xlabel(f"PCA1 ({(pca.explained_variance_ratio_[0] * 100):.1f} %)") 

209 ax[index].set_ylabel(f"PCA2 ({(pca.explained_variance_ratio_[1] * 100):.1f} %)") 

210 

211 plt.tight_layout() 

212 plt.show() 

213 

214 if save_location is not None: 

215 save_figure(save_location, save_key, figure) 

216 

217 

218def plot_pca_inverse_transform( 

219 pca: PCA, 

220 pca_results: pd.DataFrame, 

221 save_location: Optional[str] = None, 

222 save_key: str = "pca_inverse_transform.png", 

223) -> None: 

224 """ 

225 Plot inverse transform of PCA. 

226 

227 Parameters 

228 ---------- 

229 pca 

230 PCA object. 

231 pca_results 

232 PCA results data. 

233 save_location 

234 Location for output file (local path or S3 bucket). 

235 save_key 

236 Name key for output file. 

237 """ 

238 

239 figure, ax = plt.subplots(2, 3, figsize=(10, 6)) 

240 

241 points = np.arange(-2, 2, 0.5) 

242 stdev_pc1 = pca_results["PCA1"].std(ddof=0) 

243 stdev_pc2 = pca_results["PCA2"].std(ddof=0) 

244 cmap = plt.colormaps.get_cmap("RdBu_r") 

245 

246 for point in points: 

247 # Traverse PC 1 

248 fiber = pca.inverse_transform([point * stdev_pc1, 0]).reshape(-1, 3) 

249 ax[0, 0].plot(fiber[:, 0], fiber[:, 1], color=cmap((point + 2) / 4)) 

250 ax[0, 1].plot(fiber[:, 1], fiber[:, 2], color=cmap((point + 2) / 4)) 

251 ax[0, 2].plot(fiber[:, 0], fiber[:, 2], color=cmap((point + 2) / 4)) 

252 

253 # Traverse PC 2 

254 fiber = pca.inverse_transform([0, point * stdev_pc2]).reshape(-1, 3) 

255 ax[1, 0].plot(fiber[:, 0], fiber[:, 1], color=cmap((point + 2) / 4)) 

256 ax[1, 1].plot(fiber[:, 1], fiber[:, 2], color=cmap((point + 2) / 4)) 

257 ax[1, 2].plot(fiber[:, 0], fiber[:, 2], color=cmap((point + 2) / 4)) 

258 

259 for index in [0, 1]: 

260 ax[index, 0].set_xlabel("X") 

261 ax[index, 0].set_ylabel("Y", rotation=0) 

262 ax[index, 1].set_xlabel("Y") 

263 ax[index, 1].set_ylabel("Z", rotation=0) 

264 ax[index, 2].set_xlabel("X") 

265 ax[index, 2].set_ylabel("Z", rotation=0) 

266 

267 for index in [0, 1, 2]: 

268 ax[0, index].set_title("PC1") 

269 ax[1, index].set_title("PC2") 

270 

271 plt.tight_layout() 

272 plt.show() 

273 

274 if save_location is not None: 

275 save_figure(save_location, save_key, figure)