2025-04-16 22:12:19 +02:00

444 lines
14 KiB
Python

#!/usr/bin/env python
#
# Copyright 2021, Heidelberg University Clinic
#
# File author(s): Sebastian Lobentanzer
# ...
#
# Distributed under MIT licence, see the file `LICENSE`.
#
"""
BioCypher get module. Used to download and cache data from external sources.
"""
from __future__ import annotations
from typing import Optional
import shutil
import requests
from ._logger import logger
logger.debug(f"Loading module {__name__}.")
from abc import ABC
from datetime import datetime, timedelta
from tempfile import TemporaryDirectory
import os
import json
import ftplib
import pooch
from ._misc import to_list, is_nested
class Resource(ABC):
def __init__(
self,
name: str,
url_s: str | list[str],
lifetime: int = 0,
):
"""
A Resource is a file, a list of files, an API request, or a list of API
requests, any of which can be downloaded from the given URL(s) and
cached locally. This class implements checks of the minimum requirements
for a resource, to be implemented by a biocypher adapter.
Args:
name (str): The name of the resource.
url_s (str | list[str]): The URL or URLs of the resource.
lifetime (int): The lifetime of the resource in days. If 0, the
resource is considered to be permanent.
"""
self.name = name
self.url_s = url_s
self.lifetime = lifetime
class FileDownload(Resource):
def __init__(
self,
name: str,
url_s: str | list[str],
lifetime: int = 0,
is_dir: bool = False,
):
"""
Represents basic information for a File Download.
Args:
name(str): The name of the File Download.
url_s(str|list[str]): The URL(s) of the File Download.
lifetime(int): The lifetime of the File Download in days. If 0, the
File Download is cached indefinitely.
is_dir (bool): Whether the URL points to a directory or not.
"""
super().__init__(name, url_s, lifetime)
self.is_dir = is_dir
class APIRequest(Resource):
def __init__(self, name: str, url_s: str | list[str], lifetime: int = 0):
"""
Represents basic information for an API Request.
Args:
name(str): The name of the API Request.
url_s(str|list): The URL of the API endpoint.
lifetime(int): The lifetime of the API Request in days. If 0, the
API Request is cached indefinitely.
"""
super().__init__(name, url_s, lifetime)
class Downloader:
def __init__(self, cache_dir: Optional[str] = None) -> None:
"""
The Downloader is a class that manages resources that can be downloaded
and cached locally. It manages the lifetime of downloaded resources by
keeping a JSON record of the download date of each resource.
Args:
cache_dir (str): The directory where the resources are cached. If
not given, a temporary directory is created.
"""
self.cache_dir = cache_dir or TemporaryDirectory().name
self.cache_file = os.path.join(self.cache_dir, "cache.json")
self.cache_dict = self._load_cache_dict()
def download(self, *resources: Resource):
"""
Download one or multiple resources. Load from cache if the resource is
already downloaded and the cache is not expired.
Args:
resources (Resource): The resource(s) to download or load from
cache.
Returns:
list[str]: The path or paths to the resource(s) that were downloaded
or loaded from cache.
"""
paths = []
for resource in resources:
paths.append(self._download_or_cache(resource))
# flatten list if it is nested
if is_nested(paths):
paths = [path for sublist in paths for path in sublist]
return paths
def _download_or_cache(self, resource: Resource, cache: bool = True):
"""
Download a resource if it is not cached or exceeded its lifetime.
Args:
resource (Resource): The resource to download.
Returns:
list[str]: The path or paths to the downloaded resource(s).
"""
expired = self._is_cache_expired(resource)
if expired or not cache:
self._delete_expired_cache(resource)
if isinstance(resource, FileDownload):
logger.info(f"Asking for download of resource {resource.name}.")
paths = self._download_files(cache, resource)
elif isinstance(resource, APIRequest):
logger.info(
f"Asking for download of api request {resource.name}."
)
paths = self._download_api_request(resource)
else:
raise TypeError(f"Unknown resource type: {type(resource)}")
else:
paths = self.get_cached_version(resource)
self._update_cache_record(resource)
return paths
def _is_cache_expired(self, resource: Resource) -> bool:
"""
Check if resource or API request cache is expired.
Args:
resource (Resource): The resource or API request to download.
Returns:
bool: True if cache is expired, False if not.
"""
cache_record = self._get_cache_record(resource)
if cache_record:
download_time = datetime.strptime(
cache_record.get("date_downloaded"), "%Y-%m-%d %H:%M:%S.%f"
)
lifetime = timedelta(days=resource.lifetime)
expired = download_time + lifetime < datetime.now()
else:
expired = True
return expired
def _delete_expired_cache(self, resource: Resource):
cache_resource_path = self.cache_dir + "/" + resource.name
if os.path.exists(cache_resource_path) and os.path.isdir(
cache_resource_path
):
shutil.rmtree(cache_resource_path)
def _download_files(self, cache, file_download: FileDownload):
"""
Download a resource given it is a file or a directory and return the
path.
Args:
cache (bool): Whether to cache the resource or not.
file_download (FileDownload): The resource to download.
Returns:
list[str]: The path or paths to the downloaded resource(s).
"""
if file_download.is_dir:
files = self._get_files(file_download)
file_download.url_s = [
file_download.url_s + "/" + file for file in files
]
file_download.is_dir = False
paths = self._download_or_cache(file_download, cache)
elif isinstance(file_download.url_s, list):
paths = []
for url in file_download.url_s:
fname = url[url.rfind("/") + 1 :].split("?")[0]
paths.append(
self._retrieve(
url=url,
fname=fname,
path=os.path.join(self.cache_dir, file_download.name),
)
)
else:
paths = []
fname = file_download.url_s[
file_download.url_s.rfind("/") + 1 :
].split("?")[0]
results = self._retrieve(
url=file_download.url_s,
fname=fname,
path=os.path.join(self.cache_dir, file_download.name),
)
if isinstance(results, list):
paths.extend(results)
else:
paths.append(results)
# sometimes a compressed file contains multiple files
# TODO ask for a list of files in the archive to be used from the
# adapter
return paths
def _download_api_request(self, api_request: APIRequest):
"""
Download an API request and return the path.
Args:
api_request(APIRequest): The API request result that is being
cached.
Returns:
list[str]: The path to the cached API request.
"""
urls = (
api_request.url_s
if isinstance(api_request.url_s, list)
else [api_request.url_s]
)
paths = []
for url in urls:
fname = url[url.rfind("/") + 1 :].rsplit(".", 1)[0]
logger.info(
f"Asking for caching API of {api_request.name} {fname}."
)
response = requests.get(url=url)
if response.status_code != 200:
response.raise_for_status()
response_data = response.json()
api_path = os.path.join(
self.cache_dir, api_request.name, f"{fname}.json"
)
os.makedirs(os.path.dirname(api_path), exist_ok=True)
with open(api_path, "w") as f:
json.dump(response_data, f)
logger.info(f"Caching API request to {api_path}.")
paths.append(api_path)
return paths
def get_cached_version(self, resource: Resource) -> list[str]:
"""Get the cached version of a resource.
Args:
resource(Resource): The resource to get the cached version of.
Returns:
list[str]: The paths to the cached resource(s).
"""
cached_location = os.path.join(self.cache_dir, resource.name)
logger.info(f"Use cached version from {cached_location}.")
paths = []
for file in os.listdir(cached_location):
paths.append(os.path.join(cached_location, file))
return paths
def _retrieve(
self,
url: str,
fname: str,
path: str,
known_hash: str = None,
):
"""
Retrieve a file from a URL using Pooch. Infer type of file from
extension and use appropriate processor.
Args:
url (str): The URL to retrieve the file from.
fname (str): The name of the file.
path (str): The path to the file.
"""
if fname.endswith(".zip"):
return pooch.retrieve(
url=url,
known_hash=known_hash,
fname=fname,
path=path,
processor=pooch.Unzip(),
progressbar=True,
)
elif fname.endswith(".tar.gz"):
return pooch.retrieve(
url=url,
known_hash=known_hash,
fname=fname,
path=path,
processor=pooch.Untar(),
progressbar=True,
)
elif fname.endswith(".gz"):
return pooch.retrieve(
url=url,
known_hash=known_hash,
fname=fname,
path=path,
processor=pooch.Decompress(),
progressbar=True,
)
else:
return pooch.retrieve(
url=url,
known_hash=known_hash,
fname=fname,
path=path,
progressbar=True,
)
def _get_files(self, file_download: FileDownload):
"""
Get the files contained in a directory file.
Args:
file_download (FileDownload): The directory file.
Returns:
list: The files contained in the directory.
"""
if file_download.url_s.startswith("ftp://"):
# remove protocol
url = file_download.url_s[6:]
# get base url
url = url[: url.find("/")]
# get directory (remove initial slash as well)
dir = file_download.url_s[7 + len(url) :]
# get files
ftp = ftplib.FTP(url)
ftp.login()
ftp.cwd(dir)
files = ftp.nlst()
ftp.quit()
else:
raise NotImplementedError(
"Only FTP directories are supported at the moment."
)
return files
def _load_cache_dict(self):
"""
Load the cache dictionary from the cache file. Create an empty cache
file if it does not exist.
"""
if not os.path.exists(self.cache_dir):
logger.info(f"Creating cache directory {self.cache_dir}.")
os.makedirs(self.cache_dir)
if not os.path.exists(self.cache_file):
logger.info(f"Creating cache file {self.cache_file}.")
with open(self.cache_file, "w") as f:
json.dump({}, f)
with open(self.cache_file, "r") as f:
logger.info(f"Loading cache file {self.cache_file}.")
return json.load(f)
def _get_cache_record(self, resource: Resource):
"""
Get the cache record of a resource.
Args:
resource (Resource): The resource to get the cache record of.
Returns:
The cache record of the resource.
"""
return self.cache_dict.get(resource.name, {})
def _update_cache_record(self, resource: Resource):
"""
Update the cache record of a resource.
Args:
resource (Resource): The resource to update the cache record of.
"""
cache_record = {}
cache_record["url"] = to_list(resource.url_s)
cache_record["date_downloaded"] = str(datetime.now())
cache_record["lifetime"] = resource.lifetime
self.cache_dict[resource.name] = cache_record
with open(self.cache_file, "w") as f:
json.dump(self.cache_dict, f, default=str)