Source code for jobflow_remote.jobs.run

from __future__ import annotations

import datetime
import glob
import logging
import os
import subprocess
import threading
import time
import traceback
from datetime import timezone
from multiprocessing import Manager, Process
from typing import TYPE_CHECKING

from jobflow import JobStore, initialize_logger
from jobflow.core.flow import get_flow
from monty.design_patterns import singleton
from monty.os import cd
from monty.serialization import dumpfn, loadfn
from monty.shutil import decompress_file

from jobflow_remote.jobs.batch import LocalBatchManager
from jobflow_remote.jobs.data import (
    BATCH_INFO_FILENAME,
    IN_FILENAME,
    OUT_FILENAME,
    JobDoc,
)
from jobflow_remote.remote.data import get_job_path, get_store_file_paths
from jobflow_remote.utils.log import initialize_remote_run_log

if TYPE_CHECKING:
    from pathlib import Path

    from jobflow.core.job import Job

logger = logging.getLogger(__name__)


@singleton
class JfrState:
    """State of the current job being executed."""

    job_doc: JobDoc = None

    def reset(self):
        """Reset the current state."""
        self.job_doc = None


CURRENT_JOBDOC: JfrState = JfrState()


[docs] def run_remote_job(run_dir: str | Path = ".") -> None: """Run the job.""" initialize_remote_run_log() start_time = datetime.datetime.now(timezone.utc) with cd(run_dir): error = None try: dumpfn({"start_time": start_time}, OUT_FILENAME) in_data = loadfn(IN_FILENAME) job: Job = in_data["job"] store = in_data["store"] job_doc_dict = in_data.get("job_doc", None) if isinstance(job.function, dict): raise RuntimeError( # noqa: TRY004,TRY301 f"The function in the Job could not be deserialized: {job.function}.\n" "Check if this function is actually available in the worker's python environment" ) if job_doc_dict: job_doc_dict["job"] = job JfrState().job_doc = JobDoc.model_validate(job_doc_dict) store.connect() initialize_logger() try: response = job.run(store=store) finally: # some jobs may have compressed the jfremote and store files while being # executed, try to decompress them if that is the case and files need to be # decompressed. decompress_files(store) # Close the store explicitly, as minimal stores may require it. try: store.close() except Exception: logger.exception("Error while closing the store") # The output of the response has already been stored in the store. response.output = None # Convert to Flow the dynamic responses before dumping the output. # This is required so that the response does not need to be # deserialized and converted to Flows by the runner. if response.addition: response.addition = get_flow(response.addition) if response.detour: response.detour = get_flow(response.detour) if response.replace: response.replace = get_flow(response.replace) output = { "response": response, "error": error, "start_time": start_time, "end_time": datetime.datetime.now(timezone.utc), } dumpfn(output, OUT_FILENAME) except Exception: # replicate the dump to catch potential errors in # serializing/dumping the response. error = traceback.format_exc() output = { "response": None, "error": error, "start_time": start_time, "end_time": datetime.datetime.now(timezone.utc), } dumpfn(output, OUT_FILENAME) finally: JfrState().reset()
[docs] def ping(start_time, interval=600, filename=BATCH_INFO_FILENAME): while True: dumpfn( { "start_time": start_time, "last_ping_time": datetime.datetime.now(timezone.utc), }, fn=filename, ) time.sleep(interval)
PING_TIME = 600
[docs] def run_batch_jobs( base_run_dir: str | Path, files_dir: str | Path, batch_uid: str, max_time: float | None = None, max_wait: float = 60, max_jobs: int | None = None, parallel_jobs: int | None = None, sleep_time: float = None, batch_info_fname: str | Path = BATCH_INFO_FILENAME, ) -> None: # Here we assume that we are in the batch work directory where a batch process is executed/submitted start_time = datetime.datetime.now(timezone.utc) dumpfn({"start_time": start_time}, batch_info_fname) threading.Thread( target=ping, args=(start_time, PING_TIME, batch_info_fname), # dump every 600 seconds daemon=True, ).start() parallel_jobs = parallel_jobs or 1 if parallel_jobs == 1: run_single_batch_jobs( base_run_dir=base_run_dir, files_dir=files_dir, batch_uid=batch_uid, max_time=max_time, max_wait=max_wait, max_jobs=max_jobs, sleep_time=sleep_time, ) else: with Manager() as manager: multiprocess_lock = manager.Lock() parallel_ids = manager.dict() batch_manager = LocalBatchManager( files_dir=files_dir, batch_uid=batch_uid, multiprocess_lock=multiprocess_lock, ) processes = [ Process( target=run_single_batch_jobs, args=( base_run_dir, files_dir, batch_uid, max_time, max_wait, max_jobs, batch_manager, parallel_ids, sleep_time, ), ) for _ in range(parallel_jobs) ] for p in processes: p.start() time.sleep(0.5) for p in processes: p.join() dumpfn( { "start_time": start_time, "last_ping_time": datetime.datetime.now(timezone.utc), "end_time": datetime.datetime.now(timezone.utc), }, batch_info_fname, )
[docs] def run_single_batch_jobs( base_run_dir: str | Path, files_dir: str | Path, batch_uid: str, max_time: float | None = None, max_wait: float = 60, max_jobs: int | None = None, batch_manager: LocalBatchManager | None = None, parallel_ids: dict | None = None, sleep_time: float = None, ) -> None: initialize_remote_run_log() # TODO the ID should be somehow linked to the queue job if not batch_manager: batch_manager = LocalBatchManager(files_dir=files_dir, batch_uid=batch_uid) if parallel_ids: parallel_ids[os.getpid()] = False t0 = time.time() wait = 0.0 sleep_time = sleep_time or 10.0 count = 0 while True: if max_time and max_time < time.time() - t0: logger.info("Stopping due to max_time") return if max_wait and wait > max_wait: # if many jobs run in parallel do not shut down here, unless all # the other jobs are also stopped if parallel_ids: for pid, pid_is_running in parallel_ids.items(): if pid_is_running: try: os.kill(pid, 0) # throws OSError if the process is dead except OSError: # means this process is dead! parallel_ids[pid] = False if not any(parallel_ids.values()): logger.info( f"No jobs available for more than {max_wait} seconds and all other jobs are stopped. Stopping." ) return else: logger.info( f"No jobs available for more than {max_wait} seconds. Stopping." ) return if max_jobs and count >= max_jobs: logger.info(f"Maximum number of jobs reached ({max_jobs}). Stopping.") return job_str = batch_manager.get_job() if not job_str: time.sleep(sleep_time) wait += sleep_time else: wait = 0.0 count += 1 job_id, _index = job_str.split("_") index: int = int(_index) logger.info(f"Starting job with id {job_id} and index {index}") job_path = get_job_path(job_id=job_id, index=index, base_path=base_run_dir) if parallel_ids: parallel_ids[os.getpid()] = True try: with cd(job_path): result = subprocess.run( ["bash", "submit.sh"], # noqa: S603, S607 check=True, text=True, capture_output=True, ) if result.returncode: logger.warning( f"Process for job with id {job_id} and index {index} finished with an error" ) batch_manager.set_job_finished(job_id, index) except Exception: logger.exception( "Error while running job with id {job_id} and index {index}" ) else: logger.info(f"Completed job with id {job_id} and index {index}") if parallel_ids: parallel_ids[os.getpid()] = False
[docs] def decompress_files(store: JobStore) -> None: file_names = [OUT_FILENAME] file_names.extend(os.path.basename(p) for p in get_store_file_paths(store)) for fn in file_names: # If the file is already present do not decompress it, even if # a compressed version is present. if os.path.isfile(fn): continue for f in glob.glob(fn + ".*"): decompress_file(f)