Coverage for subcell_pipeline/analysis/compression_metrics/compression_analysis.py: 0%
90 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
1"""Methods compression metric analysis and plotting."""
3from typing import Any, Optional
5import numpy as np
6import pandas as pd
7from io_collection.keys.check_key import check_key
8from io_collection.load.load_dataframe import load_dataframe
9from io_collection.save.save_dataframe import save_dataframe
10from io_collection.save.save_figure import save_figure
11from matplotlib import pyplot as plt
13from subcell_pipeline.analysis.compression_metrics.compression_metric import (
14 CompressionMetric,
15)
16from subcell_pipeline.analysis.compression_metrics.constants import (
17 DEFAULT_COMPRESSION_DISTANCE,
18 SIMULATOR_COLOR_MAP,
19)
22def get_compression_metric_data(
23 bucket: str,
24 series_name: str,
25 condition_keys: list[str],
26 random_seeds: list[int],
27 metrics: list[CompressionMetric],
28 recalculate: bool = False,
29) -> pd.DataFrame:
30 """
31 Load or create merged data with metrics for given conditions and seeds.
33 If merged data already exists, load the data. Otherwise, iterate through the
34 conditions and seeds to merge the data.
36 Parameters
37 ----------
38 bucket
39 Name of S3 bucket for input and output files.
40 series_name
41 Name of simulation series.
42 condition_keys
43 List of condition keys.
44 random_seeds
45 Random seeds for simulations.
46 metrics
47 List of metrics to calculate.
48 recalculate
49 True if data should be recalculated, False otherwise.
51 Returns
52 -------
53 :
54 Merged dataframe with one row per fiber with calculated metrics.
55 """
57 data_key = f"{series_name}/analysis/{series_name}_compression_metrics.csv"
59 # Return data, if merged data already exists.
60 if check_key(bucket, data_key) and not recalculate:
61 print(
62 f"Dataframe [ { data_key } ] already exists. Loading existing merged data."
63 )
64 return load_dataframe(bucket, data_key, dtype={"key": "str"})
66 all_metrics: list[pd.DataFrame] = []
68 for condition_key in condition_keys:
69 series_key = f"{series_name}_{condition_key}" if condition_key else series_name
71 for seed in random_seeds:
72 print(
73 f"Loading samples and calculating metrics for "
74 f"[ {condition_key} ] seed [ {seed} ]"
75 )
77 sample_key = f"{series_name}/samples/{series_key}_{seed:06d}.csv"
78 samples = load_dataframe(bucket, sample_key)
80 metric_data = calculate_compression_metrics(samples, metrics)
81 metric_data["seed"] = seed
82 metric_data["key"] = condition_key
84 all_metrics.append(metric_data)
86 metrics_dataframe = pd.concat(all_metrics)
87 save_dataframe(bucket, data_key, metrics_dataframe, index=False)
89 return metrics_dataframe
92def calculate_compression_metrics(
93 df: pd.DataFrame, metrics: list[Any], **options: dict[str, Any]
94) -> pd.DataFrame:
95 """
96 Calculate compression metrics for a single simulation condition and seed.
98 Parameters
99 ----------
100 df
101 Input data for a single simulator.
102 metrics
103 The list of metrics to calculate.
104 **options
105 Additional options for the calculation.
107 Returns
108 -------
109 :
110 Dataframe with calculated metrics.
111 """
112 time_values = df["time"].unique()
113 df_metrics = pd.DataFrame(
114 index=time_values, columns=[metric.value for metric in metrics]
115 )
117 for time, fiber_at_time in df.groupby("time"):
118 polymer_trace = fiber_at_time[["xpos", "ypos", "zpos"]].values
119 for metric in metrics:
120 df_metrics.loc[time, metric.value] = metric.calculate_metric(
121 polymer_trace=polymer_trace, **options
122 )
124 df_metrics = df_metrics.reset_index().rename(columns={"index": "time"})
125 df_metrics["normalized_time"] = df_metrics["time"] / df_metrics["time"].max()
127 return df_metrics
130def save_compression_metrics(
131 data: pd.DataFrame, save_location: str, save_key: str
132) -> None:
133 """
134 Save combined compression metrics data.
136 Parameters
137 ----------
138 data
139 Compression metrics data.
140 save_location
141 Location for output file (local path or S3 bucket).
142 save_key
143 Name key for output file.
144 """
146 save_dataframe(save_location, save_key, data, index=False)
149def plot_metrics_vs_time(
150 df: pd.DataFrame,
151 metrics: list[CompressionMetric],
152 compression_distance: float = DEFAULT_COMPRESSION_DISTANCE,
153 use_real_time: bool = False,
154 save_location: Optional[str] = None,
155 save_key_template: str = "compression_metrics_over_time_%s.png",
156) -> None:
157 """
158 Plot individual metric values over time for each velocity.
160 Parameters
161 ----------
162 df
163 Input dataframe.
164 metrics
165 List of metrics to plot.
166 compression_distance
167 Compression distance in nm.
168 use_real_time
169 True to use real time for the x-axis, False otherwise.
170 save_location
171 Location for output file (local path or S3 bucket).
172 save_key_template
173 Name key template for output file.
174 """
176 num_velocities = df["velocity"].nunique()
177 total_time = 1.0
178 time_label = "Normalized Time"
179 plt.rcParams.update({"font.size": 16})
181 for metric in metrics:
182 figure, axs = plt.subplots(
183 1, num_velocities, figsize=(num_velocities * 5, 5), sharey=True, dpi=300
184 )
185 axs = axs.ravel()
186 for ct, (velocity, df_velocity) in enumerate(df.groupby("velocity")):
187 if use_real_time:
188 # type checker is unable to infer the datatype of velocity
189 total_time = compression_distance / velocity
190 time_label = "Time (s)"
191 for simulator, df_simulator in df_velocity.groupby("simulator"):
192 for repeat, df_repeat in df_simulator.groupby("repeat"):
193 if repeat == 0:
194 label = f"{simulator}"
195 else:
196 label = "_nolegend_"
197 xvals = np.linspace(0, 1, df_repeat["time"].nunique()) * total_time
198 yvals = df_repeat.groupby("time")[metric.value].mean()
200 # type checker is unable to infer the datatype of velocity
201 axs[ct].plot(
202 xvals,
203 yvals,
204 label=label,
205 color=SIMULATOR_COLOR_MAP[simulator],
206 alpha=0.6,
207 )
208 axs[ct].set_title(f"Velocity: {velocity}")
209 if ct == 0:
210 axs[ct].legend()
211 figure.supxlabel(time_label)
212 figure.supylabel(metric.label())
213 figure.tight_layout()
215 if save_location is not None:
216 save_key = save_key_template % metric.value
217 save_figure(save_location, save_key, figure)
220def plot_metric_distribution(
221 df: pd.DataFrame,
222 metrics: list[CompressionMetric],
223 save_location: Optional[str] = None,
224 save_key_template: str = "compression_metrics_histograms_%s.png",
225) -> None:
226 """
227 Plot distribution of metric values for each velocity.
229 Parameters
230 ----------
231 df
232 Input dataframe.
233 metrics
234 List of metrics to plot.
235 save_location
236 Location for output file (local path or S3 bucket).
237 save_key_template
238 Name key template for output file.
239 """
241 num_velocities = df["velocity"].nunique()
242 plt.rcParams.update({"font.size": 16})
244 for metric in metrics:
245 figure, axs = plt.subplots(
246 1,
247 num_velocities,
248 figsize=(num_velocities * 5, 5),
249 sharey=True,
250 sharex=True,
251 dpi=300,
252 )
253 axs = axs.ravel()
254 for ct, (velocity, df_velocity) in enumerate(df.groupby("velocity")):
255 metric_values = df_velocity[metric.value]
256 bins = np.linspace(np.nanmin(metric_values), np.nanmax(metric_values), 20)
257 for simulator, df_simulator in df_velocity.groupby("simulator"):
258 axs[ct].hist(
259 df_simulator[metric.value],
260 label=f"{simulator}",
261 color=SIMULATOR_COLOR_MAP[simulator],
262 alpha=0.7,
263 bins=bins,
264 )
265 axs[ct].set_title(f"Velocity: {velocity}")
266 if ct == 0:
267 axs[ct].legend()
268 figure.supxlabel(metric.label())
269 figure.supylabel("Count")
270 figure.tight_layout()
272 if save_location is not None:
273 save_key = save_key_template % metric.value
274 save_figure(save_location, save_key, figure)