added batch processing; added time measurements
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,2 +1,2 @@
|
||||
**__pycache__/
|
||||
.venv/
|
||||
old-src/
|
||||
|
||||
@@ -3,7 +3,7 @@ from neo4j import GraphDatabase
|
||||
import argparse
|
||||
import configparser
|
||||
import os
|
||||
from methods_mesh import create_graph_config, import_ontology, postprocess_mesh
|
||||
from methods_moi import create_graph_config, import_ontology, postprocess_mesh
|
||||
|
||||
|
||||
# define parameters - pass ontology file dir and db conf as arguments when running the script
|
||||
|
||||
@@ -92,10 +92,10 @@ password = conf_file['neo4j']['password']
|
||||
|
||||
if __name__ == "__main__":
|
||||
# postprocess mesh to umls
|
||||
run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_mesh_to_umls_optimised)
|
||||
#run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_mesh_to_umls_optimised)
|
||||
# postprocess clinicaltrials.gov to mesh
|
||||
run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_ct_to_mesh)
|
||||
#run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_ct_to_mesh)
|
||||
# postprocess mdm to umls
|
||||
run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_mdm_to_umls)
|
||||
#run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_mdm_to_umls)
|
||||
# postprocess clinicaltrials.gov to mdm
|
||||
run_cypher_query_for_postprocessing(uri=uri, user=username, password=password, cypher_query=cypher_ct_to_mdm)
|
||||
|
||||
@@ -5,22 +5,56 @@ from neo4j import GraphDatabase
|
||||
|
||||
# Define a function to add nodes and relationships recursively
|
||||
def add_nodes_from_dict(tx, parent_node_label, parent_node_str_id, current_dict):
|
||||
# Collect data for batching
|
||||
# dict_node_data = []
|
||||
# dict_rels_data = []
|
||||
# list_node_data = []
|
||||
# list_rels_data = []
|
||||
phase_data = []
|
||||
phase_rels_data = []
|
||||
condition_data = []
|
||||
condition_rels_data = []
|
||||
keyword_data = []
|
||||
keyword_rels_data = []
|
||||
|
||||
for key, value in current_dict.items(): # iterate over each key-value pair in dictionary
|
||||
|
||||
if key == "phases":
|
||||
|
||||
# Create a node for each phase
|
||||
for index, phase in enumerate(value):
|
||||
phase_node_str_id = f"{parent_node_str_id}_{key}_{index}"
|
||||
tx.run(f"MERGE (n:phase {{str_id: $str_id, name: $phase_name}})",
|
||||
str_id=phase_node_str_id, phase_name=phase)
|
||||
tx.run(
|
||||
f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:phase {{str_id: $child_str_id}}) "
|
||||
f"MERGE (a)-[:{key}]->(b)",
|
||||
parent_str_id=parent_node_str_id,
|
||||
child_str_id=phase_node_str_id
|
||||
)
|
||||
phase_data.append({"str_id": phase_node_str_id, "name": phase}) #new
|
||||
phase_rels_data.append({"parent_str_id": parent_node_str_id, "child_str_id": phase_node_str_id, "rel_type": key})
|
||||
|
||||
if phase_data:
|
||||
tx.run("""
|
||||
UNWIND $data AS row
|
||||
MERGE (n:phase {str_id: row.str_id})
|
||||
SET n.name = row.name
|
||||
""", data=phase_data)
|
||||
|
||||
if phase_rels_data:
|
||||
tx.run("""
|
||||
UNWIND $rels AS rel
|
||||
MATCH (a:{parent_label} {{str_id: rel.parent_str_id}}),
|
||||
(b:phase {{str_id: rel.child_str_id}})
|
||||
MERGE (a)-[:{rel_type}]->(b)
|
||||
""".format(parent_label=parent_node_label, rel_type=key), rels=phase_rels_data)
|
||||
|
||||
#tx.run(f"MERGE (n:phase {{str_id: $str_id, name: $phase_name}})",
|
||||
# str_id=phase_node_str_id, phase_name=phase)
|
||||
#tx.run(
|
||||
# f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:phase {{str_id: $child_str_id}}) "
|
||||
# f"MERGE (a)-[:{key}]->(b)",
|
||||
# parent_str_id=parent_node_str_id,
|
||||
# child_str_id=phase_node_str_id
|
||||
#)
|
||||
if isinstance(value, dict): # if value of key is a dict, then create new node:
|
||||
# Create a new node for the nested dictionary
|
||||
new_node_str_id = f"{parent_node_str_id}_{key}" # concatenate the parent_node_str_id and key to a new id
|
||||
|
||||
# todo: move batch processing to end of function
|
||||
tx.run(f"MERGE (n:{key} {{str_id: $str_id}})", str_id=new_node_str_id) # create node with key as label
|
||||
|
||||
# Create a relationship from the parent node to the new node
|
||||
@@ -38,26 +72,66 @@ def add_nodes_from_dict(tx, parent_node_label, parent_node_str_id, current_dict)
|
||||
# Create a node for each condition
|
||||
for index, condition in enumerate(value):
|
||||
condition_node_str_id = f"{parent_node_str_id}_{key}_{index}"
|
||||
tx.run(f"MERGE (n:condition {{str_id: $str_id, name: $condition_name}})",
|
||||
str_id=condition_node_str_id, condition_name=condition)
|
||||
tx.run(
|
||||
f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:condition {{str_id: $child_str_id}}) "
|
||||
f"MERGE (a)-[:{key}]->(b)",
|
||||
parent_str_id=parent_node_str_id,
|
||||
child_str_id=condition_node_str_id
|
||||
)
|
||||
condition_data.append({"str_id": condition_node_str_id, "name": condition}) #new
|
||||
condition_rels_data.append({"parent_str_id": parent_node_str_id, "child_str_id": condition_node_str_id, "rel_type": key}) #new
|
||||
|
||||
if condition_data:
|
||||
tx.run("""
|
||||
UNWIND $data AS row
|
||||
MERGE (n:condition {str_id: row.str_id})
|
||||
SET n.name = row.name
|
||||
""", data=condition_data)
|
||||
|
||||
if condition_rels_data:
|
||||
tx.run("""
|
||||
UNWIND $data AS row
|
||||
MATCH (a:{parent_label} {{str_id: row.parent_str_id}}), (b:condition {{str_id: row.child_str_id}})
|
||||
MERGE (a)-[:{rel_type}]->(b)
|
||||
""".format(parent_label=parent_node_label, rel_type=key), data=condition_rels_data)
|
||||
|
||||
#tx.run(f"MERGE (n:condition {{str_id: $str_id, name: $condition_name}})",
|
||||
# str_id=condition_node_str_id, condition_name=condition)
|
||||
# todo: rels?
|
||||
#tx.run(
|
||||
# f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:condition {{str_id: $child_str_id}}) "
|
||||
# f"MERGE (a)-[:{key}]->(b)",
|
||||
# parent_str_id=parent_node_str_id,
|
||||
# child_str_id=condition_node_str_id
|
||||
#)
|
||||
|
||||
elif key == "keywords":
|
||||
# Create a node for each keyword
|
||||
for index, keyword in enumerate(value):
|
||||
keyword_node_str_id = f"{parent_node_str_id}_{key}_{index}"
|
||||
tx.run(f"MERGE (n:keyword {{str_id: $str_id, name: $keyword_name}})", str_id=keyword_node_str_id,
|
||||
keyword_name=keyword)
|
||||
tx.run(
|
||||
f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:keyword {{str_id: $child_str_id}}) "
|
||||
f"MERGE (a)-[:{key}]->(b)",
|
||||
parent_str_id=parent_node_str_id,
|
||||
child_str_id=keyword_node_str_id
|
||||
)
|
||||
keyword_data.append({"str_id": keyword_node_str_id, "name": keyword})
|
||||
keyword_rels_data.append({
|
||||
"parent_str_id": parent_node_str_id,
|
||||
"child_str_id": keyword_node_str_id,
|
||||
"rel_type": key
|
||||
})
|
||||
|
||||
if keyword_data:
|
||||
tx.run("""
|
||||
UNWIND $data AS row
|
||||
MERGE (n:keyword {str_id: row.str_id})
|
||||
SET n.name = row.name
|
||||
""", data=keyword_data)
|
||||
|
||||
if keyword_rels_data:
|
||||
tx.run("""
|
||||
UNWIND $data AS row
|
||||
MATCH (a:{parent_label} {{str_id: row.parent_str_id}}), (b:keyword {{str_id: row.child_str_id}})
|
||||
MERGE (a)-[:{rel_type}]->(b)
|
||||
""".format(parent_label=parent_node_label, rel_type=key), data=keyword_rels_data)
|
||||
|
||||
#tx.run(f"MERGE (n:keyword {{str_id: $str_id, name: $keyword_name}})", str_id=keyword_node_str_id,
|
||||
# keyword_name=keyword)
|
||||
#tx.run(
|
||||
# f"MATCH (a:{parent_node_label} {{str_id: $parent_str_id}}), (b:keyword {{str_id: $child_str_id}}) "
|
||||
# f"MERGE (a)-[:{key}]->(b)",
|
||||
# parent_str_id=parent_node_str_id,
|
||||
# child_str_id=keyword_node_str_id
|
||||
#)
|
||||
|
||||
# if list doesn't contain any nested dictionaries, make it a value in the node
|
||||
if not any(isinstance(item, dict) for item in value):
|
||||
@@ -101,25 +175,54 @@ def add_nodes_from_dict(tx, parent_node_label, parent_node_str_id, current_dict)
|
||||
|
||||
|
||||
# Connect to Neo4j and create the graph
|
||||
def create_graph_from_directory(uri, user, password, directory_path):
|
||||
def create_graph_from_directory(uri, user, password, directory_path, batch_size=100):
|
||||
driver = GraphDatabase.driver(uri, auth=(user, password))
|
||||
|
||||
for filename in os.listdir(directory_path):
|
||||
if filename.endswith('.json'):
|
||||
with driver.session() as session:
|
||||
session.run("CREATE INDEX IF NOT EXISTS FOR (n:ClinicalTrialsEntry) ON (n.str_id)")
|
||||
session.run("CREATE INDEX IF NOT EXISTS FOR (n:phase) ON (n.str_id)")
|
||||
session.run("CREATE INDEX IF NOT EXISTS FOR (n:condition) ON (n.str_id)")
|
||||
session.run("CREATE INDEX IF NOT EXISTS FOR (n:keyword) ON (n.str_id)")
|
||||
session.run("CREATE INDEX IF NOT EXISTS FOR (n:reference) ON (n.str_id)")
|
||||
|
||||
json_files = [file for file in os.listdir(directory_path) if file.endswith(".json")]
|
||||
total_files = len(json_files)
|
||||
|
||||
for i in range(0, total_files, batch_size):
|
||||
batch_files = json_files[i:i + batch_size]
|
||||
print(f"Processing batch {i//batch_size + 1}: {len(batch_files)} files")
|
||||
|
||||
with driver.session() as session:
|
||||
for filename in batch_files:
|
||||
if filename.endswith(".json"):
|
||||
file_path = os.path.join(directory_path, filename)
|
||||
try:
|
||||
with open(file_path, 'r') as file:
|
||||
json_data = json.load(file)
|
||||
|
||||
with driver.session() as session:
|
||||
root_node_label = 'ClinicalTrialsEntry'
|
||||
root_node_str_id = json_data['protocolSection']['identificationModule']['nctId']
|
||||
session.execute_write(
|
||||
lambda tx: tx.run(f"MERGE (n:{root_node_label} {{str_id: $str_id}})", str_id=root_node_str_id))
|
||||
session.execute_write(lambda tx: tx.run(f"MERGE (n:{root_node_label} {{str_id: $str_id}})", str_id=root_node_str_id))
|
||||
session.execute_write(add_nodes_from_dict, root_node_label, root_node_str_id, json_data)
|
||||
|
||||
print(f"Successfully imported: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Failed to import {filename}: {e}")
|
||||
print(f"Failed to import: {filename}: {e}")
|
||||
|
||||
#for filename in os.listdir(directory_path):
|
||||
# if filename.endswith('.json'):
|
||||
# file_path = os.path.join(directory_path, filename)
|
||||
# try:
|
||||
# with open(file_path, 'r') as file:
|
||||
# json_data = json.load(file)
|
||||
|
||||
# with driver.session() as session:
|
||||
# root_node_label = 'ClinicalTrialsEntry'
|
||||
# root_node_str_id = json_data['protocolSection']['identificationModule']['nctId']
|
||||
# session.execute_write(
|
||||
# lambda tx: tx.run(f"MERGE (n:{root_node_label} {{str_id: $str_id}})", str_id=root_node_str_id))
|
||||
# session.execute_write(add_nodes_from_dict, root_node_label, root_node_str_id, json_data)
|
||||
|
||||
# print(f"Successfully imported: {filename}")
|
||||
# except Exception as e:
|
||||
# print(f"Failed to import {filename}: {e}")
|
||||
|
||||
driver.close()
|
||||
|
||||
@@ -2,8 +2,9 @@ import argparse
|
||||
import logging
|
||||
import configparser
|
||||
from ct2neo4j import create_graph_from_directory
|
||||
import time
|
||||
|
||||
STUDY2NEO4J_VERSION: str = "0.1"
|
||||
STUDY2NEO4J_VERSION: str = "0.2"
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -15,6 +16,7 @@ parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--conf', required=True, type=str,
|
||||
help='Configuration file with database connection parameters')
|
||||
parser.add_argument('-f', '--files', required=True, type=str, help='Directory with json files')
|
||||
parser.add_argument('-b', '--batch-size', required=False, type=int, default=100, help='Batch size')
|
||||
|
||||
# parse parameters
|
||||
args = parser.parse_args()
|
||||
@@ -29,4 +31,7 @@ password = conf_file['neo4j']['password']
|
||||
|
||||
# start study2neo4j
|
||||
if __name__ == "__main__":
|
||||
create_graph_from_directory(uri=uri, user=username, password=password, directory_path=json_file_path)
|
||||
start_time = time.time()
|
||||
create_graph_from_directory(uri=uri, user=username, password=password, directory_path=json_file_path, batch_size=int(args.batch_size))
|
||||
end_time = time.time()
|
||||
print("--- %s minutes ---" % ((end_time - start_time)/60))
|
||||
|
||||
Reference in New Issue
Block a user