Coverage for subcell_pipeline/analysis/dimensionality_reduction/fiber_data.py: 18%

85 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Methods for fiber data merging and alignment.""" 

2 

3from typing import Optional 

4 

5import matplotlib.pyplot as plt 

6import numpy as np 

7import pandas as pd 

8from io_collection.keys.check_key import check_key 

9from io_collection.load.load_dataframe import load_dataframe 

10from io_collection.save.save_dataframe import save_dataframe 

11from io_collection.save.save_figure import save_figure 

12from io_collection.save.save_json import save_json 

13 

14 

15def get_merged_data( 

16 bucket: str, 

17 series_name: str, 

18 condition_keys: list[str], 

19 random_seeds: list[int], 

20 align: bool = True, 

21) -> pd.DataFrame: 

22 """ 

23 Load or create merged data for given conditions and random seeds. 

24 

25 If merged data (aligned or unaligned) already exists, load the data. 

26 Otherwise, iterate through the conditions and seeds to merge the data. 

27 

28 Parameters 

29 ---------- 

30 bucket 

31 Name of S3 bucket for input and output files. 

32 series_name 

33 Name of simulation series. 

34 condition_keys 

35 List of condition keys. 

36 random_seeds 

37 Random seeds for simulations. 

38 align 

39 True if data should be aligned, False otherwise. 

40 

41 Returns 

42 ------- 

43 : 

44 Merged data. 

45 """ 

46 

47 align_key = "all_samples_aligned" if align else "all_samples_unaligned" 

48 data_key = f"{series_name}/analysis/{series_name}_{align_key}.csv" 

49 

50 # Return data, if merged data already exists. 

51 if check_key(bucket, data_key): 

52 print( 

53 f"Dataframe [ { data_key } ] already exists. Loading existing merged data." 

54 ) 

55 return load_dataframe(bucket, data_key, dtype={"key": "str"}) 

56 

57 all_samples: list[pd.DataFrame] = [] 

58 

59 for condition_key in condition_keys: 

60 series_key = f"{series_name}_{condition_key}" if condition_key else series_name 

61 

62 for seed in random_seeds: 

63 print(f"Loading samples for [ {condition_key} ] seed [ {seed} ]") 

64 

65 sample_key = f"{series_name}/samples/{series_key}_{seed:06d}.csv" 

66 samples = load_dataframe(bucket, sample_key) 

67 samples["seed"] = seed 

68 samples["key"] = condition_key 

69 

70 if align: 

71 align_fibers(samples) 

72 

73 all_samples.append(samples) 

74 

75 samples_dataframe = pd.concat(all_samples) 

76 save_dataframe(bucket, data_key, samples_dataframe, index=False) 

77 

78 return samples_dataframe 

79 

80 

81def align_fibers(data: pd.DataFrame) -> None: 

82 """ 

83 Align fibers for each time point in the data. 

84 

85 Parameters 

86 ---------- 

87 data 

88 Simulated fiber data. 

89 """ 

90 

91 aligned_fibers = [] 

92 

93 for time, group in data.groupby("time", sort=False): 

94 coords = group[["xpos", "ypos", "zpos"]].values 

95 

96 if time == 0: 

97 fiber = coords 

98 else: 

99 fiber, _ = align_fiber(coords) 

100 

101 aligned_fibers.append(fiber) 

102 

103 all_aligned_fibers = np.vstack(aligned_fibers) 

104 

105 data["xpos"] = all_aligned_fibers[:, 0] 

106 data["ypos"] = all_aligned_fibers[:, 1] 

107 data["zpos"] = all_aligned_fibers[:, 2] 

108 

109 

110def align_fiber(coords: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 

111 """ 

112 Align an array of x, y, z coordinates along the positive y axis. 

113 

114 The function identifies the furthest point in the yz-plane and computes the 

115 angle needed to rotate this point to lie on the positive y axis. This 

116 rotation angle is applied to all y and z coordinates; x coordinates are not 

117 changed. For example, if the furthest point is (0.5, 0, 1), it is rotated to 

118 (0.5, 1, 0) with an angle of pi / 2. 

119 

120 Parameters 

121 ---------- 

122 coords 

123 Array of x, y, and z positions. 

124 """ 

125 

126 # Identify rotation angle based on distance to point furthest from (0,0) 

127 distances = np.sqrt(np.sum(coords[:, 1:] ** 2, axis=1)) 

128 max_index = np.argmax(distances) 

129 angle = np.arctan2(coords[max_index, 2], coords[max_index, 1]) 

130 

131 # Create rotation matrix 

132 c, s = np.cos(angle), np.sin(angle) 

133 rot = np.array(((c, -s), (s, c))) 

134 

135 # Rotate y and z 

136 rotated = np.dot(coords[:, 1:], rot) 

137 

138 return np.concatenate((coords[:, 0:1], rotated), axis=1), rot 

139 

140 

141def reshape_fibers(data: pd.DataFrame) -> tuple[np.ndarray, pd.DataFrame]: 

142 """ 

143 Reshape data from tidy data format to array of fibers and fiber features. 

144 

145 Parameters 

146 ---------- 

147 data 

148 Simulated fiber data. 

149 

150 Returns 

151 ------- 

152 : 

153 Array of fibers and dataframe of fiber features. 

154 """ 

155 

156 all_features = [] 

157 all_fibers = [] 

158 

159 for (time, velocity, repeat, simulator), group in data.groupby( 

160 ["time", "velocity", "repeat", "simulator"] 

161 ): 

162 fiber = group[["xpos", "ypos", "zpos"]].values.reshape(-1, 1) 

163 all_fibers.append(fiber) 

164 all_features.append( 

165 { 

166 "TIME": time, 

167 "VELOCITY": velocity, 

168 "REPEAT": repeat, 

169 "SIMULATOR": simulator.upper(), 

170 } 

171 ) 

172 

173 return np.array(all_fibers).squeeze(), pd.DataFrame(all_features) 

174 

175 

176def save_aligned_fibers( 

177 data: pd.DataFrame, time_map: dict, save_location: str, save_key: str 

178) -> None: 

179 """ 

180 Save aligned fiber data. 

181 

182 Parameters 

183 ---------- 

184 data 

185 Aligned fiber data. 

186 time_map 

187 Map of selected aligned time for each simulator and condition. 

188 save_location 

189 Location for output file (local path or S3 bucket). 

190 save_key 

191 Name key for output file. 

192 """ 

193 

194 output = [] 

195 

196 for (simulator, repeat, key, time), group in data.groupby( 

197 ["simulator", "repeat", "key", "time"] 

198 ): 

199 if time != time_map[(simulator, key)]: 

200 continue 

201 

202 fiber = group[["xpos", "ypos", "zpos"]].values 

203 output.append( 

204 { 

205 "simulator": simulator.upper(), 

206 "repeat": int(repeat), 

207 "key": key, 

208 "x": fiber[:, 0].tolist(), 

209 "y": fiber[:, 1].tolist(), 

210 "z": fiber[:, 2].tolist(), 

211 } 

212 ) 

213 

214 save_json(save_location, save_key, output) 

215 

216 

217def plot_fibers_by_key_and_seed( 

218 data: pd.DataFrame, 

219 save_location: Optional[str] = None, 

220 save_key: str = "aligned_fibers_by_key_and_seed.png", 

221) -> None: 

222 """ 

223 Plot simulated fiber data for each condition key and random seed. 

224 

225 Parameters 

226 ---------- 

227 data 

228 Simulated fiber data. 

229 save_location 

230 Location for output file (local path or S3 bucket). 

231 save_key 

232 Name key for output file. 

233 """ 

234 

235 rows = data["key"].unique() 

236 cols = data["seed"].unique() 

237 

238 figure, ax = plt.subplots( 

239 len(rows), len(cols), figsize=(10, 6), sharey=True, sharex=True 

240 ) 

241 

242 for row_index, row in enumerate(rows): 

243 for col_index, col in enumerate(cols): 

244 if row_index == 0: 

245 ax[row_index, col_index].set_title(f"REPEAT = {col}") 

246 if col_index == 0: 

247 ax[row_index, col_index].set_ylabel(f"KEY = {row}") 

248 

249 subset = data[(data["key"] == row) & (data["seed"] == col)] 

250 

251 for (_, simulator), group in subset.groupby(["time", "simulator"]): 

252 color = "red" if simulator == "readdy" else "blue" 

253 coords = group[["xpos", "ypos", "zpos"]].values 

254 ax[row_index, col_index].plot( 

255 coords[:, 1], coords[:, 2], lw=0.5, color=color, alpha=0.5 

256 ) 

257 

258 plt.tight_layout() 

259 plt.show() 

260 

261 if save_location is not None: 

262 save_figure(save_location, save_key, figure)