medax_pipeline/import_fhir_to_nx_diGraph.py
2025-04-16 22:12:19 +02:00

277 lines
9.3 KiB
Python

from biocypher import BioCypher
import networkx as nx
import json
import os
import sys
import re
import uuid
import gc
from dotenv import load_dotenv
from graphCreation import create_graph
from graphCreation.process_references import process_references
from graphCreation.property_convolution import property_convolution
from schema_config_generation import write_automated_schema
from fhirImport import getPatientEverything, getBundle
def load_multiple_fhir_patients(n):
#graph = nx.DiGraph()
init = True
ids = []
#get n ids
nextIds = True
while len(ids) < n and nextIds:
if init:
complex = os.getenv('COMPLEX_PATIENTS')
if complex and complex.upper() != 'TRUE':
bundle = getBundle(None, '/Patient?_count=' + str(n))
else:
bundle = getBundle(None, '/Patient?_has:Observation:subject:status=final&_count=' + str(n))
else:
bundle = getBundle(None, nextLink)
if not 'entry' in bundle.json():
print("ERROR -- No data found in the fhir bundle. Check the request and if the server is up and responding")
sys.exit(1)
for entry in bundle.json()['entry']:
ids.append(entry['resource']['id'])
nextIds = False
for l in bundle.json()['link']:
if l['relation'] == "next":
nextLink = l['url']
nextIds = True
if len(ids) < n:
n = len(ids)
batchSize = int(os.getenv('BATCH_SIZE'))
c = 0
print(len(ids))
#get bundle for each ID
for id in ids:
c += 1
bundle = getPatientEverything(id).json()
bundle = replace_single_quotes(bundle) ### maybe not needed for german data
if init:
graph = nx.DiGraph()
init = False
create_graph.add_json_to_networkx(bundle, id + '_bundle', graph)
if c % 50 == 0:
print("---------- ", c, " patients loaded ----------", flush=True)
if c % batchSize == 0 or c == n:
print(c, " patients imported, reducing graph", flush = True)
process_references(graph)
property_convolution(graph)
lastChunk = False
if n == c:
lastChunk = True
runBioCypher(graph, lastChunk)
init = True
print(graph)
del graph
gc.collect
def replace_single_quotes(obj):
if isinstance(obj, str): # If it's a string, replace single quotes
return obj.replace("'", "''")
elif isinstance(obj, dict): # If it's a dictionary, process each key-value pair
return {key: replace_single_quotes(value) for key, value in obj.items()}
elif isinstance(obj, list): # If it's a list, process each item
return [replace_single_quotes(item) for item in obj]
else:
return obj # Leave other data types unchanged
def main():
## create networkX and run improvement scripts
print("Creating the graph...", flush=True)
nPatients = int(os.getenv('NUMBER_OF_PATIENTS'))
load_multiple_fhir_patients(nPatients)
def runBioCypher(nxGraph, final):
#get lists of node and edge types
print("Generate auto schema...", flush=True)
write_automated_schema(nxGraph, 'config/automated_schema.yaml', 'config/manual_schema_config.yaml')
# create Biocypher driver
bc = BioCypher(
biocypher_config_path="config/biocypher_config.yaml",
)
#bc.show_ontology_structure() #very extensive
#BioCypher preperation
def node_generator():
for node in nxGraph.nodes():
label = nxGraph.nodes[node].get('label')
if label == "resource":
label = nxGraph.nodes[node].get('resourceType')
nxGraph.nodes[node]['label'] = label.capitalize()
label = label.capitalize()
unq_id = nxGraph.nodes[node].get('unique_id', False)
if(nxGraph.nodes[node].get('label') in ['search', 'meta', 'link']):
#print("skipped a node: ", nxGraph.nodes[node].get('label'))
continue
label = nxGraph.nodes[node].get('label')
if(label == 'dummy'):
#print("SKIPPED dummy node: ", unq_id)
continue
yield(
nxGraph.nodes[node].get('unique_id', node), #remark: this returns the node id if this attribute exists. otherwise it returns node which equals the identifier that is used by nx
label,
nxGraph.nodes[node] # get properties
)
def edge_generator():
for edge in nxGraph.edges(data = True):
source, target, attributes = edge
sLabel = nxGraph.nodes[source].get('label')
if sLabel == 'resource':
sLabel = nxGraph.nodes[source].get('resourceType')
tLabel = nxGraph.nodes[target].get('label')
if tLabel == 'resource':
tLabel = nxGraph.nodes[target].get('resourceType')
label = sLabel.capitalize() + '_to_' + tLabel
yield(
attributes.get('id', str(uuid.uuid4())), # Edge ID (if exists, otherwise use nx internal id)
nxGraph.nodes[source].get('unique_id', source),
nxGraph.nodes[target].get('unique_id', target),
label,
attributes # All edge attributes
)
#import nodes
bc.write_nodes(node_generator())
bc.write_edges(edge_generator())
#write the import script -- we are creating our own script since BC would only consider the last batch as an input
if final:
print("CREATING THE SCRIPT")
generate_neo4j_import_script()
with open('/neo4j_import/shell-scipt-complete', 'w') as f:
f.write('Import completed successfully')
print("FHIR import completed successfully")
def generate_neo4j_import_script(directory_path="/neo4j_import/", output_file="neo4j-admin-import-call.sh"):
"""
Reads files in a directory and generates a Neo4j import shell script.
Args:
directory_path (str): Path to the directory containing CSV files
output_file (str): Name of the output shell script file
Returns:
str: Path to the generated shell script
"""
# Get all files in the directory
all_files = os.listdir(directory_path)
# Dictionary to store entity types (nodes and relationships)
entity_types = {}
# Find all header files and use them to identify entity types
for filename in all_files:
if '-header.csv' in filename:
entity_name = filename.split('-header.csv')[0]
# Check if it's a relationship (contains "To" and "Association")
is_relationship = "To" in entity_name and "Association" in entity_name
# Store in entity_types dictionary
if is_relationship:
entity_type = "relationships"
else:
entity_type = "nodes"
# Initialize the entity if not already present
if entity_name not in entity_types:
entity_types[entity_name] = {
"type": entity_type,
"header": f"/neo4j_import/{filename}",
"has_parts": False
}
# Check for part files for each entity
for entity_name in entity_types:
# Create pattern to match part files for this entity
part_pattern = f"{entity_name}-part"
# Check if any file matches the pattern
for filename in all_files:
if part_pattern in filename:
entity_types[entity_name]["has_parts"] = True
break
# Generate the import commands
nodes_command = ""
relationships_command = ""
for entity_name, info in entity_types.items():
if info["has_parts"]:
# Create the command string with wildcard for part files
command = f" --{info['type']}=\"{info['header']},/neo4j_import/{entity_name}-part.*\""
# Add to appropriate command string
if info['type'] == "nodes":
nodes_command += command
else: # relationships
relationships_command += command
# Create the shell script content
script_content = """#!/bin/bash
version=$(bin/neo4j-admin --version | cut -d '.' -f 1)
if [[ $version -ge 5 ]]; then
\tbin/neo4j-admin database import full neo4j --delimiter="\\t" --array-delimiter="|" --quote="'" --overwrite-destination=true --skip-bad-relationships=true --skip-duplicate-nodes=true{nodes}{relationships}
else
\tbin/neo4j-admin import --database=neo4j --delimiter="\\t" --array-delimiter="|" --quote="'" --force=true --skip-bad-relationships=true --skip-duplicate-nodes=true{nodes}{relationships}
fi
""".format(nodes=nodes_command, relationships=relationships_command)
# Write the script to file
script_path = os.path.join(directory_path, output_file)
with open(script_path, 'w') as f:
f.write(script_content)
# Make the script executable
os.chmod(script_path, 0o755)
print("Shell import script created", flush=True)
if __name__ == "__main__":
main()