Merge pull request #792 from Hexastack/feat/annotate-sample-with-keyword-entities
Some checks failed
Build and Push Docker API Image / build-and-push (push) Has been cancelled
Build and Push Docker Base Image / build-and-push (push) Has been cancelled
Build and Push Docker UI Image / build-and-push (push) Has been cancelled

feat: nlu keyword entity annotation
This commit is contained in:
Med Marrouchi
2025-03-06 14:08:19 +01:00
committed by GitHub
10 changed files with 519 additions and 36 deletions

View File

@@ -31,6 +31,7 @@ import { CsrfCheck } from '@tekuconcept/nestjs-csrf';
import { Response } from 'express';
import { HelperService } from '@/helper/helper.service';
import { HelperType } from '@/helper/types';
import { LanguageService } from '@/i18n/services/language.service';
import { CsrfInterceptor } from '@/interceptors/csrf.interceptor';
import { LoggerService } from '@/logger/logger.service';
@@ -74,6 +75,28 @@ export class NlpSampleController extends BaseController<
super(nlpSampleService);
}
@CsrfCheck(true)
@Post('annotate/:entityId')
async annotateWithKeywordEntity(@Param('entityId') entityId: string) {
const entity = await this.nlpEntityService.findOneAndPopulate(entityId);
if (!entity) {
throw new NotFoundException('Unable to find the keyword entity.');
}
if (!entity.lookups.includes('keywords')) {
throw new BadRequestException(
'Cannot annotate samples with a non-keyword entity',
);
}
await this.nlpSampleService.annotateWithKeywordEntity(entity);
return {
success: true,
};
}
/**
* Exports the NLP samples in a formatted JSON file, using the Rasa NLU format.
*
@@ -91,7 +114,7 @@ export class NlpSampleController extends BaseController<
type ? { type } : {},
);
const entities = await this.nlpEntityService.findAllAndPopulate();
const helper = await this.helperService.getDefaultNluHelper();
const helper = await this.helperService.getDefaultHelper(HelperType.NLU);
const result = await helper.format(samples, entities);
// Sending the JSON data as a file
@@ -173,27 +196,10 @@ export class NlpSampleController extends BaseController<
*/
@Get('message')
async message(@Query('text') text: string) {
const helper = await this.helperService.getDefaultNluHelper();
const helper = await this.helperService.getDefaultHelper(HelperType.NLU);
return helper.predict(text);
}
/**
* Fetches the samples and entities for a given sample type.
*
* @param type - The sample type (e.g., 'train', 'test').
* @returns An object containing the samples and entities.
* @private
*/
private async getSamplesAndEntitiesByType(type: NlpSample['type']) {
const samples = await this.nlpSampleService.findAndPopulate({
type,
});
const entities = await this.nlpEntityService.findAllAndPopulate();
return { samples, entities };
}
/**
* Initiates the training process for the NLP service using the 'train' sample type.
*
@@ -202,10 +208,10 @@ export class NlpSampleController extends BaseController<
@Get('train')
async train() {
const { samples, entities } =
await this.getSamplesAndEntitiesByType('train');
await this.nlpSampleService.getAllSamplesAndEntitiesByType('train');
try {
const helper = await this.helperService.getDefaultNluHelper();
const helper = await this.helperService.getDefaultHelper(HelperType.NLU);
const response = await helper.train?.(samples, entities);
// Mark samples as trained
await this.nlpSampleService.updateMany(
@@ -229,9 +235,9 @@ export class NlpSampleController extends BaseController<
@Get('evaluate')
async evaluate() {
const { samples, entities } =
await this.getSamplesAndEntitiesByType('test');
await this.nlpSampleService.getAllSamplesAndEntitiesByType('test');
const helper = await this.helperService.getDefaultNluHelper();
const helper = await this.helperService.getDefaultHelper(HelperType.NLU);
return await helper.evaluate?.(samples, entities);
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright © 2024 Hexastack. All rights reserved.
* Copyright © 2025 Hexastack. All rights reserved.
*
* Licensed under the GNU Affero General Public License v3.0 (AGPLv3) with the following additional terms:
* 1. The name "Hexabot" is a trademark of Hexastack. You may not use this name in derivative works without express written permission.
@@ -25,6 +25,7 @@ import {
} from '@/utils/test/test';
import { TFixtures } from '@/utils/test/types';
import { NlpSampleEntityCreateDto } from '../dto/nlp-sample-entity.dto';
import { NlpEntityRepository } from '../repositories/nlp-entity.repository';
import { NlpSampleEntityRepository } from '../repositories/nlp-sample-entity.repository';
import { NlpValueRepository } from '../repositories/nlp-value.repository';
@@ -201,7 +202,15 @@ describe('NlpSampleEntityService', () => {
});
it('should throw an error if stored entity or value cannot be found', async () => {
const sample = { id: 1, text: 'Hello world' } as any as NlpSample;
const sample: NlpSample = {
id: 's1',
text: 'Hello world',
language: null,
trained: false,
type: 'train',
createdAt: new Date(),
updatedAt: new Date(),
};
const entities = [
{ entity: 'greeting', value: 'Hello', start: 0, end: 5 },
];
@@ -214,4 +223,235 @@ describe('NlpSampleEntityService', () => {
).rejects.toThrow('Unable to find the stored entity or value');
});
});
describe('extractKeywordEntities', () => {
it('should extract entities when keywords are found', () => {
const sample = {
id: 's1',
text: 'Hello world, AI is amazing!',
} as NlpSample;
const value = {
id: 'v1',
entity: 'e1',
value: 'AI',
expressions: ['amazing'],
} as NlpValue;
const expected: NlpSampleEntityCreateDto[] = [
{
sample: 's1',
entity: 'e1',
value: 'v1',
start: 13,
end: 15,
},
{
sample: 's1',
entity: 'e1',
value: 'v1',
start: 19,
end: 26,
},
];
expect(
nlpSampleEntityService.extractKeywordEntities(sample, value),
).toEqual(expected);
});
it('should be case-insensitive', () => {
const sample = {
id: 's2',
text: 'I love ai and artificial intelligence.',
} as NlpSample;
const value = {
id: 'v2',
entity: 'e2',
value: 'AI',
expressions: [],
} as unknown as NlpValue;
const expected: NlpSampleEntityCreateDto[] = [
{
sample: 's2',
entity: 'e2',
value: 'v2',
start: 7,
end: 9,
},
];
expect(
nlpSampleEntityService.extractKeywordEntities(sample, value),
).toEqual(expected);
});
it('should extract multiple occurrences of the same keyword', () => {
const sample = {
id: 's3',
text: 'AI AI AI is everywhere.',
} as NlpSample;
const value = {
id: 'v3',
entity: 'e3',
value: 'AI',
expressions: [],
} as unknown as NlpValue;
const expected: NlpSampleEntityCreateDto[] = [
{
sample: 's3',
entity: 'e3',
value: 'v3',
start: 0,
end: 2,
},
{
sample: 's3',
entity: 'e3',
value: 'v3',
start: 3,
end: 5,
},
{
sample: 's3',
entity: 'e3',
value: 'v3',
start: 6,
end: 8,
},
];
expect(
nlpSampleEntityService.extractKeywordEntities(sample, value),
).toEqual(expected);
});
it('should handle empty expressions array correctly', () => {
const sample = {
id: 's4',
text: 'Data science is great.',
} as NlpSample;
const value = {
id: 'v4',
entity: 'e4',
value: 'science',
expressions: [],
} as unknown as NlpValue;
const expected: NlpSampleEntityCreateDto[] = [
{
sample: 's4',
entity: 'e4',
value: 'v4',
start: 5,
end: 12,
},
];
expect(
nlpSampleEntityService.extractKeywordEntities(sample, value),
).toEqual(expected);
});
it('should return an empty array if no matches are found', () => {
const sample = { id: 'sample5', text: 'Hello world!' } as NlpSample;
const value = {
id: 'v5',
entity: 'e5',
value: 'Python',
expressions: [],
} as unknown as NlpValue;
expect(
nlpSampleEntityService.extractKeywordEntities(sample, value),
).toEqual([]);
});
it('should match keywords as whole words only', () => {
const sample = {
id: 'sample6',
text: 'Technical claim.',
} as NlpSample;
const value = {
id: 'v6',
entity: 'e6',
value: 'AI',
expressions: [],
} as unknown as NlpValue;
// Should not match "AI-powered" since it's not a standalone word
const expected: NlpSampleEntityCreateDto[] = [];
expect(
nlpSampleEntityService.extractKeywordEntities(sample, value),
).toEqual(expected);
});
it('should handle special characters in the text correctly', () => {
const sample = { id: 's7', text: 'Hello, AI. AI? AI!' } as NlpSample;
const value = {
id: 'v7',
entity: 'e7',
value: 'AI',
expressions: [],
} as unknown as NlpValue;
const expected: NlpSampleEntityCreateDto[] = [
{
sample: 's7',
entity: 'e7',
value: 'v7',
start: 7,
end: 9,
},
{
sample: 's7',
entity: 'e7',
value: 'v7',
start: 11,
end: 13,
},
{
sample: 's7',
entity: 'e7',
value: 'v7',
start: 15,
end: 17,
},
];
expect(
nlpSampleEntityService.extractKeywordEntities(sample, value),
).toEqual(expected);
});
it('should handle regex special characters in keyword values correctly', () => {
const sample = {
id: 's10',
text: 'Find the,AI, in this text.',
} as NlpSample;
const value = {
id: 'v10',
entity: 'e10',
value: 'AI',
expressions: [],
} as unknown as NlpValue;
const expected: NlpSampleEntityCreateDto[] = [
{
sample: 's10',
entity: 'e10',
value: 'v10',
start: 9,
end: 11,
},
];
expect(
nlpSampleEntityService.extractKeywordEntities(sample, value),
).toEqual(expected);
});
});
});

View File

@@ -1,5 +1,5 @@
/*
* Copyright © 2024 Hexastack. All rights reserved.
* Copyright © 2025 Hexastack. All rights reserved.
*
* Licensed under the GNU Affero General Public License v3.0 (AGPLv3) with the following additional terms:
* 1. The name "Hexabot" is a trademark of Hexastack. You may not use this name in derivative works without express written permission.
@@ -10,13 +10,15 @@ import { Injectable } from '@nestjs/common';
import { BaseService } from '@/utils/generics/base-service';
import { NlpSampleEntityCreateDto } from '../dto/nlp-sample-entity.dto';
import { NlpSampleEntityRepository } from '../repositories/nlp-sample-entity.repository';
import {
NlpSampleEntity,
NlpSampleEntityFull,
NlpSampleEntityPopulate,
} from '../schemas/nlp-sample-entity.schema';
import { NlpSample } from '../schemas/nlp-sample.schema';
import { NlpSample, NlpSampleStub } from '../schemas/nlp-sample.schema';
import { NlpValue } from '../schemas/nlp-value.schema';
import { NlpSampleEntityValue } from '../schemas/types';
import { NlpEntityService } from './nlp-entity.service';
@@ -76,4 +78,41 @@ export class NlpSampleEntityService extends BaseService<
return await this.createMany(sampleEntities);
}
/**
* Extracts entities from a given text sample by matching keywords defined in `NlpValue`.
* The function uses regular expressions to locate each keyword and returns an array of matches.
*
* @param sample - The text sample from which entities should be extracted.
* @param value - The entity value containing the primary keyword and its expressions.
* @returns - An array of extracted entity matches, including their positions.
*/
extractKeywordEntities<S extends NlpSampleStub>(
sample: S,
value: NlpValue,
): NlpSampleEntityCreateDto[] {
const keywords = [value.value, ...value.expressions];
const regex = `(?<!\\p{L})${keywords.join('|')}(?!\\p{L})`;
const regexPattern = new RegExp(regex, 'giu');
const matches: NlpSampleEntityCreateDto[] = [];
let match: RegExpExecArray | null;
// Find all matches in the text using the regex pattern
while ((match = regexPattern.exec(sample.text)) !== null) {
matches.push({
sample: sample.id,
entity: value.entity,
value: value.id,
start: match.index,
end: match.index + match[0].length,
});
// Prevent infinite loops when using a regex with an empty match
if (match.index === regexPattern.lastIndex) {
regexPattern.lastIndex++;
}
}
return matches;
}
}

View File

@@ -1,5 +1,5 @@
/*
* Copyright © 2024 Hexastack. All rights reserved.
* Copyright © 2025 Hexastack. All rights reserved.
*
* Licensed under the GNU Affero General Public License v3.0 (AGPLv3) with the following additional terms:
* 1. The name "Hexabot" is a trademark of Hexastack. You may not use this name in derivative works without express written permission.
@@ -24,11 +24,16 @@ import {
rootMongooseTestModule,
} from '@/utils/test/test';
import { NlpSampleEntityCreateDto } from '../dto/nlp-sample-entity.dto';
import { NlpEntityRepository } from '../repositories/nlp-entity.repository';
import { NlpSampleEntityRepository } from '../repositories/nlp-sample-entity.repository';
import { NlpSampleRepository } from '../repositories/nlp-sample.repository';
import { NlpValueRepository } from '../repositories/nlp-value.repository';
import { NlpEntity, NlpEntityModel } from '../schemas/nlp-entity.schema';
import {
NlpEntity,
NlpEntityFull,
NlpEntityModel,
} from '../schemas/nlp-entity.schema';
import {
NlpSampleEntity,
NlpSampleEntityModel,
@@ -276,4 +281,74 @@ describe('NlpSampleService', () => {
expect(result[1].text).toEqual('Bye');
});
});
describe('annotateWithKeywordEntity', () => {
it('should annotate samples when matching samples exist', async () => {
const entity = {
id: 'entity-id',
name: 'entity_name',
values: [
{
id: 'value-id',
value: 'keyword',
expressions: ['synonym1', 'synonym2'],
},
],
} as NlpEntityFull;
const sampleText = 'This is a test sample with keyword in it.';
const samples = [{ id: 'sample-id', text: sampleText }] as NlpSample[];
const extractedMatches = [
{ sample: 'sample-id', entity: 'test_entity', value: 'keyword' },
] as NlpSampleEntityCreateDto[];
const findSpy = jest
.spyOn(nlpSampleService, 'find')
.mockResolvedValue(samples);
const extractSpy = jest
.spyOn(nlpSampleEntityService, 'extractKeywordEntities')
.mockReturnValue(extractedMatches);
const findOrCreateSpy = jest
.spyOn(nlpSampleEntityService, 'findOneOrCreate')
.mockResolvedValue({} as NlpSampleEntity);
await nlpSampleService.annotateWithKeywordEntity(entity);
expect(findSpy).toHaveBeenCalledWith({
text: { $regex: '\\b(keyword|synonym1|synonym2)\\b', $options: 'i' },
type: ['train', 'test'],
});
expect(extractSpy).toHaveBeenCalledWith(samples[0], entity.values[0]);
expect(findOrCreateSpy).toHaveBeenCalledWith(
extractedMatches[0],
extractedMatches[0],
);
});
it('should not annotate when no matching samples are found', async () => {
const entity = {
id: 'entity-id',
name: 'test_entity',
values: [
{
value: 'keyword',
expressions: ['synonym1', 'synonym2'],
},
],
} as NlpEntityFull;
jest.spyOn(nlpSampleService, 'find').mockResolvedValue([]);
const extractSpy = jest.spyOn(
nlpSampleEntityService,
'extractKeywordEntities',
);
await nlpSampleService.annotateWithKeywordEntity(entity);
expect(extractSpy).not.toHaveBeenCalled();
});
});
});

View File

@@ -1,5 +1,5 @@
/*
* Copyright © 2024 Hexastack. All rights reserved.
* Copyright © 2025 Hexastack. All rights reserved.
*
* Licensed under the GNU Affero General Public License v3.0 (AGPLv3) with the following additional terms:
* 1. The name "Hexabot" is a trademark of Hexastack. You may not use this name in derivative works without express written permission.
@@ -21,8 +21,10 @@ import { LoggerService } from '@/logger/logger.service';
import { BaseService } from '@/utils/generics/base-service';
import { THydratedDocument } from '@/utils/types/filter.types';
import { NlpSampleEntityCreateDto } from '../dto/nlp-sample-entity.dto';
import { NlpSampleCreateDto, TNlpSampleDto } from '../dto/nlp-sample.dto';
import { NlpSampleRepository } from '../repositories/nlp-sample.repository';
import { NlpEntityFull } from '../schemas/nlp-entity.schema';
import {
NlpSample,
NlpSampleFull,
@@ -50,6 +52,22 @@ export class NlpSampleService extends BaseService<
super(repository);
}
/**
* Fetches the samples and entities for a given sample type.
*
* @param type - The sample type (e.g., 'train', 'test').
* @returns An object containing the samples and entities.
*/
public async getAllSamplesAndEntitiesByType(type: NlpSample['type']) {
const samples = await this.findAndPopulate({
type,
});
const entities = await this.nlpEntityService.findAllAndPopulate();
return { samples, entities };
}
/**
* Deletes an NLP sample by its ID and cascades the operation if needed.
*
@@ -165,6 +183,53 @@ export class NlpSampleService extends BaseService<
return nlpSamples;
}
/**
* Iterates through all text samples stored in the database,
* checks if the given keyword exists within each sample, and if so, appends it as an entity.
* The function ensures that duplicate entities are not added and logs the updates.
*
* @param entity The entity
*/
async annotateWithKeywordEntity(entity: NlpEntityFull) {
for (const value of entity.values) {
// For each value, get any sample that may contain the keyword or any of it's synonyms
const keywords = [value.value, ...value.expressions];
const samples = await this.find({
text: { $regex: `\\b(${keywords.join('|')})\\b`, $options: 'i' },
type: ['train', 'test'],
});
if (samples.length > 0) {
this.logger.debug(
`Annotating ${entity.name} - ${value.value} in ${samples.length} sample(s) ...`,
);
for (const sample of samples) {
try {
const matches: NlpSampleEntityCreateDto[] =
this.nlpSampleEntityService.extractKeywordEntities(sample, value);
if (!matches.length) {
throw new Error('Something went wrong, unable to match keywords');
}
const updates = matches.map((dto) =>
this.nlpSampleEntityService.findOneOrCreate(dto, dto),
);
await Promise.all(updates);
this.logger.debug(
`Successfully annotate sample with ${updates.length} matches: ${sample.text}`,
);
} catch (err) {
this.logger.error(`Failed to annotate sample: ${sample.text}`);
}
}
}
}
}
/**
* When a language gets deleted, we need to set related samples to null
*