From 88168795c0fb1f8b15427d9fb44d5f0e475d894c Mon Sep 17 00:00:00 2001 From: Mohamed Marrouchi Date: Mon, 23 Sep 2024 11:35:01 +0100 Subject: [PATCH] feat: adapt nlu prediction --- api/src/chat/services/bot.service.spec.ts | 6 ++ .../nlp/default/__test__/index.mock.ts | 24 +++++- .../nlp/default/__test__/index.spec.ts | 45 ++++++---- .../helpers/nlp/default/index.nlp.helper.ts | 85 +++---------------- .../controllers/language.controller.spec.ts | 14 ++- .../nlp/controllers/nlp-sample.controller.ts | 2 +- .../nlp/services/nlp-sample.service.spec.ts | 11 +++ api/src/nlp/services/nlp-sample.service.ts | 39 ++++++--- .../src/components/nlp/NlpImportDialog.tsx | 1 + .../nlp/components/NlpTrainForm.tsx | 6 +- nlu/main.py | 10 +-- 11 files changed, 133 insertions(+), 110 deletions(-) diff --git a/api/src/chat/services/bot.service.spec.ts b/api/src/chat/services/bot.service.spec.ts index 7e87e4dc..26113d63 100644 --- a/api/src/chat/services/bot.service.spec.ts +++ b/api/src/chat/services/bot.service.spec.ts @@ -28,7 +28,10 @@ import { MenuService } from '@/cms/services/menu.service'; import { offlineEventText } from '@/extensions/channels/offline/__test__/events.mock'; import OfflineHandler from '@/extensions/channels/offline/index.channel'; import OfflineEventWrapper from '@/extensions/channels/offline/wrapper'; +import { LanguageRepository } from '@/i18n/repositories/language.repository'; +import { LanguageModel } from '@/i18n/schemas/language.schema'; import { I18nService } from '@/i18n/services/i18n.service'; +import { LanguageService } from '@/i18n/services/language.service'; import { LoggerService } from '@/logger/logger.service'; import { NlpEntityRepository } from '@/nlp/repositories/nlp-entity.repository'; import { NlpSampleEntityRepository } from '@/nlp/repositories/nlp-sample-entity.repository'; @@ -107,6 +110,7 @@ describe('BlockService', () => { NlpEntityModel, NlpSampleEntityModel, NlpSampleModel, + LanguageModel, ]), ], providers: [ @@ -126,6 +130,7 @@ describe('BlockService', () => { NlpEntityRepository, NlpSampleEntityRepository, NlpSampleRepository, + LanguageRepository, BlockService, CategoryService, ContentTypeService, @@ -143,6 +148,7 @@ describe('BlockService', () => { NlpSampleEntityService, NlpSampleService, NlpService, + LanguageService, { provide: PluginService, useValue: {}, diff --git a/api/src/extensions/helpers/nlp/default/__test__/index.mock.ts b/api/src/extensions/helpers/nlp/default/__test__/index.mock.ts index 22f257df..c30acb8e 100644 --- a/api/src/extensions/helpers/nlp/default/__test__/index.mock.ts +++ b/api/src/extensions/helpers/nlp/default/__test__/index.mock.ts @@ -23,6 +23,10 @@ export const nlpEmptyFormated: DatasetType = { name: 'product', elements: ['pizza', 'sandwich'], }, + { + elements: ['en', 'fr'], + name: 'language', + }, ], entity_synonyms: [ { @@ -34,17 +38,33 @@ export const nlpEmptyFormated: DatasetType = { export const nlpFormatted: DatasetType = { common_examples: [ - { text: 'Hello', intent: 'greeting', entities: [] }, + { + text: 'Hello', + intent: 'greeting', + entities: [ + { + entity: 'language', + value: 'en', + }, + ], + }, { text: 'i want to order a pizza', intent: 'order', - entities: [{ entity: 'product', value: 'pizza', start: 19, end: 23 }], + entities: [ + { entity: 'product', value: 'pizza', start: 19, end: 23 }, + { + entity: 'language', + value: 'en', + }, + ], }, ], regex_features: [], lookup_tables: [ { name: 'intent', elements: ['greeting', 'order'] }, { name: 'product', elements: ['pizza', 'sandwich'] }, + { name: 'language', elements: ['en', 'fr'] }, ], entity_synonyms: [ { diff --git a/api/src/extensions/helpers/nlp/default/__test__/index.spec.ts b/api/src/extensions/helpers/nlp/default/__test__/index.spec.ts index ac8027d0..ca9dd05c 100644 --- a/api/src/extensions/helpers/nlp/default/__test__/index.spec.ts +++ b/api/src/extensions/helpers/nlp/default/__test__/index.spec.ts @@ -8,10 +8,14 @@ */ import { HttpModule } from '@nestjs/axios'; +import { CACHE_MANAGER } from '@nestjs/cache-manager'; import { EventEmitter2 } from '@nestjs/event-emitter'; import { MongooseModule } from '@nestjs/mongoose'; import { Test, TestingModule } from '@nestjs/testing'; +import { LanguageRepository } from '@/i18n/repositories/language.repository'; +import { LanguageModel } from '@/i18n/schemas/language.schema'; +import { LanguageService } from '@/i18n/services/language.service'; import { LoggerService } from '@/logger/logger.service'; import { NlpEntityRepository } from '@/nlp/repositories/nlp-entity.repository'; import { NlpSampleEntityRepository } from '@/nlp/repositories/nlp-sample-entity.repository'; @@ -56,10 +60,24 @@ describe('NLP Default Helper', () => { NlpValueModel, NlpSampleModel, NlpSampleEntityModel, + LanguageModel, ]), HttpModule, ], providers: [ + NlpService, + NlpSampleService, + NlpSampleRepository, + NlpEntityService, + NlpEntityRepository, + NlpValueService, + NlpValueRepository, + NlpSampleEntityService, + NlpSampleEntityRepository, + LanguageService, + LanguageRepository, + EventEmitter2, + DefaultNlpHelper, LoggerService, { provide: SettingService, @@ -76,17 +94,14 @@ describe('NLP Default Helper', () => { })), }, }, - NlpService, - NlpSampleService, - NlpSampleRepository, - NlpEntityService, - NlpEntityRepository, - NlpValueService, - NlpValueRepository, - NlpSampleEntityService, - NlpSampleEntityRepository, - EventEmitter2, - DefaultNlpHelper, + { + provide: CACHE_MANAGER, + useValue: { + del: jest.fn(), + get: jest.fn(), + set: jest.fn(), + }, + }, ], }).compile(); settingService = module.get(SettingService); @@ -103,15 +118,15 @@ describe('NLP Default Helper', () => { expect(nlp).toBeDefined(); }); - it('should format empty training set properly', () => { + it('should format empty training set properly', async () => { const nlp = nlpService.getNLP(); - const results = nlp.format([], entitiesMock); + const results = await nlp.format([], entitiesMock); expect(results).toEqual(nlpEmptyFormated); }); - it('should format training set properly', () => { + it('should format training set properly', async () => { const nlp = nlpService.getNLP(); - const results = nlp.format(samplesMock, entitiesMock); + const results = await nlp.format(samplesMock, entitiesMock); expect(results).toEqual(nlpFormatted); }); diff --git a/api/src/extensions/helpers/nlp/default/index.nlp.helper.ts b/api/src/extensions/helpers/nlp/default/index.nlp.helper.ts index 059dfdaf..f884c52c 100644 --- a/api/src/extensions/helpers/nlp/default/index.nlp.helper.ts +++ b/api/src/extensions/helpers/nlp/default/index.nlp.helper.ts @@ -13,21 +13,13 @@ import { Injectable } from '@nestjs/common'; import { LoggerService } from '@/logger/logger.service'; import BaseNlpHelper from '@/nlp/lib/BaseNlpHelper'; import { Nlp } from '@/nlp/lib/types'; -import { NlpEntity, NlpEntityFull } from '@/nlp/schemas/nlp-entity.schema'; +import { NlpEntityFull } from '@/nlp/schemas/nlp-entity.schema'; import { NlpSampleFull } from '@/nlp/schemas/nlp-sample.schema'; -import { NlpValue } from '@/nlp/schemas/nlp-value.schema'; import { NlpEntityService } from '@/nlp/services/nlp-entity.service'; import { NlpSampleService } from '@/nlp/services/nlp-sample.service'; import { NlpService } from '@/nlp/services/nlp.service'; -import { - CommonExample, - DatasetType, - EntitySynonym, - ExampleEntity, - LookupTable, - NlpParseResultType, -} from './types'; +import { DatasetType, NlpParseResultType } from './types'; @Injectable() export default class DefaultNlpHelper extends BaseNlpHelper { @@ -61,69 +53,16 @@ export default class DefaultNlpHelper extends BaseNlpHelper { * @param entities - All available entities * @returns {DatasetType} - The formatted RASA training set */ - format(samples: NlpSampleFull[], entities: NlpEntityFull[]): DatasetType { - const entityMap = NlpEntity.getEntityMap(entities); - const valueMap = NlpValue.getValueMap( - NlpValue.getValuesFromEntities(entities), + async format( + samples: NlpSampleFull[], + entities: NlpEntityFull[], + ): Promise { + const nluData = await this.nlpSampleService.formatRasaNlu( + samples, + entities, ); - const common_examples: CommonExample[] = samples - .filter((s) => s.entities.length > 0) - .map((s) => { - const intent = s.entities.find( - (e) => entityMap[e.entity].name === 'intent', - ); - if (!intent) { - throw new Error('Unable to find the `intent` nlp entity.'); - } - const sampleEntities: ExampleEntity[] = s.entities - .filter((e) => entityMap[e.entity].name !== 'intent') - .map((e) => { - const res: ExampleEntity = { - entity: entityMap[e.entity].name, - value: valueMap[e.value].value, - }; - if ('start' in e && 'end' in e) { - Object.assign(res, { - start: e.start, - end: e.end, - }); - } - return res; - }); - return { - text: s.text, - intent: valueMap[intent.value].value, - entities: sampleEntities, - }; - }); - const lookup_tables: LookupTable[] = entities.map((e) => { - return { - name: e.name, - elements: e.values.map((v) => { - return v.value; - }), - }; - }); - const entity_synonyms = entities - .reduce((acc, e) => { - const synonyms = e.values.map((v) => { - return { - value: v.value, - synonyms: v.expressions, - }; - }); - return acc.concat(synonyms); - }, [] as EntitySynonym[]) - .filter((s) => { - return s.synonyms.length > 0; - }); - return { - common_examples, - regex_features: [], - lookup_tables, - entity_synonyms, - }; + return nluData; } /** @@ -138,7 +77,7 @@ export default class DefaultNlpHelper extends BaseNlpHelper { entities: NlpEntityFull[], ): Promise { const self = this; - const nluData: DatasetType = self.format(samples, entities); + const nluData: DatasetType = await self.format(samples, entities); // Train samples const result = await this.httpService.axiosRef.post( `${this.settings.endpoint}/train`, @@ -169,7 +108,7 @@ export default class DefaultNlpHelper extends BaseNlpHelper { entities: NlpEntityFull[], ): Promise { const self = this; - const nluTestData: DatasetType = self.format(samples, entities); + const nluTestData: DatasetType = await self.format(samples, entities); // Evaluate model with test samples return await this.httpService.axiosRef.post( `${this.settings.endpoint}/evaluate`, diff --git a/api/src/i18n/controllers/language.controller.spec.ts b/api/src/i18n/controllers/language.controller.spec.ts index 24e07c4c..f4d66b21 100644 --- a/api/src/i18n/controllers/language.controller.spec.ts +++ b/api/src/i18n/controllers/language.controller.spec.ts @@ -96,13 +96,23 @@ describe('LanguageController', () => { }); describe('findPage', () => { - const pageQuery = getPageQuery(); + const pageQuery = getPageQuery({ sort: ['code', 'asc'] }); it('should find languages', async () => { jest.spyOn(languageService, 'findPage'); const result = await languageController.findPage(pageQuery, {}); expect(languageService.findPage).toHaveBeenCalledWith({}, pageQuery); - expect(result).toEqualPayload(languageFixtures); + expect(result).toEqualPayload( + languageFixtures.sort(({ code: codeA }, { code: codeB }) => { + if (codeA < codeB) { + return -1; + } + if (codeA > codeB) { + return 1; + } + return 0; + }), + ); }); }); diff --git a/api/src/nlp/controllers/nlp-sample.controller.ts b/api/src/nlp/controllers/nlp-sample.controller.ts index 8e28fe18..50a118db 100644 --- a/api/src/nlp/controllers/nlp-sample.controller.ts +++ b/api/src/nlp/controllers/nlp-sample.controller.ts @@ -93,7 +93,7 @@ export class NlpSampleController extends BaseController< type ? { type } : {}, ); const entities = await this.nlpEntityService.findAllAndPopulate(); - const result = this.nlpSampleService.formatRasaNlu(samples, entities); + const result = await this.nlpSampleService.formatRasaNlu(samples, entities); // Sending the JSON data as a file const buffer = Buffer.from(JSON.stringify(result)); diff --git a/api/src/nlp/services/nlp-sample.service.spec.ts b/api/src/nlp/services/nlp-sample.service.spec.ts index 5b80719d..970a933a 100644 --- a/api/src/nlp/services/nlp-sample.service.spec.ts +++ b/api/src/nlp/services/nlp-sample.service.spec.ts @@ -7,12 +7,14 @@ * 3. SaaS Restriction: This software, or any derivative of it, may not be used to offer a competing product or service (SaaS) without prior written consent from Hexastack. Offering the software as a service or using it in a commercial cloud environment without express permission is strictly prohibited. */ +import { CACHE_MANAGER } from '@nestjs/cache-manager'; import { EventEmitter2 } from '@nestjs/event-emitter'; import { MongooseModule } from '@nestjs/mongoose'; import { Test, TestingModule } from '@nestjs/testing'; import { LanguageRepository } from '@/i18n/repositories/language.repository'; import { Language, LanguageModel } from '@/i18n/schemas/language.schema'; +import { LanguageService } from '@/i18n/services/language.service'; import { nlpSampleFixtures } from '@/utils/test/fixtures/nlpsample'; import { installNlpSampleEntityFixtures } from '@/utils/test/fixtures/nlpsampleentity'; import { getPageQuery } from '@/utils/test/pagination'; @@ -68,7 +70,16 @@ describe('NlpSampleService', () => { NlpSampleEntityService, NlpEntityService, NlpValueService, + LanguageService, EventEmitter2, + { + provide: CACHE_MANAGER, + useValue: { + del: jest.fn(), + get: jest.fn(), + set: jest.fn(), + }, + }, ], }).compile(); nlpSampleService = module.get(NlpSampleService); diff --git a/api/src/nlp/services/nlp-sample.service.ts b/api/src/nlp/services/nlp-sample.service.ts index 4ebb131e..34030091 100644 --- a/api/src/nlp/services/nlp-sample.service.ts +++ b/api/src/nlp/services/nlp-sample.service.ts @@ -16,6 +16,7 @@ import { ExampleEntity, LookupTable, } from '@/extensions/helpers/nlp/default/types'; +import { LanguageService } from '@/i18n/services/language.service'; import { BaseService } from '@/utils/generics/base-service'; import { NlpSampleRepository } from '../repositories/nlp-sample.repository'; @@ -33,7 +34,10 @@ export class NlpSampleService extends BaseService< NlpSamplePopulate, NlpSampleFull > { - constructor(readonly repository: NlpSampleRepository) { + constructor( + readonly repository: NlpSampleRepository, + private readonly languageService: LanguageService, + ) { super(repository); } @@ -56,10 +60,10 @@ export class NlpSampleService extends BaseService< * * @returns The formatted Rasa NLU training dataset. */ - formatRasaNlu( + async formatRasaNlu( samples: NlpSampleFull[], entities: NlpEntityFull[], - ): DatasetType { + ): Promise { const entityMap = NlpEntity.getEntityMap(entities); const valueMap = NlpValue.getValueMap( NlpValue.getValuesFromEntities(entities), @@ -88,21 +92,34 @@ export class NlpSampleService extends BaseService< }); } return res; + }) + // TODO : place language at the same level as the intent + .concat({ + entity: 'language', + value: s.language.code, }); + return { text: s.text, intent: valueMap[intent.value].value, entities: sampleEntities, }; }); - const lookup_tables: LookupTable[] = entities.map((e) => { - return { - name: e.name, - elements: e.values.map((v) => { - return v.value; - }), - }; - }); + + const languages = await this.languageService.getLanguages(); + const lookup_tables: LookupTable[] = entities + .map((e) => { + return { + name: e.name, + elements: e.values.map((v) => { + return v.value; + }), + }; + }) + .concat({ + name: 'language', + elements: Object.keys(languages), + }); const entity_synonyms = entities .reduce((acc, e) => { const synonyms = e.values.map((v) => { diff --git a/frontend/src/components/nlp/NlpImportDialog.tsx b/frontend/src/components/nlp/NlpImportDialog.tsx index f6289b19..0be43dd8 100644 --- a/frontend/src/components/nlp/NlpImportDialog.tsx +++ b/frontend/src/components/nlp/NlpImportDialog.tsx @@ -44,6 +44,7 @@ export const NlpImportDialog: FC = ({ QueryType.collection, EntityType.NLP_SAMPLE, ]); + queryClient.removeQueries([QueryType.count, EntityType.NLP_SAMPLE]); handleCloseDialog(); toast.success(t("message.success_save")); diff --git a/frontend/src/components/nlp/components/NlpTrainForm.tsx b/frontend/src/components/nlp/components/NlpTrainForm.tsx index efe61662..2a90e3b6 100644 --- a/frontend/src/components/nlp/components/NlpTrainForm.tsx +++ b/frontend/src/components/nlp/components/NlpTrainForm.tsx @@ -157,12 +157,16 @@ const NlpDatasetSample: FC = ({ }, onSuccess: (result) => { const traitEntities: INlpDatasetTraitEntity[] = result.entities.filter( - (e) => !("start" in e && "end" in e), + (e) => !("start" in e && "end" in e) && e.entity !== "language", ); const keywordEntities = result.entities.filter( (e) => "start" in e && "end" in e, ) as INlpDatasetKeywordEntity[]; + const language = result.entities.find( + ({ entity }) => entity === "language", + ); + setValue("language", language?.value || ""); setValue("traitEntities", traitEntities); setValue("keywordEntities", keywordEntities); }, diff --git a/nlu/main.py b/nlu/main.py index f7e4f8ba..52f12ea2 100644 --- a/nlu/main.py +++ b/nlu/main.py @@ -88,13 +88,13 @@ def parse(input: ParseInput, is_authenticated: Annotated[str, Depends(authentica headers = {"Retry-After": "120"} # Suggest retrying after 2 minutes return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content={"message": "Models are still loading, please retry later."}, headers=headers) - language = app.language_classifier.get_prediction(input.q) # type: ignore - lang = language.get("value") - intent_prediction = app.intent_classifiers[lang].get_prediction( + language_prediction = app.language_classifier.get_prediction(input.q) # type: ignore + language = language_prediction.get("value") + intent_prediction = app.intent_classifiers[language].get_prediction( input.q) # type: ignore - slot_prediction = app.slot_fillers[lang].get_prediction( + slot_prediction = app.slot_fillers[language].get_prediction( input.q) # type: ignore - slot_prediction.get("entities").append(language) + slot_prediction.get("entities").append(language_prediction) return { "text": input.q,