feat: add a penalty factor & fix unit tests

This commit is contained in:
MohamedAliBouhaouala 2025-04-23 12:52:46 +01:00
parent c2c0bd32a5
commit ed959a1051
3 changed files with 123 additions and 31 deletions

View File

@ -52,8 +52,10 @@ import {
blockGetStarted,
blockProductListMock,
blocks,
mockModifiedNlpBlock,
mockNlpBlock,
mockNlpPatternsSetOne,
mockNlpPatternsSetThree,
mockNlpPatternsSetTwo,
} from '@/utils/test/mocks/block';
import {
@ -385,25 +387,46 @@ describe('BlockService', () => {
});
describe('matchBestNLP', () => {
const nlpPenaltyFactor = 2;
it('should return the block with the highest NLP score', async () => {
const blocks = [mockNlpBlock, blockGetStarted]; // You can add more blocks with different patterns and scores
const matchedPatterns = [mockNlpPatternsSetOne, mockNlpPatternsSetTwo];
const nlp = mockNlpEntitiesSetOne;
// Spy on calculateBlockScore to check if it's called
const calculateBlockScoreSpy = jest
.spyOn(blockService, 'calculateBlockScore')
.mockImplementation((patterns) => {
// Return different scores based on the block patterns
if (patterns === mockNlpPatternsSetOne) {
return Promise.resolve(1.499);
} else {
return Promise.resolve(0);
}
});
const calculateBlockScoreSpy = jest.spyOn(
blockService,
'calculateBlockScore',
);
const bestBlock = await blockService.matchBestNLP(
blocks,
matchedPatterns,
nlp,
nlpPenaltyFactor,
);
// // Ensure calculateBlockScore was called at least once for each block
expect(calculateBlockScoreSpy).toHaveBeenCalledTimes(2); // Called for each block
// Restore the spy after the test
calculateBlockScoreSpy.mockRestore();
// Assert that the block with the highest NLP score is selected
expect(bestBlock).toEqual(mockNlpBlock);
});
it('should return the block with the highest NLP score applying penalties', async () => {
const blocks = [mockNlpBlock, mockModifiedNlpBlock]; // You can add more blocks with different patterns and scores
const matchedPatterns = [mockNlpPatternsSetOne, mockNlpPatternsSetThree];
const nlp = mockNlpEntitiesSetOne;
// Spy on calculateBlockScore to check if it's called
const calculateBlockScoreSpy = jest.spyOn(
blockService,
'calculateBlockScore',
);
const bestBlock = await blockService.matchBestNLP(
blocks,
matchedPatterns,
nlp,
nlpPenaltyFactor,
);
// Ensure calculateBlockScore was called at least once for each block
@ -412,7 +435,7 @@ describe('BlockService', () => {
// Restore the spy after the test
calculateBlockScoreSpy.mockRestore();
// Assert that the block with the highest NLP score is selected
expect(bestBlock).toEqual(mockNlpBlock);
expect(bestBlock).toEqual(mockModifiedNlpBlock);
});
it('should return undefined if no blocks match or the list is empty', async () => {
@ -424,6 +447,7 @@ describe('BlockService', () => {
blocks,
matchedPatterns,
nlp,
nlpPenaltyFactor,
);
// Assert that undefined is returned when no blocks are available
@ -432,6 +456,7 @@ describe('BlockService', () => {
});
describe('calculateBlockScore', () => {
const nlpPenaltyFactor = 0.9;
it('should calculate the correct NLP score for a block', async () => {
const nlpCacheMap: NlpCacheMap = new Map();
@ -439,14 +464,38 @@ describe('BlockService', () => {
mockNlpPatternsSetOne,
mockNlpEntitiesSetOne,
nlpCacheMap,
nlpPenaltyFactor,
);
const score2 = await blockService.calculateBlockScore(
mockNlpPatternsSetTwo,
mockNlpEntitiesSetOne,
nlpCacheMap,
nlpPenaltyFactor,
);
expect(score).toBeGreaterThan(0);
expect(score2).toBe(0);
expect(score).toBeGreaterThan(score2);
});
it('should calculate the correct NLP score for a block and apply penalties ', async () => {
const nlpCacheMap: NlpCacheMap = new Map();
const score = await blockService.calculateBlockScore(
mockNlpPatternsSetOne,
mockNlpEntitiesSetOne,
nlpCacheMap,
nlpPenaltyFactor,
);
const score2 = await blockService.calculateBlockScore(
mockNlpPatternsSetThree,
mockNlpEntitiesSetOne,
nlpCacheMap,
nlpPenaltyFactor,
);
expect(score).toBeGreaterThan(0);
expect(score2).toBeGreaterThan(0);
expect(score).toBeGreaterThan(score2);
});
@ -456,6 +505,7 @@ describe('BlockService', () => {
mockNlpPatternsSetTwo,
mockNlpEntitiesSetOne,
nlpCacheMap,
nlpPenaltyFactor,
);
expect(score).toBe(0); // No matching entity, so score should be 0
@ -472,6 +522,7 @@ describe('BlockService', () => {
mockNlpPatternsSetOne,
mockNlpEntitiesSetOne,
nlpCacheMap,
nlpPenaltyFactor,
);
const cacheSizeBefore = nlpCacheMap.size;
const entityCallsBefore = entityServiceSpy.mock.calls.length;
@ -482,6 +533,7 @@ describe('BlockService', () => {
mockNlpPatternsSetOne,
mockNlpEntitiesSetOne,
nlpCacheMap,
nlpPenaltyFactor,
);
const cacheSizeAfter = nlpCacheMap.size;
const entityCallsAfter = entityServiceSpy.mock.calls.length;

View File

@ -200,6 +200,7 @@ export class BlockService extends BaseService<
// to the accumulator array `acc`, which is returned as the final result.
// This ensures that only blocks with valid matches are kept, and blocks with no matches are excluded,
// all while iterating through the list only once.
const matchesWithPatterns = filteredBlocks.reduce<MatchResult[]>(
(acc, b) => {
const matchedPattern = this.matchNLP(nlp, b);
@ -212,6 +213,8 @@ export class BlockService extends BaseService<
[],
);
// @TODO Make nluPenaltyFactor configurable in UI settings
const nluPenaltyFactor = 0.95;
// Log the matched patterns
this.logger.debug(
`Matched patterns: ${JSON.stringify(matchesWithPatterns.map((p) => p.matchedPattern))}`,
@ -223,6 +226,7 @@ export class BlockService extends BaseService<
matchesWithPatterns.map((m) => m.block),
matchesWithPatterns.map((p) => p.matchedPattern),
nlp,
nluPenaltyFactor,
)) as BlockFull | undefined;
}
}
@ -346,14 +350,12 @@ export class BlockService extends BaseService<
const nlpPatterns = block.patterns?.filter((p) => {
return Array.isArray(p);
}) as NlpPattern[][];
// No nlp patterns found
if (nlpPatterns.length === 0) {
return undefined;
}
// Find NLP pattern match based on best guessed entities
return nlpPatterns.find((entities: NlpPattern[]) => {
const pattern = nlpPatterns.find((entities: NlpPattern[]) => {
return entities.every((ev: NlpPattern) => {
if (ev.match === 'value') {
return nlp.entities.find((e) => {
@ -369,22 +371,28 @@ export class BlockService extends BaseService<
}
});
});
this.logger.log(`THE PATTERN ${JSON.stringify(pattern)}`);
return pattern;
}
/**
* Matches the best block based on NLP pattern scoring.
* The function calculates the NLP score for each block based on the matched patterns and selected entity weights,
* and returns the block with the highest score.
* Selects the best-matching block based on NLP pattern scoring.
*
* @param blocks - Array of blocks to match with patterns
* @param matchedPatterns - Array of matched NLP patterns corresponding to each block
* @param nlp - The NLP parsed entities to compare against
* @returns The block with the highest NLP score, or undefined if no valid block is found
* This function evaluates each block by calculating a score derived from its matched NLP patterns,
* the parsed NLP entities, and a penalty factor. It compares the scores across all blocks and
* returns the one with the highest calculated score.
*
* @param blocks - An array of candidate blocks to evaluate.
* @param matchedPatterns - A two-dimensional array of matched NLP patterns corresponding to each block.
* @param nlp - The parsed NLP entities used for scoring.
* @param nlpPenaltyFactor - A numeric penalty factor applied during scoring to influence block selection.
* @returns The block with the highest NLP score, or undefined if no valid block is found.
*/
async matchBestNLP(
blocks: (Block | BlockFull)[] | undefined,
matchedPatterns: NlpPattern[][],
nlp: NLU.ParseEntities,
nlpPenaltyFactor: number,
): Promise<Block | BlockFull | undefined> {
if (!blocks || blocks.length === 0) return undefined;
if (blocks.length === 1) return blocks[0];
@ -398,12 +406,12 @@ export class BlockService extends BaseService<
for (let i = 0; i < blocks.length; i++) {
const block = blocks[i];
const patterns = matchedPatterns[i];
// If compatible, calculate the NLP score for this block
const nlpScore = await this.calculateBlockScore(
patterns,
nlp,
nlpCacheMap,
nlpPenaltyFactor,
);
if (nlpScore > highestScore) {
@ -419,19 +427,25 @@ export class BlockService extends BaseService<
}
/**
* Calculates the NLP score for a single block based on the matched patterns and parsed NLP entities.
* The score is calculated by matching each entity in the pattern with the parsed NLP entities and evaluating
* their confidence and weight from the database.
* Computes the NLP score for a given block using its matched NLP patterns and parsed NLP entities.
*
* @param patterns - The NLP patterns matched for the block
* @param nlp - The parsed NLP entities
* @param nlpCacheMap - A cache for storing previously fetched entity data to avoid redundant DB calls
* @returns The calculated NLP score for the block
* Each pattern is evaluated against the parsed NLP entities to determine matches based on entity name,
* value, and confidence. A score is computed using the entity's weight and the confidence level of the match.
* A penalty factor is optionally applied for entity-level matches to adjust the scoring.
*
* The function uses a cache (`nlpCacheMap`) to avoid redundant database lookups for entity metadata.
*
* @param patterns - The NLP patterns associated with the block.
* @param nlp - The parsed NLP entities from the user input.
* @param nlpCacheMap - A cache to store and reuse fetched entity metadata (e.g., weights and valid values).
* @param nlpPenaltyFactor - A multiplier applied to scores when the pattern match type is 'entity'.
* @returns A numeric score representing how well the block matches the given NLP context.
*/
async calculateBlockScore(
patterns: NlpPattern[],
nlp: NLU.ParseEntities,
nlpCacheMap: NlpCacheMap,
nlpPenaltyFactor: number,
): Promise<number> {
let nlpScore = 0;
@ -472,9 +486,10 @@ export class BlockService extends BaseService<
entityData?.values.some((v) => v === e.value) &&
(pattern.match !== 'value' || e.value === pattern.value),
);
return matchedEntity?.confidence
? matchedEntity.confidence * entityData.weight
? matchedEntity.confidence *
entityData.weight *
(pattern.match === 'entity' ? nlpPenaltyFactor : 1)
: 0;
}),
);

View File

@ -272,6 +272,18 @@ export const mockNlpPatternsSetTwo: NlpPattern[] = [
},
];
export const mockNlpPatternsSetThree: NlpPattern[] = [
{
entity: 'intent',
match: 'value',
value: 'greeting',
},
{
entity: 'firstname',
match: 'entity',
},
];
export const mockNlpBlock = {
...baseBlockInstance,
name: 'Mock Nlp',
@ -299,6 +311,19 @@ export const mockNlpBlock = {
message: ['Good to see you again '],
} as unknown as BlockFull;
export const mockModifiedNlpBlock = {
...baseBlockInstance,
name: 'Modified Mock Nlp',
patterns: [
'Hello',
'/we*lcome/',
{ label: 'Modified Mock Nlp', value: 'MODIFIED_MOCK_NLP' },
[...mockNlpPatternsSetThree],
],
trigger_labels: customerLabelsMock,
message: ['Hello there'],
} as unknown as BlockFull;
const patternsProduct: Pattern[] = [
'produit',
[