import os
import glob
import pandas as pd

from biocypher._logger import logger
from biocypher.output.write._batch_writer import parse_label, _BatchWriter


class _Neo4jBatchWriter(_BatchWriter):
    """
    Class for writing node and edge representations to disk using the
    format specified by Neo4j for the use of admin import. Each batch
    writer instance has a fixed representation that needs to be passed
    at instantiation via the :py:attr:`schema` argument. The instance
    also expects an ontology adapter via :py:attr:`ontology_adapter` to be able
    to convert and extend the hierarchy.

    This class inherits from the abstract class "_BatchWriter" and implements the
    Neo4j-specific methods:

        - _write_node_headers
        - _write_edge_headers
        - _construct_import_call
        - _write_array_string
    """

    def __init__(self, *args, **kwargs):
        """
        Constructor.

        Check the version of Neo4j and adds a command scope if version >= 5.

        Returns:
            _Neo4jBatchWriter: An instance of the writer.
        """

        # Should read the configuration and setup import_call_bin_prefix.
        super().__init__(*args, **kwargs)

    def _get_default_import_call_bin_prefix(self):
        """
        Method to provide the default string for the import call bin prefix.

        Returns:
            str: The default location for the neo4j admin import location
        """

        return "bin/"

    def _write_array_string(self, string_list):
        """
        Abstract method to output.write the string representation of an array into a .csv file
        as required by the neo4j admin-import.

        Args:
            string_list (list): list of ontology strings

        Returns:
            str: The string representation of an array for the neo4j admin import
        """
        string = self.adelim.join(string_list)
        return f"{self.quote}{string}{self.quote}"

    def _write_node_headers(self):
        """
        Writes single CSV file for a graph entity that is represented
        as a node as per the definition in the `schema_config.yaml`,
        containing only the header for this type of node.

        Returns:
            bool: The return value. True for success, False otherwise.
        """
        # load headers from data parse
        if not self.node_property_dict:
            logger.error(
                "Header information not found. Was the data parsed first?",
            )
            return False

        for label, props in self.node_property_dict.items():
            _id = ":ID"

            ##MeDaX dev remark:
            ##From Fhir data we get case sensitive labels. E.g. 'Procedure' and 'procedure' are two distinct node types.
            ##Because we are converting Resources to more specific node classes using their "resourceType" attribute.
            
            # translate label to PascalCase
            pascal_label = self.translator.name_sentence_to_pascal(
                parse_label(label)
            )

            header = f"{pascal_label}-header.csv"
            header_path = os.path.join(
                self.outdir,
                header,
            )
            parts = f"{pascal_label}-part.*"

            existing_header = False
            # check if file already exists
            if os.path.exists(header_path):
                logger.warning(
                    f"Header file `{header_path}` already exists. Overwriting.",
                )
                with open(header_path, "r", encoding="utf-8") as existing:
                    existing_header = existing.read().strip().split(self.delim)

            # concatenate key:value in props
            props_list = []
            for k, v in props.items():
                if v in ["int", "long", "integer"]:
                    props_list.append(f"{k}:long")
                elif v in ["int[]", "long[]", "integer[]"]:
                    props_list.append(f"{k}:long[]")
                elif v in ["float", "double", "dbl"]:
                    props_list.append(f"{k}:double")
                elif v in ["float[]", "double[]"]:
                    props_list.append(f"{k}:double[]")
                elif v in ["bool", "boolean"]:
                    # TODO Neo4j boolean support / spelling?
                    props_list.append(f"{k}:boolean")
                elif v in ["bool[]", "boolean[]"]:
                    props_list.append(f"{k}:boolean[]")
                elif v in ["str[]", "string[]"]:
                    props_list.append(f"{k}:string[]")
                else:
                    props_list.append(f"{k}")

            # create list of lists and flatten
            out_list = [[_id], props_list, [":LABEL"]]
            out_list = [val for sublist in out_list for val in sublist]





            with open(header_path, "w", encoding="utf-8") as f:
                # Check if header file already exists and has different columns
                if os.path.exists(header_path):
                    if existing_header:
                        #existing_header = existing.read().strip().split(self.delim)
                        # Compare existing and new headers
                        if set(existing_header) != set(out_list):
                            
                            # Get part files associated with this header
                            base_name = os.path.basename(header_path).replace("-header.csv", "")
                            part_files = glob.glob(os.path.join(os.path.dirname(header_path), f"{base_name}-part*.csv"))
                            
                            
                            # Find the highest numbered part file without full sorting
                            highest_part = None
                            highest_number = -1

                            for part_file in part_files:
                                try:
                                    # Extract number from filename (assuming format like "part123.csv")
                                    file_name = os.path.basename(part_file)
                                    number_part = file_name.split("part")[1].split(".")[0]
                                    number = int(number_part)
                                    
                                    if number > highest_number:
                                        highest_number = number
                                        highest_part = part_file
                                except (IndexError, ValueError):
                                    # Skip files that don't match the expected pattern
                                    continue
                            # Update each part file with the new columns
                            for part_file in part_files:
                                if part_file == highest_part:
                                    print(f"Skipping the highest part file: {highest_part}")
                                    continue
                                try:
                                    #print("exi: ", existing_header)
                                    #print("out: ", out_list)
                                    df = self.adapt_csv_to_new_header(existing_header, out_list, part_file)
                                    # Read the file without headers
                                    
                                    # Write back to file WITHOUT including the header
                                    df.to_csv(part_file, sep=self.delim, index=False, header=False)
                                    print(f"Updated {part_file} with new columns in correct positions")
                                except Exception as e:
                                    print(f"Error updating {part_file}: {e}")
                
                # Write the new header
                row = self.delim.join(out_list)
                f.write(row)


            # add file path to neo4 admin import statement (import call file
            # path may be different from actual file path)
            import_call_header_path = os.path.join(
                self.import_call_file_prefix,
                header,
            )
            import_call_parts_path = os.path.join(
                self.import_call_file_prefix,
                parts,
            )
            self.import_call_nodes.add(
                (import_call_header_path, import_call_parts_path)
            )

        return True

    def _write_edge_headers(self):
        """
        Writes single CSV file for a graph entity that is represented
        as an edge as per the definition in the `schema_config.yaml`,
        containing only the header for this type of edge.

        Returns:
            bool: The return value. True for success, False otherwise.
        """
        # load headers from data parse
        if not self.edge_property_dict:
            logger.error(
                "Header information not found. Was the data parsed first?",
            )
            return False

        for label, props in self.edge_property_dict.items():
            # translate label to PascalCase
            pascal_label = self.translator.name_sentence_to_pascal(
                parse_label(label)
            )

            # paths
            header = f"{pascal_label}-header.csv"
            header_path = os.path.join(
                self.outdir,
                header,
            )
            parts = f"{pascal_label}-part.*"

            # check for file exists
            if os.path.exists(header_path):
                logger.warning(
                    f"File {header_path} already exists. Overwriting."
                )

            # concatenate key:value in props
            props_list = []
            for k, v in props.items():
                if v in ["int", "long", "integer"]:
                    props_list.append(f"{k}:long")
                elif v in ["int[]", "long[]", "integer[]"]:
                    props_list.append(f"{k}:long[]")
                elif v in ["float", "double"]:
                    props_list.append(f"{k}:double")
                elif v in ["float[]", "double[]"]:
                    props_list.append(f"{k}:double[]")
                elif v in [
                    "bool",
                    "boolean",
                ]:  # TODO does Neo4j support bool?
                    props_list.append(f"{k}:boolean")
                elif v in ["bool[]", "boolean[]"]:
                    props_list.append(f"{k}:boolean[]")
                elif v in ["str[]", "string[]"]:
                    props_list.append(f"{k}:string[]")
                else:
                    props_list.append(f"{k}")

            skip_id = False
            schema_label = None

            if label in ["IS_SOURCE_OF", "IS_TARGET_OF", "IS_PART_OF"]:
                skip_id = True
            elif not self.translator.ontology.mapping.extended_schema.get(
                label
            ):
                # find label in schema by label_as_edge
                for (
                    k,
                    v,
                ) in self.translator.ontology.mapping.extended_schema.items():
                    if v.get("label_as_edge") == label:
                        schema_label = k
                        break
            else:
                schema_label = label

            out_list = [":START_ID"]

            if schema_label:
                if (
                    self.translator.ontology.mapping.extended_schema.get(
                        schema_label
                    ).get("use_id")
                    == False
                ):
                    skip_id = True

            if not skip_id:
                out_list.append("id")

            out_list.extend(props_list)
            out_list.extend([":END_ID", ":TYPE"])

            existing_header = False
            # check if file already exists
            if os.path.exists(header_path):
                logger.warning(
                    f"Header file `{header_path}` already exists. Overwriting.",
                )
                with open(header_path, "r", encoding="utf-8") as existing:
                    existing_header = existing.read().strip().split(self.delim)


            with open(header_path, "w", encoding="utf-8") as f:
                # Check if header file already exists and has different columns
                if os.path.exists(header_path):
                    if existing_header:
                        #existing_header = existing.read().strip().split(self.delim)
                        # Compare existing and new headers
                        if set(existing_header) != set(out_list):
                            
                            # Get part files associated with this header
                            base_name = os.path.basename(header_path).replace("-header.csv", "")
                            part_files = glob.glob(os.path.join(os.path.dirname(header_path), f"{base_name}-part*.csv"))
                            
                            
                            # Find the highest numbered part file without full sorting
                            highest_part = None
                            highest_number = -1

                            for part_file in part_files:
                                try:
                                    # Extract number from filename (assuming format like "part123.csv")
                                    file_name = os.path.basename(part_file)
                                    number_part = file_name.split("part")[1].split(".")[0]
                                    number = int(number_part)
                                    
                                    if number > highest_number:
                                        highest_number = number
                                        highest_part = part_file
                                except (IndexError, ValueError):
                                    # Skip files that don't match the expected pattern
                                    continue
                            # Update each part file with the new columns
                            for part_file in part_files:
                                if part_file == highest_part:
                                    print(f"Skipping the highest part file: {highest_part}")
                                    continue
                                try:
                                    print("exi: ", existing_header)
                                    print("out: ", out_list)
                                    df = self.adapt_csv_to_new_header(existing_header, out_list, part_file)
                                    # Read the file without headers
                                    
                                    # Write back to file WITHOUT including the header
                                    df.to_csv(part_file, sep=self.delim, index=False, header=False)
                                    print(f"Updated {part_file} with new columns in correct positions")
                                except Exception as e:
                                    print(f"Error updating {part_file}: {e}")
                
                # Write the new header
                row = self.delim.join(out_list)
                f.write(row)

            # add file path to neo4 admin import statement (import call file
            # path may be different from actual file path)
            import_call_header_path = os.path.join(
                self.import_call_file_prefix,
                header,
            )
            import_call_parts_path = os.path.join(
                self.import_call_file_prefix,
                parts,
            )
            self.import_call_edges.add(
                (import_call_header_path, import_call_parts_path)
            )

        return True

    def _get_import_script_name(self) -> str:
        """
        Returns the name of the neo4j admin import script

        Returns:
            str: The name of the import script (ending in .sh)
        """
        return "neo4j-admin-import-call.sh"

    def _construct_import_call(self) -> str:
        """
        Function to construct the import call detailing folder and
        individual node and edge headers and data files, as well as
        delimiters and database name. Built after all data has been
        processed to ensure that nodes are called before any edges.

        Returns:
            str: a bash command for neo4j-admin import
        """
        import_call_neo4j_v4 = self._get_import_call(
            "import", "--database=", "--force="
        )
        import_call_neo4j_v5 = self._get_import_call(
            "database import full", "", "--overwrite-destination="
        )
        neo4j_version_check = f"version=$({self._get_default_import_call_bin_prefix()}neo4j-admin --version | cut -d '.' -f 1)"

        import_script = f"#!/bin/bash\n{neo4j_version_check}\nif [[ $version -ge 5 ]]; then\n\t{import_call_neo4j_v5}\nelse\n\t{import_call_neo4j_v4}\nfi"
        return import_script

    def _get_import_call(
        self, import_cmd: str, database_cmd: str, wipe_cmd: str
    ) -> str:
        """Get parametrized import call for Neo4j 4 or 5+.

        Args:
            import_cmd (str): The import command to use.
            database_cmd (str): The database command to use.
            wipe_cmd (str): The wipe command to use.

        Returns:
            str: The import call.
        """
        import_call = f"{self.import_call_bin_prefix}neo4j-admin {import_cmd} "

        import_call += f"{database_cmd}{self.db_name} "

        import_call += f'--delimiter="{self.escaped_delim}" '

        import_call += f'--array-delimiter="{self.escaped_adelim}" '

        if self.quote == "'":
            import_call += f'--quote="{self.quote}" '
        else:
            import_call += f"--quote='{self.quote}' "

        if self.wipe:
            import_call += f"{wipe_cmd}true "
        if self.skip_bad_relationships:
            import_call += "--skip-bad-relationships=true "
        if self.skip_duplicate_nodes:
            import_call += "--skip-duplicate-nodes=true "

        # append node import calls
        for header_path, parts_path in self.import_call_nodes:
            import_call += f'--nodes="{header_path},{parts_path}" '

        # append edge import calls
        for header_path, parts_path in self.import_call_edges:
            import_call += f'--relationships="{header_path},{parts_path}" '

        return import_call




    def adapt_csv_to_new_header(self, old_header, new_header, csv_file_path):
        """
        Adapt a CSV table to a new header structure, placing new columns in their correct positions.
        
        Parameters:
        old_header (list): The original header columns
        new_header (list): The new header columns
        csv_file_path (str): Path to the CSV file
        
        Returns:
        pandas.DataFrame: CSV data with the new header structure
        """
        
        # Step 1: Read the CSV data without headers
        df = pd.read_csv(csv_file_path, sep=self.delim, header=None)
        
        # Step 2: If the file is empty, return empty DataFrame with new headers
        if df.empty:
            return pd.DataFrame(columns=new_header)
        
        # Step 3: If column count doesn't match old_header length, handle the mismatch
        if len(df.columns) != len(old_header):
            print(f"Warning: CSV columns count ({len(df.columns)}) doesn't match the provided old header count ({len(old_header)})")
            # If file has fewer columns than old_header, pad with NaN
            if len(df.columns) < len(old_header):
                for i in range(len(df.columns), len(old_header)):
                    df[i] = None
            # If file has more columns than old_header, truncate
            else:
                df = df.iloc[:, :len(old_header)]
        
        # Step 4: Assign old header names to the dataframe
        df.columns = old_header
        
        # Step 5: Create a new DataFrame with the correct structure
        new_df = pd.DataFrame(columns=new_header)
        
        # Step 6: For each column in the new header, find its position in the old header
        for new_col_idx, new_col in enumerate(new_header):
            if new_col in old_header:
                # If column exists in old header, copy data
                new_df[new_col] = df[new_col]
            else:
                # If new column, add empty column
                new_df[new_col] = None
        
        # Step 7: Ensure columns are in the exact order of new_header
        new_df = new_df[new_header]
        
        return new_df