2025-04-16 22:12:19 +02:00

502 lines
20 KiB
Python

import os
import glob
import pandas as pd
from biocypher._logger import logger
from biocypher.output.write._batch_writer import parse_label, _BatchWriter
class _Neo4jBatchWriter(_BatchWriter):
"""
Class for writing node and edge representations to disk using the
format specified by Neo4j for the use of admin import. Each batch
writer instance has a fixed representation that needs to be passed
at instantiation via the :py:attr:`schema` argument. The instance
also expects an ontology adapter via :py:attr:`ontology_adapter` to be able
to convert and extend the hierarchy.
This class inherits from the abstract class "_BatchWriter" and implements the
Neo4j-specific methods:
- _write_node_headers
- _write_edge_headers
- _construct_import_call
- _write_array_string
"""
def __init__(self, *args, **kwargs):
"""
Constructor.
Check the version of Neo4j and adds a command scope if version >= 5.
Returns:
_Neo4jBatchWriter: An instance of the writer.
"""
# Should read the configuration and setup import_call_bin_prefix.
super().__init__(*args, **kwargs)
def _get_default_import_call_bin_prefix(self):
"""
Method to provide the default string for the import call bin prefix.
Returns:
str: The default location for the neo4j admin import location
"""
return "bin/"
def _write_array_string(self, string_list):
"""
Abstract method to output.write the string representation of an array into a .csv file
as required by the neo4j admin-import.
Args:
string_list (list): list of ontology strings
Returns:
str: The string representation of an array for the neo4j admin import
"""
string = self.adelim.join(string_list)
return f"{self.quote}{string}{self.quote}"
def _write_node_headers(self):
"""
Writes single CSV file for a graph entity that is represented
as a node as per the definition in the `schema_config.yaml`,
containing only the header for this type of node.
Returns:
bool: The return value. True for success, False otherwise.
"""
# load headers from data parse
if not self.node_property_dict:
logger.error(
"Header information not found. Was the data parsed first?",
)
return False
for label, props in self.node_property_dict.items():
_id = ":ID"
##MeDaX dev remark:
##From Fhir data we get case sensitive labels. E.g. 'Procedure' and 'procedure' are two distinct node types.
##Because we are converting Resources to more specific node classes using their "resourceType" attribute.
# translate label to PascalCase
pascal_label = self.translator.name_sentence_to_pascal(
parse_label(label)
)
header = f"{pascal_label}-header.csv"
header_path = os.path.join(
self.outdir,
header,
)
parts = f"{pascal_label}-part.*"
existing_header = False
# check if file already exists
if os.path.exists(header_path):
logger.warning(
f"Header file `{header_path}` already exists. Overwriting.",
)
with open(header_path, "r", encoding="utf-8") as existing:
existing_header = existing.read().strip().split(self.delim)
# concatenate key:value in props
props_list = []
for k, v in props.items():
if v in ["int", "long", "integer"]:
props_list.append(f"{k}:long")
elif v in ["int[]", "long[]", "integer[]"]:
props_list.append(f"{k}:long[]")
elif v in ["float", "double", "dbl"]:
props_list.append(f"{k}:double")
elif v in ["float[]", "double[]"]:
props_list.append(f"{k}:double[]")
elif v in ["bool", "boolean"]:
# TODO Neo4j boolean support / spelling?
props_list.append(f"{k}:boolean")
elif v in ["bool[]", "boolean[]"]:
props_list.append(f"{k}:boolean[]")
elif v in ["str[]", "string[]"]:
props_list.append(f"{k}:string[]")
else:
props_list.append(f"{k}")
# create list of lists and flatten
out_list = [[_id], props_list, [":LABEL"]]
out_list = [val for sublist in out_list for val in sublist]
with open(header_path, "w", encoding="utf-8") as f:
# Check if header file already exists and has different columns
if os.path.exists(header_path):
if existing_header:
#existing_header = existing.read().strip().split(self.delim)
# Compare existing and new headers
if set(existing_header) != set(out_list):
# Get part files associated with this header
base_name = os.path.basename(header_path).replace("-header.csv", "")
part_files = glob.glob(os.path.join(os.path.dirname(header_path), f"{base_name}-part*.csv"))
# Find the highest numbered part file without full sorting
highest_part = None
highest_number = -1
for part_file in part_files:
try:
# Extract number from filename (assuming format like "part123.csv")
file_name = os.path.basename(part_file)
number_part = file_name.split("part")[1].split(".")[0]
number = int(number_part)
if number > highest_number:
highest_number = number
highest_part = part_file
except (IndexError, ValueError):
# Skip files that don't match the expected pattern
continue
# Update each part file with the new columns
for part_file in part_files:
if part_file == highest_part:
print(f"Skipping the highest part file: {highest_part}")
continue
try:
#print("exi: ", existing_header)
#print("out: ", out_list)
df = self.adapt_csv_to_new_header(existing_header, out_list, part_file)
# Read the file without headers
# Write back to file WITHOUT including the header
df.to_csv(part_file, sep=self.delim, index=False, header=False)
print(f"Updated {part_file} with new columns in correct positions")
except Exception as e:
print(f"Error updating {part_file}: {e}")
# Write the new header
row = self.delim.join(out_list)
f.write(row)
# add file path to neo4 admin import statement (import call file
# path may be different from actual file path)
import_call_header_path = os.path.join(
self.import_call_file_prefix,
header,
)
import_call_parts_path = os.path.join(
self.import_call_file_prefix,
parts,
)
self.import_call_nodes.add(
(import_call_header_path, import_call_parts_path)
)
return True
def _write_edge_headers(self):
"""
Writes single CSV file for a graph entity that is represented
as an edge as per the definition in the `schema_config.yaml`,
containing only the header for this type of edge.
Returns:
bool: The return value. True for success, False otherwise.
"""
# load headers from data parse
if not self.edge_property_dict:
logger.error(
"Header information not found. Was the data parsed first?",
)
return False
for label, props in self.edge_property_dict.items():
# translate label to PascalCase
pascal_label = self.translator.name_sentence_to_pascal(
parse_label(label)
)
# paths
header = f"{pascal_label}-header.csv"
header_path = os.path.join(
self.outdir,
header,
)
parts = f"{pascal_label}-part.*"
# check for file exists
if os.path.exists(header_path):
logger.warning(
f"File {header_path} already exists. Overwriting."
)
# concatenate key:value in props
props_list = []
for k, v in props.items():
if v in ["int", "long", "integer"]:
props_list.append(f"{k}:long")
elif v in ["int[]", "long[]", "integer[]"]:
props_list.append(f"{k}:long[]")
elif v in ["float", "double"]:
props_list.append(f"{k}:double")
elif v in ["float[]", "double[]"]:
props_list.append(f"{k}:double[]")
elif v in [
"bool",
"boolean",
]: # TODO does Neo4j support bool?
props_list.append(f"{k}:boolean")
elif v in ["bool[]", "boolean[]"]:
props_list.append(f"{k}:boolean[]")
elif v in ["str[]", "string[]"]:
props_list.append(f"{k}:string[]")
else:
props_list.append(f"{k}")
skip_id = False
schema_label = None
if label in ["IS_SOURCE_OF", "IS_TARGET_OF", "IS_PART_OF"]:
skip_id = True
elif not self.translator.ontology.mapping.extended_schema.get(
label
):
# find label in schema by label_as_edge
for (
k,
v,
) in self.translator.ontology.mapping.extended_schema.items():
if v.get("label_as_edge") == label:
schema_label = k
break
else:
schema_label = label
out_list = [":START_ID"]
if schema_label:
if (
self.translator.ontology.mapping.extended_schema.get(
schema_label
).get("use_id")
== False
):
skip_id = True
if not skip_id:
out_list.append("id")
out_list.extend(props_list)
out_list.extend([":END_ID", ":TYPE"])
existing_header = False
# check if file already exists
if os.path.exists(header_path):
logger.warning(
f"Header file `{header_path}` already exists. Overwriting.",
)
with open(header_path, "r", encoding="utf-8") as existing:
existing_header = existing.read().strip().split(self.delim)
with open(header_path, "w", encoding="utf-8") as f:
# Check if header file already exists and has different columns
if os.path.exists(header_path):
if existing_header:
#existing_header = existing.read().strip().split(self.delim)
# Compare existing and new headers
if set(existing_header) != set(out_list):
# Get part files associated with this header
base_name = os.path.basename(header_path).replace("-header.csv", "")
part_files = glob.glob(os.path.join(os.path.dirname(header_path), f"{base_name}-part*.csv"))
# Find the highest numbered part file without full sorting
highest_part = None
highest_number = -1
for part_file in part_files:
try:
# Extract number from filename (assuming format like "part123.csv")
file_name = os.path.basename(part_file)
number_part = file_name.split("part")[1].split(".")[0]
number = int(number_part)
if number > highest_number:
highest_number = number
highest_part = part_file
except (IndexError, ValueError):
# Skip files that don't match the expected pattern
continue
# Update each part file with the new columns
for part_file in part_files:
if part_file == highest_part:
print(f"Skipping the highest part file: {highest_part}")
continue
try:
print("exi: ", existing_header)
print("out: ", out_list)
df = self.adapt_csv_to_new_header(existing_header, out_list, part_file)
# Read the file without headers
# Write back to file WITHOUT including the header
df.to_csv(part_file, sep=self.delim, index=False, header=False)
print(f"Updated {part_file} with new columns in correct positions")
except Exception as e:
print(f"Error updating {part_file}: {e}")
# Write the new header
row = self.delim.join(out_list)
f.write(row)
# add file path to neo4 admin import statement (import call file
# path may be different from actual file path)
import_call_header_path = os.path.join(
self.import_call_file_prefix,
header,
)
import_call_parts_path = os.path.join(
self.import_call_file_prefix,
parts,
)
self.import_call_edges.add(
(import_call_header_path, import_call_parts_path)
)
return True
def _get_import_script_name(self) -> str:
"""
Returns the name of the neo4j admin import script
Returns:
str: The name of the import script (ending in .sh)
"""
return "neo4j-admin-import-call.sh"
def _construct_import_call(self) -> str:
"""
Function to construct the import call detailing folder and
individual node and edge headers and data files, as well as
delimiters and database name. Built after all data has been
processed to ensure that nodes are called before any edges.
Returns:
str: a bash command for neo4j-admin import
"""
import_call_neo4j_v4 = self._get_import_call(
"import", "--database=", "--force="
)
import_call_neo4j_v5 = self._get_import_call(
"database import full", "", "--overwrite-destination="
)
neo4j_version_check = f"version=$({self._get_default_import_call_bin_prefix()}neo4j-admin --version | cut -d '.' -f 1)"
import_script = f"#!/bin/bash\n{neo4j_version_check}\nif [[ $version -ge 5 ]]; then\n\t{import_call_neo4j_v5}\nelse\n\t{import_call_neo4j_v4}\nfi"
return import_script
def _get_import_call(
self, import_cmd: str, database_cmd: str, wipe_cmd: str
) -> str:
"""Get parametrized import call for Neo4j 4 or 5+.
Args:
import_cmd (str): The import command to use.
database_cmd (str): The database command to use.
wipe_cmd (str): The wipe command to use.
Returns:
str: The import call.
"""
import_call = f"{self.import_call_bin_prefix}neo4j-admin {import_cmd} "
import_call += f"{database_cmd}{self.db_name} "
import_call += f'--delimiter="{self.escaped_delim}" '
import_call += f'--array-delimiter="{self.escaped_adelim}" '
if self.quote == "'":
import_call += f'--quote="{self.quote}" '
else:
import_call += f"--quote='{self.quote}' "
if self.wipe:
import_call += f"{wipe_cmd}true "
if self.skip_bad_relationships:
import_call += "--skip-bad-relationships=true "
if self.skip_duplicate_nodes:
import_call += "--skip-duplicate-nodes=true "
# append node import calls
for header_path, parts_path in self.import_call_nodes:
import_call += f'--nodes="{header_path},{parts_path}" '
# append edge import calls
for header_path, parts_path in self.import_call_edges:
import_call += f'--relationships="{header_path},{parts_path}" '
return import_call
def adapt_csv_to_new_header(self, old_header, new_header, csv_file_path):
"""
Adapt a CSV table to a new header structure, placing new columns in their correct positions.
Parameters:
old_header (list): The original header columns
new_header (list): The new header columns
csv_file_path (str): Path to the CSV file
Returns:
pandas.DataFrame: CSV data with the new header structure
"""
# Step 1: Read the CSV data without headers
df = pd.read_csv(csv_file_path, sep=self.delim, header=None)
# Step 2: If the file is empty, return empty DataFrame with new headers
if df.empty:
return pd.DataFrame(columns=new_header)
# Step 3: If column count doesn't match old_header length, handle the mismatch
if len(df.columns) != len(old_header):
print(f"Warning: CSV columns count ({len(df.columns)}) doesn't match the provided old header count ({len(old_header)})")
# If file has fewer columns than old_header, pad with NaN
if len(df.columns) < len(old_header):
for i in range(len(df.columns), len(old_header)):
df[i] = None
# If file has more columns than old_header, truncate
else:
df = df.iloc[:, :len(old_header)]
# Step 4: Assign old header names to the dataframe
df.columns = old_header
# Step 5: Create a new DataFrame with the correct structure
new_df = pd.DataFrame(columns=new_header)
# Step 6: For each column in the new header, find its position in the old header
for new_col_idx, new_col in enumerate(new_header):
if new_col in old_header:
# If column exists in old header, copy data
new_df[new_col] = df[new_col]
else:
# If new column, add empty column
new_df[new_col] = None
# Step 7: Ensure columns are in the exact order of new_header
new_df = new_df[new_header]
return new_df