
class PolicyBase(BaseModel):
    list_metrics: List[str] = []
    list_metrics_kwargs: List[dict] = None
    #list_thresholds: List[Union[float, int]] = None
    list_threshold: dict = {}
    scores: dict = defaultdict(dict) 
    all_metrics_results  =  []

    
    def compute(self,references = None, predictions = None, prediction_scores = None, sensitive_features = None,productions = None,list_metrics = None,type = None, **kwargs):
            
            try:
                list_metrics_objects = [metrics_to_class_name[metric_name] for metric_name in list_metrics]
                print(list_metrics_objects)
            except Exception as e:
                print(e)
            
            if references is not None and predictions is not None and type == 'performance':
                for m in list_metrics_objects:
                    try:
                        result = m.compute(references,predictions,prediction_scores,**kwargs)
                        format_result = {'category' : 'performance', 'risk' : list(result.keys())[0], 'value' : list(result.values())[0]}
                        self.all_metrics_results.append(format_result)
                    except Exception as e:
                        print(e)
                    
            
            
            if references is not None and productions is not None and type == 'drift':
                for m in list_metrics_objects:
                    try:
                        result = m.compute(references,productions,**kwargs)
                        format_result = {'category' : 'drift', 'risk' : list(result.keys())[0], 'value' : list(result.values())[0]}
                        self.all_metrics_results.append(format_result)
                    except Exception as e:
                        print(e)

            if type == 'fairness':

                # fairness_metrics and demography_metrics are instance level attributes
                
                fairness_evaluator = Fairness(references=references, predictions=predictions,
                                               sensitive_features=sensitive_features, prediction_scores=prediction_scores, **kwargs)
                fairness_evaluator.fairness_metrics = {k: v for k, v in fairness_evaluator.fairness_metrics.items() if k in list_metrics}
                
                fairness_evaluator.demography_metrics = {k: v for k, v in fairness_evaluator.demography_metrics.items() if k in list_metrics}
                
                fairness_result = fairness_evaluator.compute()
                #subsets_data = give_subsets(y_true=references,y_pred = predictions,prediction_scores= prediction_scores,sensitive_features=sensitive_features)
                format_result = {'category' : 'fairness', 'risk' : list(fairness_result.keys())[0], 'value' : list(fairness_result.values())[0]}
                self.all_metrics_results.append(format_result)
                    
                #self.scores.update(fairness_result)
            
                for result in self.all_metrics_results:
                    category = result['category']
                    metric = result['metric']
                    self.scores[category][metric] = result['value']
            
            return self.scores
    
def calculating_risk_status(scores,list_of_thresholds):
    metrics  = list(scores['risk'])
    for metric in metrics:
        if metric in list_of_thresholds.keys():
            if scores['risk'] > list_of_thresholds[metric]:
                scores[metric][''] = ''
            else:
                scores[metric] = 'not at risk'