diff --git a/.gitignore b/.gitignore index d1731f7..f7d0361 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ **__pycache__/ -.venv/ +old-src/ diff --git a/src/moi/moi.py b/src/moi/moi.py index defe79b..538985a 100644 --- a/src/moi/moi.py +++ b/src/moi/moi.py @@ -3,7 +3,7 @@ from neo4j import GraphDatabase import argparse import configparser import os -from methods_mesh import create_graph_config, import_ontology, postprocess_mesh +from methods_moi import create_graph_config, import_ontology, postprocess_mesh # define parameters - pass ontology file dir and db conf as arguments when running the script diff --git a/src/postprocessing/postprocess.py b/src/postprocessing/postprocess.py index 3e97fed..0ed46a5 100644 --- a/src/postprocessing/postprocess.py +++ b/src/postprocessing/postprocess.py @@ -92,10 +92,10 @@ password = conf_file['neo4j']['password'] if __name__ == "__main__": # postprocess mesh to umls - run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_mesh_to_umls_optimised) + #run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_mesh_to_umls_optimised) # postprocess clinicaltrials.gov to mesh - run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_ct_to_mesh) + #run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_ct_to_mesh) # postprocess mdm to umls - run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_mdm_to_umls) + #run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_mdm_to_umls) # postprocess clinicaltrials.gov to mdm run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_ct_to_mdm) diff --git a/src/study2neo4j/ct2neo4j.py b/src/study2neo4j/ct2neo4j.py index 3780bab..f38d221 100644 --- a/src/study2neo4j/ct2neo4j.py +++ b/src/study2neo4j/ct2neo4j.py @@ -5,22 +5,56 @@ from neo4j import GraphDatabase # Define a function to add nodes and relationships recursively def add_nodes_from_dict(tx, parent_node_label, parent_node_str_id, current_dict): + # Collect data for batching + # dict_node_data = [] + # dict_rels_data = [] + # list_node_data = [] + # list_rels_data = [] + phase_data = [] + phase_rels_data = [] + condition_data = [] + condition_rels_data = [] + keyword_data = [] + keyword_rels_data = [] + for key, value in current_dict.items(): # iterate over each key-value pair in dictionary + if key == "phases": + # Create a node for each phase for index, phase in enumerate(value): phase_node_str_id = f"{parent_node_str_id}_{key}_{index}" - tx.run(f"MERGE (n:phase {{str_id: $str_id, name: $phase_name}})", - str_id=phase_node_str_id, phase_name=phase) - tx.run( - f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:phase {{str_id: $child_str_id}}) " - f"MERGE (a)-[:{key}]->(b)", - parent_str_id=parent_node_str_id, - child_str_id=phase_node_str_id - ) + phase_data.append({"str_id": phase_node_str_id, "name": phase}) #new + phase_rels_data.append({"parent_str_id": parent_node_str_id, "child_str_id": phase_node_str_id, "rel_type": key}) + + if phase_data: + tx.run(""" + UNWIND $data AS row + MERGE (n:phase {str_id: row.str_id}) + SET n.name = row.name + """, data=phase_data) + + if phase_rels_data: + tx.run(""" + UNWIND $rels AS rel + MATCH (a:{parent_label} {{str_id: rel.parent_str_id}}), + (b:phase {{str_id: rel.child_str_id}}) + MERGE (a)-[:{rel_type}]->(b) + """.format(parent_label=parent_node_label, rel_type=key), rels=phase_rels_data) + + #tx.run(f"MERGE (n:phase {{str_id: $str_id, name: $phase_name}})", + # str_id=phase_node_str_id, phase_name=phase) + #tx.run( + # f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:phase {{str_id: $child_str_id}}) " + # f"MERGE (a)-[:{key}]->(b)", + # parent_str_id=parent_node_str_id, + # child_str_id=phase_node_str_id + #) if isinstance(value, dict): # if value of key is a dict, then create new node: # Create a new node for the nested dictionary new_node_str_id = f"{parent_node_str_id}_{key}" # concatenate the parent_node_str_id and key to a new id + + # todo: move batch processing to end of function tx.run(f"MERGE (n:{key} {{str_id: $str_id}})", str_id=new_node_str_id) # create node with key as label # Create a relationship from the parent node to the new node @@ -38,26 +72,66 @@ def add_nodes_from_dict(tx, parent_node_label, parent_node_str_id, current_dict) # Create a node for each condition for index, condition in enumerate(value): condition_node_str_id = f"{parent_node_str_id}_{key}_{index}" - tx.run(f"MERGE (n:condition {{str_id: $str_id, name: $condition_name}})", - str_id=condition_node_str_id, condition_name=condition) - tx.run( - f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:condition {{str_id: $child_str_id}}) " - f"MERGE (a)-[:{key}]->(b)", - parent_str_id=parent_node_str_id, - child_str_id=condition_node_str_id - ) + condition_data.append({"str_id": condition_node_str_id, "name": condition}) #new + condition_rels_data.append({"parent_str_id": parent_node_str_id, "child_str_id": condition_node_str_id, "rel_type": key}) #new + + if condition_data: + tx.run(""" + UNWIND $data AS row + MERGE (n:condition {str_id: row.str_id}) + SET n.name = row.name + """, data=condition_data) + + if condition_rels_data: + tx.run(""" + UNWIND $data AS row + MATCH (a:{parent_label} {{str_id: row.parent_str_id}}), (b:condition {{str_id: row.child_str_id}}) + MERGE (a)-[:{rel_type}]->(b) + """.format(parent_label=parent_node_label, rel_type=key), data=condition_rels_data) + + #tx.run(f"MERGE (n:condition {{str_id: $str_id, name: $condition_name}})", + # str_id=condition_node_str_id, condition_name=condition) + # todo: rels? + #tx.run( + # f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:condition {{str_id: $child_str_id}}) " + # f"MERGE (a)-[:{key}]->(b)", + # parent_str_id=parent_node_str_id, + # child_str_id=condition_node_str_id + #) + elif key == "keywords": # Create a node for each keyword for index, keyword in enumerate(value): keyword_node_str_id = f"{parent_node_str_id}_{key}_{index}" - tx.run(f"MERGE (n:keyword {{str_id: $str_id, name: $keyword_name}})", str_id=keyword_node_str_id, - keyword_name=keyword) - tx.run( - f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:keyword {{str_id: $child_str_id}}) " - f"MERGE (a)-[:{key}]->(b)", - parent_str_id=parent_node_str_id, - child_str_id=keyword_node_str_id - ) + keyword_data.append({"str_id": keyword_node_str_id, "name": keyword}) + keyword_rels_data.append({ + "parent_str_id": parent_node_str_id, + "child_str_id": keyword_node_str_id, + "rel_type": key + }) + + if keyword_data: + tx.run(""" + UNWIND $data AS row + MERGE (n:keyword {str_id: row.str_id}) + SET n.name = row.name + """, data=keyword_data) + + if keyword_rels_data: + tx.run(""" + UNWIND $data AS row + MATCH (a:{parent_label} {{str_id: row.parent_str_id}}), (b:keyword {{str_id: row.child_str_id}}) + MERGE (a)-[:{rel_type}]->(b) + """.format(parent_label=parent_node_label, rel_type=key), data=keyword_rels_data) + + #tx.run(f"MERGE (n:keyword {{str_id: $str_id, name: $keyword_name}})", str_id=keyword_node_str_id, + # keyword_name=keyword) + #tx.run( + # f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:keyword {{str_id: $child_str_id}}) " + # f"MERGE (a)-[:{key}]->(b)", + # parent_str_id=parent_node_str_id, + # child_str_id=keyword_node_str_id + #) # if list doesn't contain any nested dictionaries, make it a value in the node if not any(isinstance(item, dict) for item in value): @@ -101,25 +175,54 @@ def add_nodes_from_dict(tx, parent_node_label, parent_node_str_id, current_dict) # Connect to Neo4j and create the graph -def create_graph_from_directory(uri, user, password, directory_path): +def create_graph_from_directory(uri, user, password, directory_path, batch_size=100): driver = GraphDatabase.driver(uri, auth=(user, password)) - for filename in os.listdir(directory_path): - if filename.endswith('.json'): - file_path = os.path.join(directory_path, filename) - try: - with open(file_path, 'r') as file: - json_data = json.load(file) + with driver.session() as session: + session.run("CREATE INDEX IF NOT EXISTS FOR (n:ClinicalTrialsEntry) ON (n.str_id)") + session.run("CREATE INDEX IF NOT EXISTS FOR (n:phase) ON (n.str_id)") + session.run("CREATE INDEX IF NOT EXISTS FOR (n:condition) ON (n.str_id)") + session.run("CREATE INDEX IF NOT EXISTS FOR (n:keyword) ON (n.str_id)") + session.run("CREATE INDEX IF NOT EXISTS FOR (n:reference) ON (n.str_id)") - with driver.session() as session: - root_node_label = 'ClinicalTrialsEntry' - root_node_str_id = json_data['protocolSection']['identificationModule']['nctId'] - session.execute_write( - lambda tx: tx.run(f"MERGE (n:{root_node_label} {{str_id: $str_id}})", str_id=root_node_str_id)) - session.execute_write(add_nodes_from_dict, root_node_label, root_node_str_id, json_data) + json_files = [file for file in os.listdir(directory_path) if file.endswith(".json")] + total_files = len(json_files) - print(f"Successfully imported: {filename}") - except Exception as e: - print(f"Failed to import {filename}: {e}") + for i in range(0, total_files, batch_size): + batch_files = json_files[i:i + batch_size] + print(f"Processing batch {i//batch_size + 1}: {len(batch_files)} files") + + with driver.session() as session: + for filename in batch_files: + if filename.endswith(".json"): + file_path = os.path.join(directory_path, filename) + try: + with open(file_path, 'r') as file: + json_data = json.load(file) + root_node_label = 'ClinicalTrialsEntry' + root_node_str_id = json_data['protocolSection']['identificationModule']['nctId'] + session.execute_write(lambda tx: tx.run(f"MERGE (n:{root_node_label} {{str_id: $str_id}})", str_id=root_node_str_id)) + session.execute_write(add_nodes_from_dict, root_node_label, root_node_str_id, json_data) + print(f"Successfully imported: {filename}") + except Exception as e: + print(f"Failed to import: {filename}: {e}") + + #for filename in os.listdir(directory_path): + # if filename.endswith('.json'): + # file_path = os.path.join(directory_path, filename) + # try: + # with open(file_path, 'r') as file: + # json_data = json.load(file) + + # with driver.session() as session: + # root_node_label = 'ClinicalTrialsEntry' + # root_node_str_id = json_data['protocolSection']['identificationModule']['nctId'] + # session.execute_write( + # lambda tx: tx.run(f"MERGE (n:{root_node_label} {{str_id: $str_id}})", str_id=root_node_str_id)) + # session.execute_write(add_nodes_from_dict, root_node_label, root_node_str_id, json_data) + + # print(f"Successfully imported: {filename}") + # except Exception as e: + # print(f"Failed to import {filename}: {e}") driver.close() diff --git a/src/study2neo4j/run.py b/src/study2neo4j/run.py index 54c300c..22aec15 100644 --- a/src/study2neo4j/run.py +++ b/src/study2neo4j/run.py @@ -2,8 +2,9 @@ import argparse import logging import configparser from ct2neo4j import create_graph_from_directory +import time -STUDY2NEO4J_VERSION: str = "0.1" +STUDY2NEO4J_VERSION: str = "0.2" logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -15,6 +16,7 @@ parser = argparse.ArgumentParser() parser.add_argument('-c', '--conf', required=True, type=str, help='Configuration file with database connection parameters') parser.add_argument('-f', '--files', required=True, type=str, help='Directory with json files') +parser.add_argument('-b', '--batch-size', required=False, type=int, default=100, help='Batch size') # parse parameters args = parser.parse_args() @@ -29,4 +31,7 @@ password = conf_file['neo4j']['password'] # start study2neo4j if __name__ == "__main__": - create_graph_from_directory(uri=uri, user=username, password=password, directory_path=json_file_path) + start_time = time.time() + create_graph_from_directory(uri=uri, user=username, password=password, directory_path=json_file_path, batch_size=int(args.batch_size)) + end_time = time.time() + print("--- %s minutes ---" % ((end_time - start_time)/60))