from __future__ import annotations
import datetime
import inspect
import io
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any
import orjson
from jobflow.core.job import Job
from jobflow.core.store import JobStore
from maggma.core import Sort, Store
from maggma.stores import JSONStore
from maggma.utils import to_dt
from monty.io import zopen
# from maggma.stores.mongolike import JSONStore
from monty.json import MontyDecoder, jsanitize
from jobflow_remote.jobs.data import RemoteError
from jobflow_remote.utils.data import uuid_to_path
if TYPE_CHECKING:
from collections.abc import Iterator
JOB_INIT_ARGS = {k for k in inspect.signature(Job).parameters if k != "kwargs"}
"""A set of the arguments of the Job constructor which
can be used to detect additional custom arguments
"""
[docs]
def get_job_path(
job_id: str, index: int | None, base_path: str | Path | None = None
) -> str:
base_path = Path(base_path) if base_path else Path()
relative_path = uuid_to_path(job_id, index)
return str(base_path / relative_path)
[docs]
def get_remote_in_file(job, remote_store):
d = jsanitize(
{"job": job, "store": remote_store},
strict=True,
allow_bson=True,
enum_values=True,
)
return io.BytesIO(orjson.dumps(d, default=default_orjson_serializer))
[docs]
def default_orjson_serializer(obj: Any) -> Any:
type_obj = type(obj)
if type_obj != float and issubclass(type_obj, float):
return float(obj)
raise TypeError
[docs]
def get_remote_store(
store: JobStore, work_dir: str | Path, config_dict: dict | None
) -> JobStore:
docs_store = get_single_store(
config_dict=config_dict, file_name="remote_job_data", dir_path=work_dir
)
additional_stores = {}
for k in store.additional_stores:
additional_stores[k] = get_single_store(
config_dict=config_dict,
file_name=f"additional_store_{k}",
dir_path=work_dir,
)
return JobStore(
docs_store=docs_store,
additional_stores=additional_stores,
save=store.save,
load=store.load,
)
default_remote_store = {"store": "maggma_json", "zip": False}
[docs]
def get_single_store(
config_dict: dict | None, file_name: str, dir_path: str | Path
) -> Store:
config_dict = config_dict or default_remote_store
store_type = config_dict.get("store", default_remote_store["store"])
total_file_name = get_single_store_file_name(config_dict, file_name)
file_path = os.path.join(dir_path, total_file_name)
if store_type == "maggma_json":
return StdJSONStore(file_path)
if store_type == "orjson":
return MinimalORJSONStore(file_path)
if store_type == "msgspec_json":
return MinimalMsgspecJSONStore(file_path)
if store_type == "msgpack":
return MinimalMsgpackStore(file_path)
if isinstance(store_type, dict):
store_type = dict(store_type)
store_type["path"] = file_path
store = MontyDecoder().process_decoded(store_type)
if not isinstance(store, Store):
raise TypeError(
f"Could not instantiate a proper store from remote config dict {store_type}"
)
return store
raise ValueError(f"remote store type not supported: {store_type}")
[docs]
def get_single_store_file_name(config_dict: dict | None, file_name: str) -> str:
config_dict = config_dict or default_remote_store
store_type = config_dict.get("store", default_remote_store["store"])
if isinstance(store_type, str) and "json" in store_type:
ext = "json"
elif isinstance(store_type, str) and "msgpack" in store_type:
ext = "msgpack"
else:
ext = config_dict.get("extension") # type: ignore
if not ext:
raise ValueError(
f"Could not determine extension for remote store config dict: {config_dict}"
)
total_file_name = f"{file_name}.{ext}"
if config_dict.get("zip", False):
total_file_name += ".gz"
return total_file_name
[docs]
def get_remote_store_filenames(store: JobStore, config_dict: dict | None) -> list[str]:
return [
get_single_store_file_name(config_dict=config_dict, file_name="remote_job_data")
] + [
get_single_store_file_name(
config_dict=config_dict, file_name=f"additional_store_{k}"
)
for k in store.additional_stores
]
[docs]
def get_store_file_paths(store: JobStore) -> list[str]:
def get_single_path(base_store: Store):
paths = getattr(base_store, "paths", None)
if paths:
return paths[0]
path = getattr(base_store, "path", None)
if not path:
raise RuntimeError(f"Could not determine the path for {base_store}")
return path
store_paths = [get_single_path(store.docs_store)]
store_paths.extend(get_single_path(s) for s in store.additional_stores.values())
return store_paths
[docs]
def update_store(store: JobStore, remote_store: JobStore, db_id: int) -> None:
try:
store.connect()
remote_store.connect()
additional_stores = set(store.additional_stores)
additional_remote_stores = set(remote_store.additional_stores)
# This checks that the additional stores in the two stores match correctly.
# It should not happen if not because of a bug, so the check could maybe be
# removed
if additional_stores ^ additional_remote_stores:
raise ValueError(
f"The additional stores in the local and remote JobStore do not "
f"match: {additional_stores ^ additional_remote_stores}"
)
# copy the data store by store, not using directly the JobStore.
# This avoids the need to deserialize the store content and the "save"
# argument.
for add_store_name, remote_add_store in remote_store.additional_stores.items():
add_store = store.additional_stores[add_store_name]
for d in remote_add_store.query():
data = dict(d)
data.pop("_id", None)
add_store.update(data)
main_docs_list = list(remote_store.docs_store.query({}))
if len(main_docs_list) > 1:
raise RuntimeError(
"The downloaded output store contains more than one document"
)
main_doc = main_docs_list[0]
main_doc.pop("_id", None)
# Set the db_id here and not directly in the Job's metadata to prevent
# it from being propagated to its children/replacements.
if "db_id" not in main_doc["metadata"]:
main_doc["metadata"]["db_id"] = db_id
store.docs_store.update(main_doc, key=["uuid", "index"])
finally:
try:
store.close()
except Exception:
logging.exception(f"error while closing the store {store}")
try:
remote_store.close()
except Exception:
logging.exception(f"error while closing the remote store {remote_store}")
[docs]
def resolve_job_dict_args(job_dict: dict, store: JobStore) -> dict:
"""
Resolve the references in a serialized Job.
Similar to Job.resolve_args, but without the need to deserialize the Job.
The references are resolved inplace.
Parameters
----------
job_dict
The serialized version of a Job.
store
The JobStore from where the references should be resolved.
Returns
-------
The updated version of the input dictionary with references resolved.
"""
from jobflow.core.reference import OnMissing, find_and_resolve_references
on_missing = OnMissing(job_dict["config"]["on_missing_references"])
cache: dict[str, Any] = {}
resolved_args = find_and_resolve_references(
job_dict["function_args"], store, cache=cache, on_missing=on_missing
)
resolved_kwargs = find_and_resolve_references(
job_dict["function_kwargs"], store, cache=cache, on_missing=on_missing
)
resolved_args = tuple(resolved_args)
# substitution is in place
job_dict["function_args"] = resolved_args
job_dict["function_kwargs"] = resolved_kwargs
missing_stores = check_additional_stores(job_dict, store)
if missing_stores:
raise RemoteError(
f"Additional stores {missing_stores!r} are not configured for this project.",
no_retry=True,
)
return job_dict
[docs]
def check_additional_stores(job: dict | Job, store: JobStore) -> list[str]:
"""
Check if all the required additional stores have been defined in
the output JobStore. If some are missing return the names of the missing Stores.
Parameters
----------
job
A Job or its serialized version.
store
The JobStore from where the references should be resolved.
Returns
-------
The list of names of the missing additional stores.
An empty list if no store is missing.
"""
if isinstance(job, dict):
additional_store_names = set(job) - JOB_INIT_ARGS
else:
# TODO expose the _kwargs attribute in jobflow through an
# "additional_stores" property
additional_store_names = set(job._kwargs)
missing_stores = []
for store_name in additional_store_names:
# Exclude MSON fields
if store_name.startswith("@"):
continue
if store_name not in store.additional_stores:
missing_stores.append(store_name)
return missing_stores
[docs]
class StdJSONStore(JSONStore):
"""
Simple subclass of the JSONStore defining the serialization_default
that cannot be dumped to json.
"""
def __init__(self, paths, **kwargs) -> None:
super().__init__(
paths=paths,
serialization_default=default_orjson_serializer,
read_only=False,
**kwargs,
)
[docs]
class MinimalFileStore(Store):
"""
A Minimal Store for access to a single file.
Only methods required by jobflow-remote are implemented.
"""
@property
def _collection(self):
raise NotImplementedError
[docs]
def close(self) -> None:
self.update_file()
[docs]
def count(self, criteria: dict | None = None) -> int:
return len(self.data)
[docs]
def query(
self,
criteria: dict | None = None,
properties: dict | list | None = None,
sort: dict[str, Sort | int] | None = None,
skip: int = 0,
limit: int = 0,
) -> Iterator[dict]:
if criteria or properties or sort or skip or sort:
raise NotImplementedError(
"Query only implemented to return the whole set of docs"
)
return iter(self.data)
[docs]
def ensure_index(self, key: str, unique: bool = False) -> bool:
raise NotImplementedError
[docs]
def groupby(
self,
keys: list[str] | str,
criteria: dict | None = None,
properties: dict | list | None = None,
sort: dict[str, Sort | int] | None = None,
skip: int = 0,
limit: int = 0,
) -> Iterator[tuple[dict, list[dict]]]:
raise NotImplementedError
def __init__(
self,
path: str,
**kwargs,
) -> None:
"""
Args:
path: paths for json files to turn into a Store.
"""
self.path = path
self.kwargs = kwargs
self.default_sort = None
self.data: list[dict] = []
super().__init__(**kwargs)
[docs]
def connect(self, force_reset: bool = False) -> None:
"""Loads the files into the collection in memory."""
# create the .json file if it does not exist
if not Path(self.path).exists():
self.update_file()
else:
self.data = self.read_file()
[docs]
def update(self, docs: list[dict] | dict, key: list | str | None = None) -> None:
"""
Update documents into the Store.
For a file-writable JSONStore, the json file is updated.
Args:
docs: the document or list of documents to update
key: field name(s) to determine uniqueness for a
document, can be a list of multiple fields,
a single field, or None if the Store's key
field is to be used
"""
if not isinstance(docs, (list, tuple)):
docs = [docs]
self.data.extend(docs)
[docs]
def update_file(self):
raise NotImplementedError
[docs]
def read_file(self) -> list:
raise NotImplementedError
[docs]
def remove_docs(self, criteria: dict):
"""
Remove docs matching the query dictionary.
For a file-writable JSONStore, the json file is updated.
Args:
criteria: query dictionary to match
"""
raise NotImplementedError
def __hash__(self):
return hash((*self.path, self.last_updated_field))
def __eq__(self, other: object) -> bool:
"""
Check equality for JSONStore.
Args:
other: other JSONStore to compare with
"""
if not isinstance(other, type(self)):
return False
fields = ["path", "last_updated_field"]
return all(getattr(self, f) == getattr(other, f) for f in fields)
[docs]
class MinimalORJSONStore(MinimalFileStore):
@property
def name(self) -> str:
return f"json://{self.path}"
[docs]
def update_file(self) -> None:
"""Updates the json file when a write-like operation is performed."""
with zopen(self.path, "wb") as f:
for d in self.data:
d.pop("_id", None)
bytesdata = orjson.dumps(
self.data,
default=default_orjson_serializer,
)
f.write(bytesdata)
[docs]
def read_file(self) -> list:
"""
Helper method to read the contents of a JSON file and generate
a list of docs.
"""
with zopen(self.path, "rb") as f:
data = f.read()
if not data:
return []
objects = orjson.loads(data)
objects = [objects] if not isinstance(objects, list) else objects
# datetime objects deserialize to str. Try to convert the last_updated
# field back to datetime.
# # TODO - there may still be problems caused if a JSONStore is init'ed from
# documents that don't contain a last_updated field
# See Store.last_updated in store.py.
for obj in objects:
if obj.get(self.last_updated_field):
obj[self.last_updated_field] = to_dt(obj[self.last_updated_field])
return objects
[docs]
class MinimalMsgspecJSONStore(MinimalFileStore):
@property
def name(self) -> str:
return f"json://{self.path}"
[docs]
def update_file(self) -> None:
"""Updates the json file when a write-like operation is performed."""
import msgspec
with zopen(self.path, "wb") as f:
for d in self.data:
d.pop("_id", None)
bytesdata = msgspec.json.encode(
self.data,
)
f.write(bytesdata)
[docs]
def read_file(self) -> list:
"""
Helper method to read the contents of a JSON file and generate
a list of docs.
"""
import msgspec
with zopen(self.path, "rb") as f:
data = f.read()
if not data:
return []
objects = msgspec.json.decode(data)
objects = [objects] if not isinstance(objects, list) else objects
# datetime objects deserialize to str. Try to convert the last_updated
# field back to datetime.
# # TODO - there may still be problems caused if a JSONStore is init'ed from
# documents that don't contain a last_updated field
# See Store.last_updated in store.py.
for obj in objects:
if obj.get(self.last_updated_field):
obj[self.last_updated_field] = to_dt(obj[self.last_updated_field])
return objects
[docs]
def decode_datetime(obj):
if "__datetime__" in obj:
obj = datetime.datetime.strptime(obj["as_str"], "%Y%m%dT%H:%M:%S.%f")
return obj
[docs]
def encode_datetime(obj):
if isinstance(obj, datetime.datetime):
return {"__datetime__": True, "as_str": obj.strftime("%Y%m%dT%H:%M:%S.%f")}
return obj
[docs]
class MinimalMsgpackStore(MinimalFileStore):
@property
def name(self) -> str:
return f"msgpack://{self.path}"
[docs]
def update_file(self) -> None:
"""Updates the msgpack file when a write-like operation is performed."""
import msgpack
with zopen(self.path, "wb") as f:
msgpack.pack(self.data, f, default=encode_datetime, use_bin_type=True)
[docs]
def read_file(self) -> list:
"""
Helper method to read the contents of a msgpack file and generate
a list of docs.
"""
import msgpack
with zopen(self.path, "rb") as f:
objects = msgpack.unpack(f, object_hook=decode_datetime, raw=False)
objects = [objects] if not isinstance(objects, list) else objects
# datetime objects deserialize to str. Try to convert the last_updated
# field back to datetime.
# # TODO - there may still be problems caused if a JSONStore is init'ed from
# documents that don't contain a last_updated field
# See Store.last_updated in store.py.
for obj in objects:
if obj.get(self.last_updated_field):
obj[self.last_updated_field] = to_dt(obj[self.last_updated_field])
return objects