import pickle import networkx as nx from biocypher._logger import logger from biocypher.output.write._writer import _Writer from biocypher.output.write.relational._csv import _PandasCSVWriter class _NetworkXWriter(_Writer): """ Class for writing node and edges to a networkx DiGraph. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.csv_writer = _PandasCSVWriter(*args, write_to_file=False, **kwargs) self.G = nx.DiGraph() def _construct_import_call(self) -> str: """Function to construct the Python code to load all node and edge csv files again into Pandas dfs. Returns: str: Python code to load the csv files into Pandas dfs. """ logger.info( f"Writing networkx {self.G} to pickle file networkx_graph.pkl." ) with open(f"{self.output_directory}/networkx_graph.pkl", "wb") as f: pickle.dump(self.G, f) import_call = "import pickle\n" import_call += "with open('./networkx_graph.pkl', 'rb') as f:\n\tG_loaded = pickle.load(f)" return import_call def _get_import_script_name(self) -> str: """Function to return the name of the import script.""" return "import_networkx.py" def _write_node_data(self, nodes) -> bool: passed = self.csv_writer._write_entities_to_file(nodes) self.add_to_networkx() return passed def _write_edge_data(self, edges) -> bool: passed = self.csv_writer._write_entities_to_file(edges) self.add_to_networkx() return passed def add_to_networkx(self) -> bool: all_dfs = self.csv_writer.stored_dfs node_dfs = [ df for df in all_dfs.values() if df.columns.str.contains("node_id").any() ] edge_dfs = [ df for df in all_dfs.values() if df.columns.str.contains("source_id").any() and df.columns.str.contains("target_id").any() ] for df in node_dfs: nodes = df.set_index("node_id").to_dict(orient="index") self.G.add_nodes_from(nodes.items()) for df in edge_dfs: edges = df.set_index(["source_id", "target_id"]).to_dict( orient="index" ) self.G.add_edges_from( ( (source, target, attrs) for (source, target), attrs in edges.items() ) ) return True