Coverage for subcell_pipeline/simulation/readdy/loader.py: 0%

88 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Class for loading and shaping ReaDDy trajectories.""" 

2 

3from typing import Any, Optional 

4 

5import numpy as np 

6import readdy 

7from io_collection.keys.check_key import check_key 

8from io_collection.load.load_pickle import load_pickle 

9from io_collection.save.save_pickle import save_pickle 

10from tqdm import tqdm 

11 

12from subcell_pipeline.simulation.readdy.data_structures import ( 

13 FrameData, 

14 ParticleData, 

15 TopologyData, 

16) 

17 

18 

19class ReaddyLoader: 

20 """ 

21 Load and shape data from a ReaDDy trajectory. 

22 

23 Trajectory is loaded from the simulation output h5 file of the .dat pickle 

24 file. If a .dat pickle location and key are provided, the loaded trajectory 

25 is saved to the given location for faster reloads. 

26 """ 

27 

28 _readdy_trajectory: Optional[readdy.Trajectory] 

29 """ReaDDy trajectory object.""" 

30 

31 _trajectory: Optional[list[FrameData]] 

32 """List of FrameData for trajectory.""" 

33 

34 h5_file_path: str 

35 """Path to the ReaDDy .h5 file or .dat pickle file.""" 

36 

37 min_time_ix: int 

38 """First time index to include.""" 

39 

40 max_time_ix: int 

41 """Last time index to include.""" 

42 

43 time_inc: int 

44 """Include every time_inc timestep.""" 

45 

46 timestep: float 

47 """Real time for each simulation timestep.""" 

48 

49 pickle_location: Optional[str] 

50 """Location to save pickle file (AWS S3 bucket or local path).""" 

51 

52 pickle_key: Optional[str] 

53 """Name of pickle file (AWS S3 bucket or local path).""" 

54 

55 def __init__( 

56 self, 

57 h5_file_path: str, 

58 min_time_ix: int = 0, 

59 max_time_ix: int = -1, 

60 time_inc: int = 1, 

61 timestep: float = 100.0, 

62 pickle_location: Optional[str] = None, 

63 pickle_key: Optional[str] = None, 

64 ): 

65 self._readdy_trajectory = None 

66 self._trajectory = None 

67 self.h5_file_path = h5_file_path 

68 self.min_time_ix = min_time_ix 

69 self.max_time_ix = max_time_ix 

70 self.time_inc = time_inc 

71 self.timestep = timestep 

72 self.pickle_location = pickle_location 

73 self.pickle_key = pickle_key 

74 

75 def readdy_trajectory(self) -> readdy.Trajectory: 

76 """ 

77 Lazy load the ReaDDy trajectory object. 

78 

79 Note that loading ReaDDy trajectories requires a path to a local file. 

80 Loading currently does not support S3 locations. 

81 

82 Returns 

83 ------- 

84 : 

85 The ReaDDy trajectory object. 

86 """ 

87 if self._readdy_trajectory is None: 

88 self._readdy_trajectory = readdy.Trajectory(self.h5_file_path) 

89 return self._readdy_trajectory 

90 

91 @staticmethod 

92 def _frame_edges(time_ix: int, topology_records: Any) -> list[list[int]]: 

93 """ 

94 Get all edges at the given time index as [particle1 id, particle2 id]. 

95 

96 The ``topology_records`` object is output from 

97 ``readdy.Trajectory(h5_file_path).read_observable_topologies()``. 

98 """ 

99 result = [] 

100 for top in topology_records[time_ix]: 

101 for e1, e2 in top.edges: 

102 if e1 <= e2: 

103 ix1 = top.particles[e1] 

104 ix2 = top.particles[e2] 

105 result.append([ix1, ix2]) 

106 return result 

107 

108 def _shape_trajectory_data(self) -> list[FrameData]: 

109 """Shape data from a ReaDDy trajectory for analysis.""" 

110 ( 

111 _, 

112 topology_records, 

113 ) = self.readdy_trajectory().read_observable_topologies() # type: ignore 

114 ( 

115 times, 

116 types, 

117 ids, 

118 positions, 

119 ) = self.readdy_trajectory().read_observable_particles() # type: ignore 

120 result = [] 

121 for time_ix in tqdm(range(len(times))): 

122 if ( 

123 time_ix < self.min_time_ix 

124 or (self.max_time_ix >= 0 and time_ix > self.max_time_ix) 

125 or times[time_ix] % self.time_inc != 0 

126 ): 

127 continue 

128 frame = FrameData(time=self.timestep * time_ix) 

129 frame.edge_ids = ReaddyLoader._frame_edges(time_ix, topology_records) 

130 for index, top in enumerate(topology_records[time_ix]): 

131 frame.topologies[index] = TopologyData( 

132 uid=index, 

133 type_name=top.type, 

134 particle_ids=top.particles, 

135 ) 

136 for p in range(len(ids[time_ix])): 

137 p_id = ids[time_ix][p] 

138 position = positions[time_ix][p] 

139 neighbor_ids = [] 

140 for edge in frame.edge_ids: 

141 if p_id == edge[0]: 

142 neighbor_ids.append(edge[1]) 

143 elif p_id == edge[1]: 

144 neighbor_ids.append(edge[0]) 

145 frame.particles[ids[time_ix][p]] = ParticleData( 

146 uid=ids[time_ix][p], 

147 type_name=self.readdy_trajectory().species_name( # type: ignore 

148 types[time_ix][p] 

149 ), 

150 position=np.array([position[0], position[1], position[2]]), 

151 neighbor_ids=neighbor_ids, 

152 ) 

153 result.append(frame) 

154 return result 

155 

156 def trajectory(self) -> list[FrameData]: 

157 """ 

158 Lazy load the shaped trajectory. 

159 

160 Returns 

161 ------- 

162 : 

163 The trajectory of data shaped for analysis. 

164 """ 

165 

166 if self._trajectory is not None: 

167 return self._trajectory 

168 

169 if self.pickle_location is not None and self.pickle_key is not None: 

170 if check_key(self.pickle_location, self.pickle_key): 

171 print(f"Loading pickle file for ReaDDy data from {self.pickle_key}") 

172 self._trajectory = load_pickle(self.pickle_location, self.pickle_key) 

173 else: 

174 print(f"Loading ReaDDy data from h5 file {self.h5_file_path}") 

175 print(f"Saving pickle file for ReaDDy data to {self.h5_file_path}") 

176 self._trajectory = self._shape_trajectory_data() 

177 save_pickle(self.pickle_location, self.pickle_key, self._trajectory) 

178 else: 

179 print(f"Loading ReaDDy data from h5 file {self.h5_file_path}") 

180 self._trajectory = self._shape_trajectory_data() 

181 

182 return self._trajectory