
import json
import random
import os
import time
from random import randint
from os import listdir
import numpy as np
from sklearn.metrics import average_precision_score, f1_score,recall_score

current_dir_path = os.path.dirname(os.path.realpath(__file__))

def GroundingDifference(annFile, resFile):
    Anns_vqa = []
    Anns_vizwiz = []
    Ress_vqa=[]
    Ress_vizwiz=[]
    
    with open(annFile,'r') as annF:
        with open(resFile,'r') as resF:
            anns = json.load(annF)
            ress = json.load(resF)
            Anns_labels = np.array([ann["single_grounding"] for ann in anns])
            Res_labels = np.array([res["single_grounding"] for res in ress])
            for ann in anns:
                if ann["question_id"].startswith("Viz"):
                    Anns_vizwiz.append(ann)
                else:
                    Anns_vqa.append(ann)
            for res in ress:
                if res["question_id"].startswith("Viz"):
                    Ress_vizwiz.append(res)
                else:
                    Ress_vqa.append(res)

            Anns_vizwiz_labels = np.array([ann["single_grounding"] for ann in Anns_vizwiz])
            Anns_vqa_labels = np.array([ann["single_grounding"] for ann in Anns_vqa])

            Ress_vizwiz_labels = np.array([res["single_grounding"] for res in Ress_vizwiz])
            Ress_vqa_labels = np.array([res["single_grounding"] for res in Ress_vqa])
            if len(Anns_labels) !=len(Res_labels):
                print("unsucessful submission! The number of files you generated is not equal to the number of ground-truth files.")
            else:
                results ={}

                results["overall_f1"] = round(100*f1_score(Anns_labels,Res_labels>0.5),2)
                results['overall_precision'] = round(100*average_precision_score(Anns_labels,Res_labels>0.5), 2)
                results['overall_recall'] = round(100*recall_score(Anns_labels,Res_labels>0.5), 2)
                results['vqav2_f1'] = round(100*f1_score(Anns_vqa_labels,Ress_vqa_labels>0.5),2)
                results['vqa_precision'] = round(100*average_precision_score(Anns_vqa_labels,Ress_vqa_labels>0.5), 2)
                results['vqa_recall'] = round(100*recall_score(Anns_vqa_labels,Ress_vqa_labels>0.5), 2)
                results['vizwiz_f1'] = round(100*f1_score(Anns_vizwiz_labels,Ress_vizwiz_labels>0.5),2)
                results['vizwiz_precision'] = round(100*average_precision_score(Anns_vizwiz_labels,Ress_vizwiz_labels>0.5), 2)
                results['vizwiz_recall'] = round(100*recall_score(Anns_vizwiz_labels,Ress_vizwiz_labels>0.5), 2)
                return results


phase_splits = {
    "test-dev2023": ["test-dev"],
    "test-standard2023": ["test"],
    "test-challenge2023": ["test"]
}


def evaluate(resFile, phase_codename, **kwargs):
    result = []
    splits = phase_splits[phase_codename]
    for split in splits:
        annFile = os.path.join(current_dir_path, "Annotations", split + ".json")
        # if "dev" in phase_codename:
        print(phase_codename)
        result.append({split: GroundingDifference(annFile,resFile)}) # return a dict from groundingDifference()
        output = {"result": result}
        output["submission_result"] = result
        # output["submission_result"] = output["result"][0]
        print(result)
        print("Completed evaluation for Test Phase")
        return output

if __name__=="__main__":
    evaluate("Results/test.json","test-standard2023")