Coverage for subcell_pipeline/analysis/compression_metrics/compression_analysis.py: 0%

90 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Methods compression metric analysis and plotting.""" 

2 

3from typing import Any, Optional 

4 

5import numpy as np 

6import pandas as pd 

7from io_collection.keys.check_key import check_key 

8from io_collection.load.load_dataframe import load_dataframe 

9from io_collection.save.save_dataframe import save_dataframe 

10from io_collection.save.save_figure import save_figure 

11from matplotlib import pyplot as plt 

12 

13from subcell_pipeline.analysis.compression_metrics.compression_metric import ( 

14 CompressionMetric, 

15) 

16from subcell_pipeline.analysis.compression_metrics.constants import ( 

17 DEFAULT_COMPRESSION_DISTANCE, 

18 SIMULATOR_COLOR_MAP, 

19) 

20 

21 

22def get_compression_metric_data( 

23 bucket: str, 

24 series_name: str, 

25 condition_keys: list[str], 

26 random_seeds: list[int], 

27 metrics: list[CompressionMetric], 

28 recalculate: bool = False, 

29) -> pd.DataFrame: 

30 """ 

31 Load or create merged data with metrics for given conditions and seeds. 

32 

33 If merged data already exists, load the data. Otherwise, iterate through the 

34 conditions and seeds to merge the data. 

35 

36 Parameters 

37 ---------- 

38 bucket 

39 Name of S3 bucket for input and output files. 

40 series_name 

41 Name of simulation series. 

42 condition_keys 

43 List of condition keys. 

44 random_seeds 

45 Random seeds for simulations. 

46 metrics 

47 List of metrics to calculate. 

48 recalculate 

49 True if data should be recalculated, False otherwise. 

50 

51 Returns 

52 ------- 

53 : 

54 Merged dataframe with one row per fiber with calculated metrics. 

55 """ 

56 

57 data_key = f"{series_name}/analysis/{series_name}_compression_metrics.csv" 

58 

59 # Return data, if merged data already exists. 

60 if check_key(bucket, data_key) and not recalculate: 

61 print( 

62 f"Dataframe [ { data_key } ] already exists. Loading existing merged data." 

63 ) 

64 return load_dataframe(bucket, data_key, dtype={"key": "str"}) 

65 

66 all_metrics: list[pd.DataFrame] = [] 

67 

68 for condition_key in condition_keys: 

69 series_key = f"{series_name}_{condition_key}" if condition_key else series_name 

70 

71 for seed in random_seeds: 

72 print( 

73 f"Loading samples and calculating metrics for " 

74 f"[ {condition_key} ] seed [ {seed} ]" 

75 ) 

76 

77 sample_key = f"{series_name}/samples/{series_key}_{seed:06d}.csv" 

78 samples = load_dataframe(bucket, sample_key) 

79 

80 metric_data = calculate_compression_metrics(samples, metrics) 

81 metric_data["seed"] = seed 

82 metric_data["key"] = condition_key 

83 

84 all_metrics.append(metric_data) 

85 

86 metrics_dataframe = pd.concat(all_metrics) 

87 save_dataframe(bucket, data_key, metrics_dataframe, index=False) 

88 

89 return metrics_dataframe 

90 

91 

92def calculate_compression_metrics( 

93 df: pd.DataFrame, metrics: list[Any], **options: dict[str, Any] 

94) -> pd.DataFrame: 

95 """ 

96 Calculate compression metrics for a single simulation condition and seed. 

97 

98 Parameters 

99 ---------- 

100 df 

101 Input data for a single simulator. 

102 metrics 

103 The list of metrics to calculate. 

104 **options 

105 Additional options for the calculation. 

106 

107 Returns 

108 ------- 

109 : 

110 Dataframe with calculated metrics. 

111 """ 

112 time_values = df["time"].unique() 

113 df_metrics = pd.DataFrame( 

114 index=time_values, columns=[metric.value for metric in metrics] 

115 ) 

116 

117 for time, fiber_at_time in df.groupby("time"): 

118 polymer_trace = fiber_at_time[["xpos", "ypos", "zpos"]].values 

119 for metric in metrics: 

120 df_metrics.loc[time, metric.value] = metric.calculate_metric( 

121 polymer_trace=polymer_trace, **options 

122 ) 

123 

124 df_metrics = df_metrics.reset_index().rename(columns={"index": "time"}) 

125 df_metrics["normalized_time"] = df_metrics["time"] / df_metrics["time"].max() 

126 

127 return df_metrics 

128 

129 

130def save_compression_metrics( 

131 data: pd.DataFrame, save_location: str, save_key: str 

132) -> None: 

133 """ 

134 Save combined compression metrics data. 

135 

136 Parameters 

137 ---------- 

138 data 

139 Compression metrics data. 

140 save_location 

141 Location for output file (local path or S3 bucket). 

142 save_key 

143 Name key for output file. 

144 """ 

145 

146 save_dataframe(save_location, save_key, data, index=False) 

147 

148 

149def plot_metrics_vs_time( 

150 df: pd.DataFrame, 

151 metrics: list[CompressionMetric], 

152 compression_distance: float = DEFAULT_COMPRESSION_DISTANCE, 

153 use_real_time: bool = False, 

154 save_location: Optional[str] = None, 

155 save_key_template: str = "compression_metrics_over_time_%s.png", 

156) -> None: 

157 """ 

158 Plot individual metric values over time for each velocity. 

159 

160 Parameters 

161 ---------- 

162 df 

163 Input dataframe. 

164 metrics 

165 List of metrics to plot. 

166 compression_distance 

167 Compression distance in nm. 

168 use_real_time 

169 True to use real time for the x-axis, False otherwise. 

170 save_location 

171 Location for output file (local path or S3 bucket). 

172 save_key_template 

173 Name key template for output file. 

174 """ 

175 

176 num_velocities = df["velocity"].nunique() 

177 total_time = 1.0 

178 time_label = "Normalized Time" 

179 plt.rcParams.update({"font.size": 16}) 

180 

181 for metric in metrics: 

182 figure, axs = plt.subplots( 

183 1, num_velocities, figsize=(num_velocities * 5, 5), sharey=True, dpi=300 

184 ) 

185 axs = axs.ravel() 

186 for ct, (velocity, df_velocity) in enumerate(df.groupby("velocity")): 

187 if use_real_time: 

188 # type checker is unable to infer the datatype of velocity 

189 total_time = compression_distance / velocity 

190 time_label = "Time (s)" 

191 for simulator, df_simulator in df_velocity.groupby("simulator"): 

192 for repeat, df_repeat in df_simulator.groupby("repeat"): 

193 if repeat == 0: 

194 label = f"{simulator}" 

195 else: 

196 label = "_nolegend_" 

197 xvals = np.linspace(0, 1, df_repeat["time"].nunique()) * total_time 

198 yvals = df_repeat.groupby("time")[metric.value].mean() 

199 

200 # type checker is unable to infer the datatype of velocity 

201 axs[ct].plot( 

202 xvals, 

203 yvals, 

204 label=label, 

205 color=SIMULATOR_COLOR_MAP[simulator], 

206 alpha=0.6, 

207 ) 

208 axs[ct].set_title(f"Velocity: {velocity}") 

209 if ct == 0: 

210 axs[ct].legend() 

211 figure.supxlabel(time_label) 

212 figure.supylabel(metric.label()) 

213 figure.tight_layout() 

214 

215 if save_location is not None: 

216 save_key = save_key_template % metric.value 

217 save_figure(save_location, save_key, figure) 

218 

219 

220def plot_metric_distribution( 

221 df: pd.DataFrame, 

222 metrics: list[CompressionMetric], 

223 save_location: Optional[str] = None, 

224 save_key_template: str = "compression_metrics_histograms_%s.png", 

225) -> None: 

226 """ 

227 Plot distribution of metric values for each velocity. 

228 

229 Parameters 

230 ---------- 

231 df 

232 Input dataframe. 

233 metrics 

234 List of metrics to plot. 

235 save_location 

236 Location for output file (local path or S3 bucket). 

237 save_key_template 

238 Name key template for output file. 

239 """ 

240 

241 num_velocities = df["velocity"].nunique() 

242 plt.rcParams.update({"font.size": 16}) 

243 

244 for metric in metrics: 

245 figure, axs = plt.subplots( 

246 1, 

247 num_velocities, 

248 figsize=(num_velocities * 5, 5), 

249 sharey=True, 

250 sharex=True, 

251 dpi=300, 

252 ) 

253 axs = axs.ravel() 

254 for ct, (velocity, df_velocity) in enumerate(df.groupby("velocity")): 

255 metric_values = df_velocity[metric.value] 

256 bins = np.linspace(np.nanmin(metric_values), np.nanmax(metric_values), 20) 

257 for simulator, df_simulator in df_velocity.groupby("simulator"): 

258 axs[ct].hist( 

259 df_simulator[metric.value], 

260 label=f"{simulator}", 

261 color=SIMULATOR_COLOR_MAP[simulator], 

262 alpha=0.7, 

263 bins=bins, 

264 ) 

265 axs[ct].set_title(f"Velocity: {velocity}") 

266 if ct == 0: 

267 axs[ct].legend() 

268 figure.supxlabel(metric.label()) 

269 figure.supylabel("Count") 

270 figure.tight_layout() 

271 

272 if save_location is not None: 

273 save_key = save_key_template % metric.value 

274 save_figure(save_location, save_key, figure)