Merge pull request #1048 from Hexastack/feat/default-nlu-penalty-factor-config

Add default NLU penalty factor and async support for NLP matching
This commit is contained in:
Med Marrouchi
2025-05-29 15:32:23 +01:00
committed by GitHub
6 changed files with 79 additions and 14 deletions

View File

@@ -47,6 +47,7 @@ import { NlpValueService } from '@/nlp/services/nlp-value.service';
import { NlpService } from '@/nlp/services/nlp.service';
import { PluginService } from '@/plugins/plugins.service';
import { SettingService } from '@/setting/services/setting.service';
import { FALLBACK_DEFAULT_NLU_PENALTY_FACTOR } from '@/utils/constants/nlp';
import {
blockFixtures,
installBlockFixtures,
@@ -196,6 +197,7 @@ describe('BlockService', () => {
})),
getSettings: jest.fn(() => ({
contact: { company_name: 'Your company name' },
chatbot_settings: { default_nlu_penalty_factor: 0.95 },
})),
},
},
@@ -467,9 +469,11 @@ describe('BlockService', () => {
blockService,
'calculateNluPatternMatchScore',
);
const bestBlock = blockService.matchBestNLP(
blocks,
mockNlpGreetingNameEntities,
FALLBACK_DEFAULT_NLU_PENALTY_FACTOR,
);
// Ensure calculateBlockScore was called at least once for each block
@@ -509,7 +513,11 @@ describe('BlockService', () => {
blockService,
'calculateNluPatternMatchScore',
);
const bestBlock = blockService.matchBestNLP(blocks, nlp);
const bestBlock = blockService.matchBestNLP(
blocks,
nlp,
FALLBACK_DEFAULT_NLU_PENALTY_FACTOR,
);
// Ensure calculateBlockScore was called at least once for each block
expect(calculateBlockScoreSpy).toHaveBeenCalledTimes(3); // Called for each block
@@ -530,6 +538,7 @@ describe('BlockService', () => {
const bestBlock = blockService.matchBestNLP(
blocks,
mockNlpGreetingNameEntities,
FALLBACK_DEFAULT_NLU_PENALTY_FACTOR,
);
// Assert that undefined is returned when no blocks are available
@@ -542,6 +551,7 @@ describe('BlockService', () => {
const matchingScore = blockService.calculateNluPatternMatchScore(
mockNlpGreetingNamePatterns,
mockNlpGreetingNameEntities,
FALLBACK_DEFAULT_NLU_PENALTY_FACTOR,
);
expect(matchingScore).toBeGreaterThan(0);
@@ -551,15 +561,29 @@ describe('BlockService', () => {
const scoreWithoutPenalty = blockService.calculateNluPatternMatchScore(
mockNlpGreetingNamePatterns,
mockNlpGreetingNameEntities,
FALLBACK_DEFAULT_NLU_PENALTY_FACTOR,
);
const scoreWithPenalty = blockService.calculateNluPatternMatchScore(
mockNlpGreetingAnyNamePatterns,
mockNlpGreetingNameEntities,
FALLBACK_DEFAULT_NLU_PENALTY_FACTOR,
);
expect(scoreWithoutPenalty).toBeGreaterThan(scoreWithPenalty);
});
it('should handle invalid case for penalty factor values', async () => {
// Test with invalid penalty (should use fallback)
const scoreWithInvalidPenalty =
blockService.calculateNluPatternMatchScore(
mockNlpGreetingAnyNamePatterns,
mockNlpGreetingNameEntities,
-1,
);
expect(scoreWithInvalidPenalty).toBeGreaterThan(0); // Should use fallback value
});
});
describe('matchPayload', () => {

View File

@@ -20,6 +20,7 @@ import { NlpService } from '@/nlp/services/nlp.service';
import { PluginService } from '@/plugins/plugins.service';
import { PluginType } from '@/plugins/types';
import { SettingService } from '@/setting/services/setting.service';
import { FALLBACK_DEFAULT_NLU_PENALTY_FACTOR } from '@/utils/constants/nlp';
import { BaseService } from '@/utils/generics/base-service';
import { getRandomElement } from '@/utils/helpers/safeRandom';
@@ -180,8 +181,23 @@ export class BlockService extends BaseService<
const scoredEntities =
await this.nlpService.computePredictionScore(nlp);
const settings = await this.settingService.getSettings();
let penaltyFactor =
settings.chatbot_settings?.default_nlu_penalty_factor;
if (!penaltyFactor) {
this.logger.warn(
'Using fallback NLU penalty factor value: %s',
FALLBACK_DEFAULT_NLU_PENALTY_FACTOR,
);
penaltyFactor = FALLBACK_DEFAULT_NLU_PENALTY_FACTOR;
}
if (scoredEntities.entities.length > 0) {
block = this.matchBestNLP(filteredBlocks, scoredEntities);
block = this.matchBestNLP(
filteredBlocks,
scoredEntities,
penaltyFactor,
);
}
}
}
@@ -351,6 +367,7 @@ export class BlockService extends BaseService<
matchBestNLP<B extends BlockStub>(
blocks: B[],
scoredEntities: NLU.ScoredEntities,
penaltyFactor: number,
): B | undefined {
const bestMatch = blocks.reduce(
(bestMatch, block) => {
@@ -365,10 +382,10 @@ export class BlockService extends BaseService<
const score = this.calculateNluPatternMatchScore(
patterns,
scoredEntities,
penaltyFactor,
);
return Math.max(maxScore, score);
}, 0);
return score > bestMatch.score ? { block, score } : bestMatch;
},
{ block: undefined, score: 0 },
@@ -390,14 +407,13 @@ export class BlockService extends BaseService<
*
* @param patterns - A list of patterns to evaluate against the NLU prediction.
* @param prediction - The scored entities resulting from NLU inference.
* @param [penaltyFactor=0.95] - Optional penalty factor to apply for generic matches (default is 0.95).
*
* @returns The total aggregated match score based on matched patterns and their computed scores.
*/
calculateNluPatternMatchScore(
patterns: NlpPattern[],
prediction: NLU.ScoredEntities,
penaltyFactor = 0.95,
penaltyFactor: number,
): number {
if (!patterns.length || !prediction.entities.length) {
return 0;

View File

@@ -24,6 +24,18 @@ export const DEFAULT_SETTINGS = [
},
weight: 1,
},
{
group: 'chatbot_settings',
label: 'default_nlu_penalty_factor',
value: 0.95,
type: SettingType.number,
config: {
min: 0,
max: 1,
step: 0.01,
},
weight: 2,
},
{
group: 'chatbot_settings',
label: 'default_llm_helper',
@@ -36,7 +48,7 @@ export const DEFAULT_SETTINGS = [
idKey: 'name',
labelKey: 'name',
},
weight: 2,
weight: 3,
},
{
group: 'chatbot_settings',
@@ -50,14 +62,14 @@ export const DEFAULT_SETTINGS = [
idKey: 'name',
labelKey: 'name',
},
weight: 3,
weight: 4,
},
{
group: 'chatbot_settings',
label: 'global_fallback',
value: true,
type: SettingType.checkbox,
weight: 4,
weight: 5,
},
{
group: 'chatbot_settings',
@@ -72,7 +84,7 @@ export const DEFAULT_SETTINGS = [
idKey: 'id',
labelKey: 'name',
},
weight: 5,
weight: 6,
},
{
group: 'chatbot_settings',
@@ -82,7 +94,7 @@ export const DEFAULT_SETTINGS = [
"I'm really sorry but i don't quite understand what you are saying :(",
] as string[],
type: SettingType.multiple_text,
weight: 6,
weight: 7,
translatable: true,
},
{

View File

@@ -0,0 +1,9 @@
/*
* Copyright © 2025 Hexastack. All rights reserved.
*
* Licensed under the GNU Affero General Public License v3.0 (AGPLv3) with the following additional terms:
* 1. The name "Hexabot" is a trademark of Hexastack. You may not use this name in derivative works without express written permission.
* 2. All derivative works must include clear attribution to the original creator and software, Hexastack and Hexabot, in a prominent location (e.g., in the software's "About" section, documentation, and README file).
*/
export const FALLBACK_DEFAULT_NLU_PENALTY_FACTOR = 0.95;