2025-04-16 22:12:19 +02:00

77 lines
2.5 KiB
Python

import pickle
import networkx as nx
from biocypher._logger import logger
from biocypher.output.write._writer import _Writer
from biocypher.output.write.relational._csv import _PandasCSVWriter
class _NetworkXWriter(_Writer):
"""
Class for writing node and edges to a networkx DiGraph.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.csv_writer = _PandasCSVWriter(*args, write_to_file=False, **kwargs)
self.G = nx.DiGraph()
def _construct_import_call(self) -> str:
"""Function to construct the Python code to load all node and edge csv files again into Pandas dfs.
Returns:
str: Python code to load the csv files into Pandas dfs.
"""
logger.info(
f"Writing networkx {self.G} to pickle file networkx_graph.pkl."
)
with open(f"{self.output_directory}/networkx_graph.pkl", "wb") as f:
pickle.dump(self.G, f)
import_call = "import pickle\n"
import_call += "with open('./networkx_graph.pkl', 'rb') as f:\n\tG_loaded = pickle.load(f)"
return import_call
def _get_import_script_name(self) -> str:
"""Function to return the name of the import script."""
return "import_networkx.py"
def _write_node_data(self, nodes) -> bool:
passed = self.csv_writer._write_entities_to_file(nodes)
self.add_to_networkx()
return passed
def _write_edge_data(self, edges) -> bool:
passed = self.csv_writer._write_entities_to_file(edges)
self.add_to_networkx()
return passed
def add_to_networkx(self) -> bool:
all_dfs = self.csv_writer.stored_dfs
node_dfs = [
df
for df in all_dfs.values()
if df.columns.str.contains("node_id").any()
]
edge_dfs = [
df
for df in all_dfs.values()
if df.columns.str.contains("source_id").any()
and df.columns.str.contains("target_id").any()
]
for df in node_dfs:
nodes = df.set_index("node_id").to_dict(orient="index")
self.G.add_nodes_from(nodes.items())
for df in edge_dfs:
edges = df.set_index(["source_id", "target_id"]).to_dict(
orient="index"
)
self.G.add_edges_from(
(
(source, target, attrs)
for (source, target), attrs in edges.items()
)
)
return True