#!/usr/bin/env python
# coding: utf-8

from pathlib import Path

#import networkx as nx
import yaml
from collections import defaultdict

#extract all node types and generate basic yaml config part for nodes

def write_automated_schema(graph, filePath, mSchemaPath):
    schemaData = {
        'nodes': {},
        'edges': {}
    }
    
    
    if Path(filePath).exists():
        schemaData = loadManualSchema(filePath)
    elif mSchemaPath:
        print("using the manual schema")
        schemaData = loadManualSchema(mSchemaPath)



    for node in graph.nodes():
        label = graph.nodes[node].get('label')
        
        if label == 'resource':
            label = graph.nodes[node].get('resourceType')
        
        label = label.capitalize()

        if not label in schemaData['nodes']:
            schemaData['nodes'][label] = {}
        
        if not 'properties' in schemaData['nodes'][label]:
            schemaData['nodes'][label]['properties'] = {}
        
        for k in graph.nodes[node].keys():
            #print(k, '----- ', graph.nodes[node][k])
            #if k != 'label':
            schemaData['nodes'][label]['properties'][k] = 'str'        

        #schemaData['nodes'][label]['properties'].update(graph.nodes[node].keys())
    
    
    file=open(filePath, 'w')

    for n in schemaData['nodes']:
        temp = n+':\n'
        if 'is_a' in schemaData['nodes'][n]:
            temp += '    is_a: ' + schemaData['nodes'][n]['is_a'] + '\n'
        else:
            temp += '    is_a: named thing\n'

        if 'represented_as' in schemaData['nodes'][n]:
            temp += '    represented_as: ' + schemaData['nodes'][n]['represented_as'] + '\n'
        else:
            temp += '    represented_as: node\n'

        if 'label_in_input' in schemaData['nodes'][n]:
            temp += '    label_in_input: ' + schemaData['nodes'][n]['label_in_input'] + '\n'

        if 'preferred_id' in schemaData['nodes'][n]:
            temp += '    preferred_id: ' + schemaData['nodes'][n]['preferred_id'] + '\n'
        else:
            temp += '    preferred_id: fhir_id\n'

        temp += '    label_in_input: ' + n + '\n'

        temp += '    properties:\n'
        # get property values from schemaData if exists
        
        for pKey in schemaData['nodes'][n]['properties']:
                temp += '        ' + pKey + ': ' + schemaData['nodes'][n]['properties'][pKey] + '\n'
        #elif schemaData['nodes']['properties']:
            #print("----> ", schemaData['nodes']['properties'])
        """ else:
            for attr in schemaData['nodes'][n]:
                temp += '        ' + attr + ': str\n' """

        temp += '\n'

        file.write(temp)

    file.write('\n')

    #extract all relationship types and generate basic yaml config part for relationships
    #if not edgeTypes: edgeTypes = set()

    for u, v, a in graph.edges(data=True):

        #edge_label = graph[u][v].get('edge_type', '')
        source_label = graph.nodes[u].get('label')
        target_label = graph.nodes[v].get('label')

        if source_label == 'resource':
            source_label = graph.nodes[u].get('resourceType', str(u))

        if target_label == 'resource':
            target_label = graph.nodes[v].get('resourceType', str(v))

        source_label = source_label.capitalize()
        #target_label = target_label.capitalize()


        if source_label + ' to ' + target_label + ' association' in schemaData['edges']:
            # add missing attributes
            continue
        elif source_label + ' derived from ' + target_label + ' association' in schemaData['edges']:
            continue
        elif source_label + ' has member ' + target_label + ' association' in schemaData['edges']:
            continue
        elif source_label + ' reasoned by ' + target_label + ' association' in schemaData['edges']:
            continue
        elif source_label + ' is ' + target_label + ' association' in schemaData['edges']:
            continue
        else:
            #schemaData['edges'][source_label + ' to ' + target_label + ' association'] = set()
            schemaData['edges'][source_label + ' to ' + target_label + ' association'] = {
                'is_a': 'association',
                'represented_as': 'edge',
                'label_in_input': source_label + '_to_' + target_label,
                'properties': a
            }
    

    for label in schemaData['edges']:
        temp = '' + label + ':\n'
        for key in schemaData['edges'][label]:
            if key == 'properties':
                temp += '  properties:\n'
                for prop in schemaData['edges'][label][key]:
                    temp += '    ' + prop + ': ' + schemaData['edges'][label][key][prop] + '\n'
            else:
                temp+= '  ' + key + ': ' + schemaData['edges'][label][key] + '\n'

        temp += '\n'
        file.write(temp)

    
    file.close()

def loadManualSchema(path):
    schemaData = {
        'nodes': {},
        'edges': {}
    }
    edgeTypes = set()

    with open(path, 'r') as file:
        # Load YAML with comments stripped
        data = yaml.safe_load(file)

    for label, attrs in data.items():
        cLabel = label.capitalize()
        if not label == 'Title':
            if attrs["represented_as"] == 'node':
                if not hasattr(schemaData['nodes'], cLabel):
                    schemaData['nodes'][cLabel] = set()

                #assuming uniqueness in schema file here. If the same node type exits twice, it will be overwritten.
                schemaData['nodes'][cLabel] = attrs
                #for a in attrs:
                
                #print(v)
                """ for k, v in attrs:
                    if not k == ''
                    schemaData['nodes'][label][k] = v """
            else:
                if not hasattr(schemaData['edges'], cLabel):
                    schemaData['edges'][cLabel] = set()
                
                #assuming uniqueness in schema file here. If the same node type exits twice, it will be overwritten.
                schemaData['edges'][cLabel] = attrs
    
    return schemaData