77 lines
2.5 KiB
Python
77 lines
2.5 KiB
Python
import pickle
|
|
|
|
import networkx as nx
|
|
|
|
from biocypher._logger import logger
|
|
from biocypher.output.write._writer import _Writer
|
|
from biocypher.output.write.relational._csv import _PandasCSVWriter
|
|
|
|
|
|
class _NetworkXWriter(_Writer):
|
|
"""
|
|
Class for writing node and edges to a networkx DiGraph.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.csv_writer = _PandasCSVWriter(*args, write_to_file=False, **kwargs)
|
|
self.G = nx.DiGraph()
|
|
|
|
def _construct_import_call(self) -> str:
|
|
"""Function to construct the Python code to load all node and edge csv files again into Pandas dfs.
|
|
|
|
Returns:
|
|
str: Python code to load the csv files into Pandas dfs.
|
|
"""
|
|
logger.info(
|
|
f"Writing networkx {self.G} to pickle file networkx_graph.pkl."
|
|
)
|
|
with open(f"{self.output_directory}/networkx_graph.pkl", "wb") as f:
|
|
pickle.dump(self.G, f)
|
|
|
|
import_call = "import pickle\n"
|
|
import_call += "with open('./networkx_graph.pkl', 'rb') as f:\n\tG_loaded = pickle.load(f)"
|
|
return import_call
|
|
|
|
def _get_import_script_name(self) -> str:
|
|
"""Function to return the name of the import script."""
|
|
return "import_networkx.py"
|
|
|
|
def _write_node_data(self, nodes) -> bool:
|
|
passed = self.csv_writer._write_entities_to_file(nodes)
|
|
self.add_to_networkx()
|
|
return passed
|
|
|
|
def _write_edge_data(self, edges) -> bool:
|
|
passed = self.csv_writer._write_entities_to_file(edges)
|
|
self.add_to_networkx()
|
|
return passed
|
|
|
|
def add_to_networkx(self) -> bool:
|
|
all_dfs = self.csv_writer.stored_dfs
|
|
node_dfs = [
|
|
df
|
|
for df in all_dfs.values()
|
|
if df.columns.str.contains("node_id").any()
|
|
]
|
|
edge_dfs = [
|
|
df
|
|
for df in all_dfs.values()
|
|
if df.columns.str.contains("source_id").any()
|
|
and df.columns.str.contains("target_id").any()
|
|
]
|
|
for df in node_dfs:
|
|
nodes = df.set_index("node_id").to_dict(orient="index")
|
|
self.G.add_nodes_from(nodes.items())
|
|
for df in edge_dfs:
|
|
edges = df.set_index(["source_id", "target_id"]).to_dict(
|
|
orient="index"
|
|
)
|
|
self.G.add_edges_from(
|
|
(
|
|
(source, target, attrs)
|
|
for (source, target), attrs in edges.items()
|
|
)
|
|
)
|
|
return True
|