Coverage for subcell_pipeline/simulation/readdy/loader.py: 0%
88 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-29 15:14 +0000
1"""Class for loading and shaping ReaDDy trajectories."""
3from typing import Any, Optional
5import numpy as np
6import readdy
7from io_collection.keys.check_key import check_key
8from io_collection.load.load_pickle import load_pickle
9from io_collection.save.save_pickle import save_pickle
10from tqdm import tqdm
12from subcell_pipeline.simulation.readdy.data_structures import (
13 FrameData,
14 ParticleData,
15 TopologyData,
16)
19class ReaddyLoader:
20 """
21 Load and shape data from a ReaDDy trajectory.
23 Trajectory is loaded from the simulation output h5 file of the .dat pickle
24 file. If a .dat pickle location and key are provided, the loaded trajectory
25 is saved to the given location for faster reloads.
26 """
28 _readdy_trajectory: Optional[readdy.Trajectory]
29 """ReaDDy trajectory object."""
31 _trajectory: Optional[list[FrameData]]
32 """List of FrameData for trajectory."""
34 h5_file_path: str
35 """Path to the ReaDDy .h5 file or .dat pickle file."""
37 min_time_ix: int
38 """First time index to include."""
40 max_time_ix: int
41 """Last time index to include."""
43 time_inc: int
44 """Include every time_inc timestep."""
46 timestep: float
47 """Real time for each simulation timestep."""
49 pickle_location: Optional[str]
50 """Location to save pickle file (AWS S3 bucket or local path)."""
52 pickle_key: Optional[str]
53 """Name of pickle file (AWS S3 bucket or local path)."""
55 def __init__(
56 self,
57 h5_file_path: str,
58 min_time_ix: int = 0,
59 max_time_ix: int = -1,
60 time_inc: int = 1,
61 timestep: float = 100.0,
62 pickle_location: Optional[str] = None,
63 pickle_key: Optional[str] = None,
64 ):
65 self._readdy_trajectory = None
66 self._trajectory = None
67 self.h5_file_path = h5_file_path
68 self.min_time_ix = min_time_ix
69 self.max_time_ix = max_time_ix
70 self.time_inc = time_inc
71 self.timestep = timestep
72 self.pickle_location = pickle_location
73 self.pickle_key = pickle_key
75 def readdy_trajectory(self) -> readdy.Trajectory:
76 """
77 Lazy load the ReaDDy trajectory object.
79 Note that loading ReaDDy trajectories requires a path to a local file.
80 Loading currently does not support S3 locations.
82 Returns
83 -------
84 :
85 The ReaDDy trajectory object.
86 """
87 if self._readdy_trajectory is None:
88 self._readdy_trajectory = readdy.Trajectory(self.h5_file_path)
89 return self._readdy_trajectory
91 @staticmethod
92 def _frame_edges(time_ix: int, topology_records: Any) -> list[list[int]]:
93 """
94 Get all edges at the given time index as [particle1 id, particle2 id].
96 The ``topology_records`` object is output from
97 ``readdy.Trajectory(h5_file_path).read_observable_topologies()``.
98 """
99 result = []
100 for top in topology_records[time_ix]:
101 for e1, e2 in top.edges:
102 if e1 <= e2:
103 ix1 = top.particles[e1]
104 ix2 = top.particles[e2]
105 result.append([ix1, ix2])
106 return result
108 def _shape_trajectory_data(self) -> list[FrameData]:
109 """Shape data from a ReaDDy trajectory for analysis."""
110 (
111 _,
112 topology_records,
113 ) = self.readdy_trajectory().read_observable_topologies() # type: ignore
114 (
115 times,
116 types,
117 ids,
118 positions,
119 ) = self.readdy_trajectory().read_observable_particles() # type: ignore
120 result = []
121 for time_ix in tqdm(range(len(times))):
122 if (
123 time_ix < self.min_time_ix
124 or (self.max_time_ix >= 0 and time_ix > self.max_time_ix)
125 or times[time_ix] % self.time_inc != 0
126 ):
127 continue
128 frame = FrameData(time=self.timestep * time_ix)
129 frame.edge_ids = ReaddyLoader._frame_edges(time_ix, topology_records)
130 for index, top in enumerate(topology_records[time_ix]):
131 frame.topologies[index] = TopologyData(
132 uid=index,
133 type_name=top.type,
134 particle_ids=top.particles,
135 )
136 for p in range(len(ids[time_ix])):
137 p_id = ids[time_ix][p]
138 position = positions[time_ix][p]
139 neighbor_ids = []
140 for edge in frame.edge_ids:
141 if p_id == edge[0]:
142 neighbor_ids.append(edge[1])
143 elif p_id == edge[1]:
144 neighbor_ids.append(edge[0])
145 frame.particles[ids[time_ix][p]] = ParticleData(
146 uid=ids[time_ix][p],
147 type_name=self.readdy_trajectory().species_name( # type: ignore
148 types[time_ix][p]
149 ),
150 position=np.array([position[0], position[1], position[2]]),
151 neighbor_ids=neighbor_ids,
152 )
153 result.append(frame)
154 return result
156 def trajectory(self) -> list[FrameData]:
157 """
158 Lazy load the shaped trajectory.
160 Returns
161 -------
162 :
163 The trajectory of data shaped for analysis.
164 """
166 if self._trajectory is not None:
167 return self._trajectory
169 if self.pickle_location is not None and self.pickle_key is not None:
170 if check_key(self.pickle_location, self.pickle_key):
171 print(f"Loading pickle file for ReaDDy data from {self.pickle_key}")
172 self._trajectory = load_pickle(self.pickle_location, self.pickle_key)
173 else:
174 print(f"Loading ReaDDy data from h5 file {self.h5_file_path}")
175 print(f"Saving pickle file for ReaDDy data to {self.h5_file_path}")
176 self._trajectory = self._shape_trajectory_data()
177 save_pickle(self.pickle_location, self.pickle_key, self._trajectory)
178 else:
179 print(f"Loading ReaDDy data from h5 file {self.h5_file_path}")
180 self._trajectory = self._shape_trajectory_data()
182 return self._trajectory