mirror of
https://github.com/deepseek-ai/ESFT
synced 2024-11-21 19:17:39 +00:00
Update benchmarks.py
This commit is contained in:
parent
38c8074be0
commit
26b4fc4a8a
@ -138,7 +138,7 @@ class IntentEvaluator(BaseEvaluator):
|
||||
prediction = json.loads(prediction)
|
||||
except:
|
||||
print(f"unable to parse prediction {prediction} of example with gt {ground_truth}")
|
||||
return {'intent_acc': 0, 'slots_f1': 0, 'em': 0}
|
||||
return 0.0
|
||||
|
||||
intent_em = prediction.get('intent', '') == ground_truth.get('intent', '')
|
||||
|
||||
@ -147,7 +147,7 @@ class IntentEvaluator(BaseEvaluator):
|
||||
pred_slots = {(k, str(tuple(sorted([str(i).replace(" ", "") for i in v]))) if isinstance(v, list) else v.replace(" ", "")) for k, v in prediction.get('slots', {}).items()}
|
||||
except:
|
||||
print(f"OK to parse prediction slots {prediction} of example with gt {ground_truth}, but failed in processing the contents.")
|
||||
return {'intent_acc': 0, 'slots_f1': 0, 'em': 0}
|
||||
return 0.0
|
||||
|
||||
correct_slots = pred_slots.intersection(gt_slots)
|
||||
slots_em = (len(correct_slots) == len(pred_slots)) and (len(correct_slots) == len(gt_slots))
|
||||
|
Loading…
Reference in New Issue
Block a user