from abc import ABC, abstractmethod from types import GeneratorType from typing import Union, Optional from collections import OrderedDict, defaultdict import os import re import glob from more_itertools import peekable from biocypher._create import BioCypherEdge, BioCypherNode, BioCypherRelAsNode from biocypher._logger import logger from biocypher._translate import Translator from biocypher._deduplicate import Deduplicator from biocypher.output.write._writer import _Writer class _BatchWriter(_Writer, ABC): """Abstract batch writer class""" @abstractmethod def _get_default_import_call_bin_prefix(self): """ Abstract method to provide the default string for the import call bin prefix. Returns: str: The database-specific string for the path to the import call bin prefix """ raise NotImplementedError( "Database writer must override '_get_default_import_call_bin_prefix'" ) @abstractmethod def _write_array_string(self, string_list): """ Abstract method to write the string representation of an array into a .csv file. Different databases require different formats of array to optimize import speed. Args: string_list (list): list of ontology strings Returns: str: The database-specific string representation of an array """ raise NotImplementedError( "Database writer must override '_write_array_string'" ) @abstractmethod def _write_node_headers(self): """ Abstract method that takes care of importing properties of a graph entity that is represented as a node as per the definition in the `schema_config.yaml` Returns: bool: The return value. True for success, False otherwise. """ raise NotImplementedError( "Database writer must override '_write_node_headers'" ) @abstractmethod def _write_edge_headers(self): """ Abstract method to write a database import-file for a graph entity that is represented as an edge as per the definition in the `schema_config.yaml`, containing only the header for this type of edge. Returns: bool: The return value. True for success, False otherwise. """ raise NotImplementedError( "Database writer must override '_write_edge_headers'" ) @abstractmethod def _construct_import_call(self) -> str: """ Function to construct the import call detailing folder and individual node and edge headers and data files, as well as delimiters and database name. Built after all data has been processed to ensure that nodes are called before any edges. Returns: str: A bash command for csv import. """ raise NotImplementedError( "Database writer must override '_construct_import_call'" ) @abstractmethod def _get_import_script_name(self) -> str: """ Returns the name of the import script. The name will be chosen based on the used database. Returns: str: The name of the import script (ending in .sh) """ raise NotImplementedError( "Database writer must override '_get_import_script_name'" ) def __init__( self, translator: "Translator", deduplicator: "Deduplicator", delimiter: str, array_delimiter: str = ",", quote: str = '"', output_directory: Optional[str] = None, db_name: str = "neo4j", import_call_bin_prefix: Optional[str] = None, import_call_file_prefix: Optional[str] = None, wipe: bool = True, strict_mode: bool = False, skip_bad_relationships: bool = False, skip_duplicate_nodes: bool = False, db_user: str = None, db_password: str = None, db_host: str = None, db_port: str = None, rdf_format: str = None, rdf_namespaces: dict = {}, ): """ Abtract parent class for writing node and edge representations to disk using the format specified by each database type. The database-specific functions are implemented by the respective child-classes. This abstract class contains all methods expected by a bach writer instance, some of which need to be overwritten by the child classes. Each batch writer instance has a fixed representation that needs to be passed at instantiation via the :py:attr:`schema` argument. The instance also expects an ontology adapter via :py:attr:`ontology_adapter` to be able to convert and extend the hierarchy. Requires the following methods to be overwritten by database-specific writer classes: - _write_node_headers - _write_edge_headers - _construct_import_call - _write_array_string - _get_import_script_name Args: translator: Instance of :py:class:`Translator` to enable translation of nodes and manipulation of properties. deduplicator: Instance of :py:class:`Deduplicator` to enable deduplication of nodes and edges. delimiter: The delimiter to use for the CSV files. array_delimiter: The delimiter to use for array properties. quote: The quote character to use for the CSV files. output_directory: Path for exporting CSV files. db_name: Name of the database that will be used in the generated commands. import_call_bin_prefix: Path prefix for the admin import call binary. import_call_file_prefix: Path prefix for the data files (headers and parts) in the import call. wipe: Whether to force import (removing existing DB content). (Specific to Neo4j.) strict_mode: Whether to enforce source, version, and license properties. skip_bad_relationships: Whether to skip relationships that do not have a valid start and end node. (Specific to Neo4j.) skip_duplicate_nodes: Whether to skip duplicate nodes. (Specific to Neo4j.) db_user: The database user. db_password: The database password. db_host: The database host. Defaults to localhost. db_port: The database port. rdf_format: The format of RDF. rdf_namespaces: The namespaces for RDF. """ super().__init__( translator=translator, deduplicator=deduplicator, output_directory=output_directory, strict_mode=strict_mode, ) self.db_name = db_name self.db_user = db_user self.db_password = db_password self.db_host = db_host or "localhost" self.db_port = db_port self.rdf_format = rdf_format self.rdf_namespaces = rdf_namespaces self.delim, self.escaped_delim = self._process_delimiter(delimiter) self.adelim, self.escaped_adelim = self._process_delimiter( array_delimiter ) self.quote = quote self.skip_bad_relationships = skip_bad_relationships self.skip_duplicate_nodes = skip_duplicate_nodes if import_call_bin_prefix is None: self.import_call_bin_prefix = ( self._get_default_import_call_bin_prefix() ) else: self.import_call_bin_prefix = import_call_bin_prefix self.wipe = wipe self.strict_mode = strict_mode self.translator = translator self.deduplicator = deduplicator self.node_property_dict = {} self.edge_property_dict = {} self.import_call_nodes = set() self.import_call_edges = set() self.outdir = output_directory self._import_call_file_prefix = import_call_file_prefix self.parts = {} # dict to store the paths of part files for each label # TODO not memory efficient, but should be fine for most cases; is # there a more elegant solution? @property def import_call_file_prefix(self): """ Property for output directory path. """ if self._import_call_file_prefix is None: return self.outdir else: return self._import_call_file_prefix def _process_delimiter(self, delimiter: str) -> str: """ Return escaped characters in case of receiving their string representation (e.g. tab for '\t'). """ if delimiter == "\\t": return "\t", "\\t" else: return delimiter, delimiter def write_nodes( self, nodes, batch_size: int = int(1e6), force: bool = False ): """ Wrapper for writing nodes and their headers. Args: nodes (BioCypherNode): a list or generator of nodes in :py:class:`BioCypherNode` format batch_size (int): The batch size for writing nodes. force (bool): Whether to force writing nodes even if their type is not present in the schema. Returns: bool: The return value. True for success, False otherwise. """ # TODO check represented_as # write node data passed = self._write_node_data(nodes, batch_size, force) if not passed: logger.error("Error while writing node data.") return False # pass property data to header writer per node type written passed = self._write_node_headers() if not passed: logger.error("Error while writing node headers.") return False return True def write_edges( self, edges: Union[list, GeneratorType], batch_size: int = int(1e6), ) -> bool: """ Wrapper for writing edges and their headers. Args: edges (BioCypherEdge): a list or generator of edges in :py:class:`BioCypherEdge` or :py:class:`BioCypherRelAsNode` format Returns: bool: The return value. True for success, False otherwise. """ passed = False edges = list(edges) # force evaluation to handle empty generator if edges: nodes_flat = [] edges_flat = [] for edge in edges: if isinstance(edge, BioCypherRelAsNode): # check if relationship has already been written, if so skip if self.deduplicator.rel_as_node_seen(edge): continue nodes_flat.append(edge.get_node()) edges_flat.append(edge.get_source_edge()) edges_flat.append(edge.get_target_edge()) else: # check if relationship has already been written, if so skip if self.deduplicator.edge_seen(edge): continue edges_flat.append(edge) if nodes_flat and edges_flat: passed = self.write_nodes(nodes_flat) and self._write_edge_data( edges_flat, batch_size, ) else: passed = self._write_edge_data(edges_flat, batch_size) else: # is this a problem? if the generator or list is empty, we # don't write anything. logger.debug( "No edges to write, possibly due to no matched Biolink classes.", ) pass if not passed: logger.error("Error while writing edge data.") return False # pass property data to header writer per edge type written passed = self._write_edge_headers() if not passed: logger.error("Error while writing edge headers.") return False return True def _write_node_data(self, nodes, batch_size, force: bool = False): """ Writes biocypher nodes to CSV conforming to the headers created with `_write_node_headers()`, and is actually required to be run before calling `_write_node_headers()` to set the :py:attr:`self.node_property_dict` for passing the node properties to the instance. Expects list or generator of nodes from the :py:class:`BioCypherNode` class. Args: nodes (BioCypherNode): a list or generator of nodes in :py:class:`BioCypherNode` format Returns: bool: The return value. True for success, False otherwise. """ if isinstance(nodes, GeneratorType) or isinstance(nodes, peekable): logger.debug("Writing node CSV from generator.") bins = defaultdict(list) # dict to store a list for each # label that is passed in bin_l = {} # dict to store the length of each list for # batching cutoff reference_props = defaultdict( dict, ) # dict to store a dict of properties # for each label to check for consistency and their type # for now, relevant for `int` labels = {} # dict to store the additional labels for each # primary graph constituent from biolink hierarchy for node in nodes: # check if node has already been written, if so skip if self.deduplicator.node_seen(node): continue _id = node.get_id() label = node.get_label() # check for non-id if not _id: logger.warning(f"Node {label} has no id; skipping.") continue if not label in bins.keys(): # start new list all_labels = None bins[label].append(node) bin_l[label] = 1 # get properties from config if present if ( label in self.translator.ontology.mapping.extended_schema ): cprops = self.translator.ontology.mapping.extended_schema.get( label ).get( "properties", ) else: cprops = None if cprops: d = dict(cprops) # add id and preferred id to properties; these are # created in node creation (`_create.BioCypherNode`) d["id"] = "str" d["preferred_id"] = "str" # add strict mode properties if self.strict_mode: d["source"] = "str" d["version"] = "str" d["licence"] = "str" else: d = dict(node.get_properties()) # encode property type for k, v in d.items(): if d[k] is not None: d[k] = type(v).__name__ # else use first encountered node to define properties for # checking; could later be by checking all nodes but much # more complicated, particularly involving batch writing # (would require "do-overs"). for now, we output a warning # if node properties diverge from reference properties (in # write_single_node_list_to_file) TODO if it occurs, ask # user to select desired properties and restart the process reference_props[label] = d # get label hierarchy # multiple labels: if not force: all_labels = self.translator.ontology.get_ancestors( label ) else: all_labels = None if all_labels: # convert to pascal case all_labels = [ self.translator.name_sentence_to_pascal(label) for label in all_labels ] # remove duplicates all_labels = list(OrderedDict.fromkeys(all_labels)) # order alphabetically all_labels.sort() # concatenate with array delimiter all_labels = self._write_array_string(all_labels) else: all_labels = self.translator.name_sentence_to_pascal( label ) labels[label] = all_labels else: # add to list bins[label].append(node) bin_l[label] += 1 if not bin_l[label] < batch_size: # batch size controlled here passed = self._write_single_node_list_to_file( bins[label], label, reference_props[label], labels[label], ) if not passed: return False bins[label] = [] bin_l[label] = 0 # after generator depleted, write remainder of bins for label, nl in bins.items(): #print("YO") #print(label) #print(nl) passed = self._write_single_node_list_to_file( nl, label, reference_props[label], labels[label], ) if not passed: return False # use complete bin list to write header files # TODO if a node type has varying properties # (ie missingness), we'd need to collect all possible # properties in the generator pass # save config or first-node properties to instance attribute for label in reference_props.keys(): self.node_property_dict[label] = reference_props[label] return True else: if type(nodes) is not list: logger.error("Nodes must be passed as list or generator.") return False else: def gen(nodes): yield from nodes return self._write_node_data(gen(nodes), batch_size=batch_size) def _write_single_node_list_to_file( self, node_list: list, label: str, prop_dict: dict, labels: str, ): """ This function takes one list of biocypher nodes and writes them to a Neo4j admin import compatible CSV file. Args: node_list (list): list of BioCypherNodes to be written label (str): the primary label of the node prop_dict (dict): properties of node class passed from parsing function and their types labels (str): string of one or several concatenated labels for the node class Returns: bool: The return value. True for success, False otherwise. """ if not all(isinstance(n, BioCypherNode) for n in node_list): logger.error("Nodes must be passed as type BioCypherNode.") return False # from list of nodes to list of strings lines = [] for n in node_list: # check for deviations in properties # node properties n_props = n.get_properties() n_keys = list(n_props.keys()) # reference properties ref_props = list(prop_dict.keys()) # compare lists order invariant if not set(ref_props) == set(n_keys): onode = n.get_id() oprop1 = set(ref_props).difference(n_keys) oprop2 = set(n_keys).difference(ref_props) logger.error( f"At least one node of the class {n.get_label()} " f"has more or fewer properties than another. " f"Offending node: {onode!r}, offending property: " f"{max([oprop1, oprop2])}. " f"All reference properties: {ref_props}, " f"All node properties: {n_keys}.", ) return False line = [n.get_id()] if ref_props: plist = [] # make all into strings, put actual strings in quotes for k, v in prop_dict.items(): p = n_props.get(k) if p is None: # TODO make field empty instead of ""? plist.append("") elif v in [ "int", "integer", "long", "float", "double", "dbl", "bool", "boolean", ]: plist.append(str(p)) else: if isinstance(p, list): plist.append(self._write_array_string(p)) else: plist.append(f"{self.quote}{str(p)}{self.quote}") line.append(self.delim.join(plist)) line.append(labels) lines.append(self.delim.join(line) + "\n") # avoid writing empty files if lines: self._write_next_part(label, lines) return True def _write_edge_data(self, edges, batch_size): """ Writes biocypher edges to CSV conforming to the headers created with `_write_edge_headers()`, and is actually required to be run before calling `_write_node_headers()` to set the :py:attr:`self.edge_property_dict` for passing the edge properties to the instance. Expects list or generator of edges from the :py:class:`BioCypherEdge` class. Args: edges (BioCypherEdge): a list or generator of edges in :py:class:`BioCypherEdge` format Returns: bool: The return value. True for success, False otherwise. Todo: - currently works for mixed edges but in practice often is called on one iterable containing one type of edge only """ if isinstance(edges, GeneratorType): logger.debug("Writing edge CSV from generator.") bins = defaultdict(list) # dict to store a list for each # label that is passed in bin_l = {} # dict to store the length of each list for # batching cutoff reference_props = defaultdict( dict, ) # dict to store a dict of properties # for each label to check for consistency and their type # for now, relevant for `int` for edge in edges: if not (edge.get_source_id() and edge.get_target_id()): logger.error( "Edge must have source and target node. " f"Caused by: {edge}", ) continue label = edge.get_label() if not label in bins.keys(): # start new list bins[label].append(edge) bin_l[label] = 1 # get properties from config if present # check whether label is in ontology_adapter.leaves # (may not be if it is an edge that carries the # "label_as_edge" property) cprops = None if ( label in self.translator.ontology.mapping.extended_schema ): cprops = self.translator.ontology.mapping.extended_schema.get( label ).get( "properties", ) else: # try via "label_as_edge" for ( k, v, ) in ( self.translator.ontology.mapping.extended_schema.items() ): if isinstance(v, dict): if v.get("label_as_edge") == label: cprops = v.get("properties") break if cprops: d = cprops # add strict mode properties if self.strict_mode: d["source"] = "str" d["version"] = "str" d["licence"] = "str" else: d = dict(edge.get_properties()) # encode property type for k, v in d.items(): if d[k] is not None: d[k] = type(v).__name__ # else use first encountered edge to define # properties for checking; could later be by # checking all edges but much more complicated, # particularly involving batch writing (would # require "do-overs"). for now, we output a warning # if edge properties diverge from reference # properties (in write_single_edge_list_to_file) # TODO reference_props[label] = d else: # add to list bins[label].append(edge) bin_l[label] += 1 if not bin_l[label] < batch_size: # batch size controlled here passed = self._write_single_edge_list_to_file( bins[label], label, reference_props[label], ) if not passed: return False bins[label] = [] bin_l[label] = 0 # after generator depleted, write remainder of bins for label, nl in bins.items(): passed = self._write_single_edge_list_to_file( nl, label, reference_props[label], ) if not passed: return False # use complete bin list to write header files # TODO if a edge type has varying properties # (ie missingness), we'd need to collect all possible # properties in the generator pass # save first-edge properties to instance attribute for label in reference_props.keys(): self.edge_property_dict[label] = reference_props[label] return True else: if type(edges) is not list: logger.error("Edges must be passed as list or generator.") return False else: def gen(edges): yield from edges return self._write_edge_data(gen(edges), batch_size=batch_size) def _write_single_edge_list_to_file( self, edge_list: list, label: str, prop_dict: dict, ): """ This function takes one list of biocypher edges and writes them to a Neo4j admin import compatible CSV file. Args: edge_list (list): list of BioCypherEdges to be written label (str): the label (type) of the edge prop_dict (dict): properties of node class passed from parsing function and their types Returns: bool: The return value. True for success, False otherwise. """ if not all(isinstance(n, BioCypherEdge) for n in edge_list): logger.error("Edges must be passed as type BioCypherEdge.") return False # from list of edges to list of strings lines = [] for e in edge_list: # check for deviations in properties # edge properties e_props = e.get_properties() e_keys = list(e_props.keys()) ref_props = list(prop_dict.keys()) # compare list order invariant if not set(ref_props) == set(e_keys): oedge = f"{e.get_source_id()}-{e.get_target_id()}" oprop1 = set(ref_props).difference(e_keys) oprop2 = set(e_keys).difference(ref_props) logger.error( f"At least one edge of the class {e.get_label()} " f"has more or fewer properties than another. " f"Offending edge: {oedge!r}, offending property: " f"{max([oprop1, oprop2])}. " f"All reference properties: {ref_props}, " f"All edge properties: {e_keys}.", ) return False plist = [] # make all into strings, put actual strings in quotes for k, v in prop_dict.items(): p = e_props.get(k) if p is None: # TODO make field empty instead of ""? plist.append("") elif v in [ "int", "integer", "long", "float", "double", "dbl", "bool", "boolean", ]: plist.append(str(p)) else: if isinstance(p, list): plist.append(self._write_array_string(p)) else: plist.append(self.quote + str(p) + self.quote) entries = [e.get_source_id()] skip_id = False schema_label = None if label in ["IS_SOURCE_OF", "IS_TARGET_OF", "IS_PART_OF"]: skip_id = True elif not self.translator.ontology.mapping.extended_schema.get( label ): # find label in schema by label_as_edge for ( k, v, ) in self.translator.ontology.mapping.extended_schema.items(): if v.get("label_as_edge") == label: schema_label = k break else: schema_label = label if schema_label: if ( self.translator.ontology.mapping.extended_schema.get( schema_label ).get("use_id") == False ): skip_id = True if not skip_id: entries.append(e.get_id() or "") if ref_props: entries.append(self.delim.join(plist)) entries.append(e.get_target_id()) entries.append( self.translator.name_sentence_to_pascal( e.get_label(), ) ) lines.append( self.delim.join(entries) + "\n", ) # avoid writing empty files if lines: self._write_next_part(label, lines) return True def _write_next_part(self, label: str, lines: list): """ This function writes a list of strings to a new part file. Args: label (str): the label (type) of the edge; internal representation sentence case -> needs to become PascalCase for disk representation lines (list): list of strings to be written Returns: bool: The return value. True for success, False otherwise. """ # translate label to PascalCase label_pascal = self.translator.name_sentence_to_pascal( parse_label(label) ) # list files in self.outdir files = glob.glob( os.path.join(self.outdir, f"{label_pascal}-part*.csv") ) # find file with highest part number if not files: next_part = 0 else: next_part = ( max( [ int(f.split(".")[-2].split("-")[-1].replace("part", "")) for f in files ], ) + 1 ) # write to file padded_part = str(next_part).zfill(3) logger.info( f"Writing {len(lines)} entries to {label_pascal}-part{padded_part}.csv", ) # store name only in case import_call_file_prefix is set part = f"{label_pascal}-part{padded_part}.csv" file_path = os.path.join(self.outdir, part) with open(file_path, "w", encoding="utf-8") as f: # concatenate with delimiter f.writelines(lines) if not self.parts.get(label): self.parts[label] = [part] else: self.parts[label].append(part) def get_import_call(self) -> str: """ Function to return the import call detailing folder and individual node and edge headers and data files, as well as delimiters and database name. Returns: str: a bash command for the database import """ return self._construct_import_call() def write_import_call(self) -> str: """ Function to write the import call detailing folder and individual node and edge headers and data files, as well as delimiters and database name, to the export folder as txt. Returns: str: The path of the file holding the import call. """ file_path = os.path.join(self.outdir, self._get_import_script_name()) logger.info( f"Writing {self.db_name + ' ' if self.db_name else ''}import call to `{file_path}`." ) with open(file_path, "w", encoding="utf-8") as f: f.write(self._construct_import_call()) return file_path def parse_label(label: str) -> str: """ Check if the label is compliant with Neo4j naming conventions, https://neo4j.com/docs/cypher-manual/current/syntax/naming/, and if not, remove non-compliant characters. Args: label (str): The label to check Returns: str: The compliant label """ # Check if the name contains only alphanumeric characters, underscore, or dollar sign # and dot (for class hierarchy of BioCypher) allowed_chars = r"a-zA-Z0-9_$ ." matches = re.findall(f"[{allowed_chars}]", label) non_matches = re.findall(f"[^{allowed_chars}]", label) if non_matches: non_matches = list(set(non_matches)) logger.warning( f"Label is not compliant with Neo4j naming rules. Removed non compliant characters: {non_matches}" ) def first_character_compliant(character: str) -> bool: return character.isalpha() or character == "$" if not first_character_compliant(matches[0]): for c in matches: if first_character_compliant(c): matches = matches[matches.index(c) :] break logger.warning( "Label does not start with an alphabetic character or with $. Removed non compliant characters." ) return "".join(matches).strip()