Coverage for subcell_pipeline/simulation/batch_simulations.py: 0%

70 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2024-08-29 15:14 +0000

1"""Methods for running simulations on AWS Batch.""" 

2 

3import re 

4from typing import Optional 

5 

6import boto3 

7from container_collection.batch.get_batch_logs import get_batch_logs 

8from container_collection.batch.make_batch_job import make_batch_job 

9from container_collection.batch.register_batch_job import register_batch_job 

10from container_collection.batch.submit_batch_job import submit_batch_job 

11from io_collection.keys.copy_key import copy_key 

12from io_collection.save.save_text import save_text 

13 

14 

15def generate_configs_from_file( 

16 bucket: str, 

17 series_name: str, 

18 timestamp: str, 

19 random_seeds: list[int], 

20 config_file: str, 

21) -> None: 

22 """ 

23 Generate configs from given file for each seed and save to S3 bucket. 

24 

25 Parameters 

26 ---------- 

27 bucket 

28 Name of S3 bucket for input and output files. 

29 series_name 

30 Name of simulation series. 

31 timestamp 

32 Current timestamp used to organize input and outfile files. 

33 random_seeds 

34 Random seeds for simulations. 

35 config_file 

36 Path to the config file. 

37 """ 

38 

39 with open(config_file) as f: 

40 contents = f.read() 

41 

42 for index, seed in enumerate(random_seeds): 

43 config_key = f"{series_name}/{timestamp}/configs/{series_name}_{index}.cym" 

44 config_contents = contents.replace("{{RANDOM_SEED}}", str(seed)) 

45 print(f"Saving config for for seed {seed} to [ {config_key}]") 

46 save_text(bucket, config_key, config_contents) 

47 

48 

49def generate_configs_from_template( 

50 bucket: str, 

51 series_name: str, 

52 timestamp: str, 

53 random_seeds: list[int], 

54 config_files: list[str], 

55 pattern: str, 

56 key_map: dict[str, str], 

57) -> list[str]: 

58 """ 

59 Generate configs for each given file for each seed and save to S3 bucket. 

60 

61 Parameters 

62 ---------- 

63 bucket 

64 Name of S3 bucket for input and output files. 

65 series_name 

66 Name of simulation series. 

67 timestamp 

68 Current timestamp used to organize input and outfile files. 

69 random_seeds 

70 Random seeds for simulations. 

71 config_files 

72 Path to the config files. 

73 pattern 

74 Regex pattern to find config condition value. 

75 key_map 

76 Map of condition values to file keys. 

77 

78 Returns 

79 ------- 

80 : 

81 List of config groups. 

82 """ 

83 

84 group_keys = [] 

85 

86 for config_file in config_files: 

87 with open(config_file) as f: 

88 contents = f.read() 

89 

90 match = re.findall(pattern, contents)[0].strip() 

91 match_key = key_map[match] 

92 

93 group_key = f"{series_name}_{match_key}" 

94 group_keys.append(group_key) 

95 

96 for index, seed in enumerate(random_seeds): 

97 config_key = f"{series_name}/{timestamp}/configs/{group_key}_{index}.cym" 

98 config_contents = contents.replace("{{RANDOM_SEED}}", str(seed)) 

99 print(f"Saving config for [ {match} ] for seed {seed} to [ {config_key} ]") 

100 save_text(bucket, config_key, config_contents) 

101 

102 return group_keys 

103 

104 

105def register_and_run_simulations( 

106 bucket: str, 

107 series_name: str, 

108 timestamp: str, 

109 group_keys: list[str], 

110 aws_account: str, 

111 aws_region: str, 

112 aws_user: str, 

113 image: str, 

114 vcpus: int, 

115 memory: int, 

116 job_queue: str, 

117 job_size: int, 

118) -> list[str]: 

119 """ 

120 Register job definitions and submit jobs to AWS Batch. 

121 

122 Parameters 

123 ---------- 

124 bucket 

125 Name of S3 bucket for input and output files. 

126 series_name 

127 Name of simulation series. 

128 timestamp 

129 Current timestamp used to organize input and outfile files. 

130 group_keys 

131 List of config group keys. 

132 aws_account 

133 AWS account number. 

134 aws_region 

135 AWS region. 

136 aws_user 

137 User name prefix for job name and image. 

138 image 

139 Image name and version. 

140 vcpus 

141 Number of vCPUs for each job. 

142 memory 

143 Memory for each job. 

144 job_queue 

145 Job queue. 

146 job_size 

147 Job array size. 

148 

149 Returns 

150 ------- 

151 : 

152 List of job ARNs. 

153 """ 

154 

155 boto3.setup_default_session(region_name=aws_region) 

156 

157 all_job_arns: list[str] = [] 

158 registry = f"{aws_account}.dkr.ecr.{aws_region}.amazonaws.com" 

159 job_key = f"{bucket}/{series_name}/{timestamp}/" 

160 

161 for group_key in group_keys: 

162 job_definition = make_batch_job( 

163 f"{aws_user}_{group_key}", 

164 f"{registry}/{aws_user}/{image}", 

165 vcpus, 

166 memory, 

167 [ 

168 {"name": "SIMULATION_TYPE", "value": "AWS"}, 

169 {"name": "BATCH_WORKING_URL", "value": job_key}, 

170 {"name": "FILE_SET_NAME", "value": group_key}, 

171 ], 

172 f"arn:aws:iam::{aws_account}:role/BatchJobRole", 

173 ) 

174 

175 job_definition_arn = register_batch_job(job_definition) 

176 

177 print(f"Create job definition [ {job_definition_arn} ] for [ {group_key} ]") 

178 

179 job_arns = submit_batch_job( 

180 group_key, 

181 job_definition_arn, 

182 aws_user, 

183 job_queue, 

184 job_size, 

185 ) 

186 

187 for job_arn in job_arns: 

188 print(f"Submitted job [ {job_arn} ] for [ {group_key} ]") 

189 

190 all_job_arns.extend(job_arns) 

191 

192 return all_job_arns 

193 

194 

195def check_and_save_job_logs( 

196 bucket: str, series_name: str, job_arns: list[str], aws_region: str 

197) -> None: 

198 """ 

199 Check job status and save CloudWatch logs for successfully completed jobs. 

200 

201 Parameters 

202 ---------- 

203 bucket 

204 Name of S3 bucket for input and output files. 

205 series_name 

206 Name of simulation series. 

207 job_arns 

208 List of job ARNs. 

209 aws_region 

210 AWS region. 

211 """ 

212 

213 boto3.setup_default_session(region_name=aws_region) 

214 

215 batch_client = boto3.client("batch") 

216 responses = batch_client.describe_jobs(jobs=job_arns)["jobs"] 

217 

218 for response in responses: 

219 if responses[0]["status"] != "SUCCEEDED": 

220 print( 

221 f"Job [ {response['jobId']} ] has status [ {responses[0]['status']} ]" 

222 ) 

223 else: 

224 group_key = next( 

225 item 

226 for item in response["container"]["environment"] 

227 if item["name"] == "FILE_SET_NAME" 

228 )["value"] 

229 log_key = f"{series_name}/logs/{group_key}_{response['jobId']}.log" 

230 

231 print(f"Saving logs for job [ {response['jobId']} ] to [ {log_key}]") 

232 

233 logs = get_batch_logs(response["jobArn"], " ") 

234 save_text(bucket, log_key, logs) 

235 

236 

237def copy_simulation_outputs( 

238 bucket: str, 

239 series_name: str, 

240 source_template: str, 

241 n_replicates: int, 

242 condition_keys: Optional[dict[str, str]] = None, 

243) -> None: 

244 """ 

245 Copy simulation outputs from where they are saved to pipeline file structure. 

246 

247 Parameters 

248 ---------- 

249 bucket 

250 Name of S3 bucket for input and output files. 

251 series_name 

252 Name of simulation series. 

253 source_template 

254 Template string for source output files. 

255 n_replicates 

256 _Number of simulation replicates. 

257 condition_keys 

258 Map of source to target condition keys. 

259 """ 

260 

261 if condition_keys is None: 

262 condition_keys = {"": ""} 

263 

264 for index in range(n_replicates): 

265 for source_condition, target_condition in condition_keys.items(): 

266 if source_condition == "" and target_condition == "": 

267 source_key = source_template % (index) 

268 target_key = f"{series_name}/outputs/{series_name}_{index}.h5" 

269 else: 

270 source_key = source_template % (source_condition, index) 

271 target_key = ( 

272 f"{series_name}/outputs/{series_name}_{target_condition}_{index}.h5" 

273 ) 

274 

275 print(f"Copying [ {source_key} ] to [ {target_key} ]") 

276 copy_key(bucket, source_key, target_key)