From 26b4fc4a8abc1790c0b4f8a3d0eba4df9afd2886 Mon Sep 17 00:00:00 2001 From: Zihan Wang <112086423+ZihanWang314@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:33:47 +0800 Subject: [PATCH] Update benchmarks.py --- benchmarks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks.py b/benchmarks.py index e4a8bdd..cb81c4d 100644 --- a/benchmarks.py +++ b/benchmarks.py @@ -138,7 +138,7 @@ class IntentEvaluator(BaseEvaluator): prediction = json.loads(prediction) except: print(f"unable to parse prediction {prediction} of example with gt {ground_truth}") - return {'intent_acc': 0, 'slots_f1': 0, 'em': 0} + return 0.0 intent_em = prediction.get('intent', '') == ground_truth.get('intent', '') @@ -147,7 +147,7 @@ class IntentEvaluator(BaseEvaluator): pred_slots = {(k, str(tuple(sorted([str(i).replace(" ", "") for i in v]))) if isinstance(v, list) else v.replace(" ", "")) for k, v in prediction.get('slots', {}).items()} except: print(f"OK to parse prediction slots {prediction} of example with gt {ground_truth}, but failed in processing the contents.") - return {'intent_acc': 0, 'slots_f1': 0, 'em': 0} + return 0.0 correct_slots = pred_slots.intersection(gt_slots) slots_em = (len(correct_slots) == len(pred_slots)) and (len(correct_slots) == len(gt_slots))