2025-04-16 22:12:19 +02:00

114 lines
3.6 KiB
Python

#!/usr/bin/env python
#
# Copyright 2021, Heidelberg University Clinic
#
# File author(s): Sebastian Lobentanzer
# Michael Hartung
#
# Distributed under MIT licence, see the file `LICENSE`.
#
"""
BioCypher 'offline' module. Handles the writing of node and edge representations
suitable for import into a DBMS.
"""
from biocypher._logger import logger
from biocypher.output.write.graph._rdf import _RDFWriter
from biocypher.output.write.graph._neo4j import _Neo4jBatchWriter
from biocypher.output.write.graph._arangodb import _ArangoDBBatchWriter
from biocypher.output.write.graph._networkx import _NetworkXWriter
from biocypher.output.write.relational._csv import _PandasCSVWriter
from biocypher.output.write.relational._sqlite import _SQLiteBatchWriter
from biocypher.output.write.relational._postgresql import _PostgreSQLBatchWriter
logger.debug(f"Loading module {__name__}.")
from typing import TYPE_CHECKING
from biocypher._config import config as _config
__all__ = ["get_writer", "DBMS_TO_CLASS"]
if TYPE_CHECKING:
from biocypher._translate import Translator
from biocypher._deduplicate import Deduplicator
DBMS_TO_CLASS = {
"neo": _Neo4jBatchWriter,
"neo4j": _Neo4jBatchWriter,
"Neo4j": _Neo4jBatchWriter,
"postgres": _PostgreSQLBatchWriter,
"postgresql": _PostgreSQLBatchWriter,
"PostgreSQL": _PostgreSQLBatchWriter,
"arango": _ArangoDBBatchWriter,
"arangodb": _ArangoDBBatchWriter,
"ArangoDB": _ArangoDBBatchWriter,
"sqlite": _SQLiteBatchWriter,
"sqlite3": _SQLiteBatchWriter,
"rdf": _RDFWriter,
"RDF": _RDFWriter,
"csv": _PandasCSVWriter,
"CSV": _PandasCSVWriter,
"pandas": _PandasCSVWriter,
"Pandas": _PandasCSVWriter,
"networkx": _NetworkXWriter,
"NetworkX": _NetworkXWriter,
}
def get_writer(
dbms: str,
translator: "Translator",
deduplicator: "Deduplicator",
output_directory: str,
strict_mode: bool,
):
"""
Function to return the writer class based on the selection in the config
file.
Args:
dbms: the database management system; for options, see DBMS_TO_CLASS.
translator: the Translator object.
deduplicator: the Deduplicator object.
output_directory: the directory to output.write the output files to.
strict_mode: whether to use strict mode.
Returns:
instance: an instance of the selected writer class.
"""
dbms_config = _config(dbms)
writer = DBMS_TO_CLASS[dbms]
if not writer:
raise ValueError(f"Unknown dbms: {dbms}")
if writer is not None:
return writer(
translator=translator,
deduplicator=deduplicator,
delimiter=dbms_config.get("delimiter"),
array_delimiter=dbms_config.get("array_delimiter"),
quote=dbms_config.get("quote_character"),
output_directory=output_directory,
db_name=dbms_config.get("database_name"),
import_call_bin_prefix=dbms_config.get("import_call_bin_prefix"),
import_call_file_prefix=dbms_config.get("import_call_file_prefix"),
wipe=dbms_config.get("wipe"),
strict_mode=strict_mode,
skip_bad_relationships=dbms_config.get(
"skip_bad_relationships"
), # neo4j
skip_duplicate_nodes=dbms_config.get(
"skip_duplicate_nodes"
), # neo4j
db_user=dbms_config.get("user"), # psql
db_password=dbms_config.get("password"), # psql
db_port=dbms_config.get("port"), # psql
rdf_format=dbms_config.get("rdf_format"), # rdf
rdf_namespaces=dbms_config.get("rdf_namespaces"), # rdf
)