feat: adapt nlu prediction

This commit is contained in:
Mohamed Marrouchi 2024-09-23 11:35:01 +01:00
parent 0c02b51cf6
commit 88168795c0
11 changed files with 133 additions and 110 deletions

View File

@ -28,7 +28,10 @@ import { MenuService } from '@/cms/services/menu.service';
import { offlineEventText } from '@/extensions/channels/offline/__test__/events.mock';
import OfflineHandler from '@/extensions/channels/offline/index.channel';
import OfflineEventWrapper from '@/extensions/channels/offline/wrapper';
import { LanguageRepository } from '@/i18n/repositories/language.repository';
import { LanguageModel } from '@/i18n/schemas/language.schema';
import { I18nService } from '@/i18n/services/i18n.service';
import { LanguageService } from '@/i18n/services/language.service';
import { LoggerService } from '@/logger/logger.service';
import { NlpEntityRepository } from '@/nlp/repositories/nlp-entity.repository';
import { NlpSampleEntityRepository } from '@/nlp/repositories/nlp-sample-entity.repository';
@ -107,6 +110,7 @@ describe('BlockService', () => {
NlpEntityModel,
NlpSampleEntityModel,
NlpSampleModel,
LanguageModel,
]),
],
providers: [
@ -126,6 +130,7 @@ describe('BlockService', () => {
NlpEntityRepository,
NlpSampleEntityRepository,
NlpSampleRepository,
LanguageRepository,
BlockService,
CategoryService,
ContentTypeService,
@ -143,6 +148,7 @@ describe('BlockService', () => {
NlpSampleEntityService,
NlpSampleService,
NlpService,
LanguageService,
{
provide: PluginService,
useValue: {},

View File

@ -23,6 +23,10 @@ export const nlpEmptyFormated: DatasetType = {
name: 'product',
elements: ['pizza', 'sandwich'],
},
{
elements: ['en', 'fr'],
name: 'language',
},
],
entity_synonyms: [
{
@ -34,17 +38,33 @@ export const nlpEmptyFormated: DatasetType = {
export const nlpFormatted: DatasetType = {
common_examples: [
{ text: 'Hello', intent: 'greeting', entities: [] },
{
text: 'Hello',
intent: 'greeting',
entities: [
{
entity: 'language',
value: 'en',
},
],
},
{
text: 'i want to order a pizza',
intent: 'order',
entities: [{ entity: 'product', value: 'pizza', start: 19, end: 23 }],
entities: [
{ entity: 'product', value: 'pizza', start: 19, end: 23 },
{
entity: 'language',
value: 'en',
},
],
},
],
regex_features: [],
lookup_tables: [
{ name: 'intent', elements: ['greeting', 'order'] },
{ name: 'product', elements: ['pizza', 'sandwich'] },
{ name: 'language', elements: ['en', 'fr'] },
],
entity_synonyms: [
{

View File

@ -8,10 +8,14 @@
*/
import { HttpModule } from '@nestjs/axios';
import { CACHE_MANAGER } from '@nestjs/cache-manager';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { MongooseModule } from '@nestjs/mongoose';
import { Test, TestingModule } from '@nestjs/testing';
import { LanguageRepository } from '@/i18n/repositories/language.repository';
import { LanguageModel } from '@/i18n/schemas/language.schema';
import { LanguageService } from '@/i18n/services/language.service';
import { LoggerService } from '@/logger/logger.service';
import { NlpEntityRepository } from '@/nlp/repositories/nlp-entity.repository';
import { NlpSampleEntityRepository } from '@/nlp/repositories/nlp-sample-entity.repository';
@ -56,10 +60,24 @@ describe('NLP Default Helper', () => {
NlpValueModel,
NlpSampleModel,
NlpSampleEntityModel,
LanguageModel,
]),
HttpModule,
],
providers: [
NlpService,
NlpSampleService,
NlpSampleRepository,
NlpEntityService,
NlpEntityRepository,
NlpValueService,
NlpValueRepository,
NlpSampleEntityService,
NlpSampleEntityRepository,
LanguageService,
LanguageRepository,
EventEmitter2,
DefaultNlpHelper,
LoggerService,
{
provide: SettingService,
@ -76,17 +94,14 @@ describe('NLP Default Helper', () => {
})),
},
},
NlpService,
NlpSampleService,
NlpSampleRepository,
NlpEntityService,
NlpEntityRepository,
NlpValueService,
NlpValueRepository,
NlpSampleEntityService,
NlpSampleEntityRepository,
EventEmitter2,
DefaultNlpHelper,
{
provide: CACHE_MANAGER,
useValue: {
del: jest.fn(),
get: jest.fn(),
set: jest.fn(),
},
},
],
}).compile();
settingService = module.get<SettingService>(SettingService);
@ -103,15 +118,15 @@ describe('NLP Default Helper', () => {
expect(nlp).toBeDefined();
});
it('should format empty training set properly', () => {
it('should format empty training set properly', async () => {
const nlp = nlpService.getNLP();
const results = nlp.format([], entitiesMock);
const results = await nlp.format([], entitiesMock);
expect(results).toEqual(nlpEmptyFormated);
});
it('should format training set properly', () => {
it('should format training set properly', async () => {
const nlp = nlpService.getNLP();
const results = nlp.format(samplesMock, entitiesMock);
const results = await nlp.format(samplesMock, entitiesMock);
expect(results).toEqual(nlpFormatted);
});

View File

@ -13,21 +13,13 @@ import { Injectable } from '@nestjs/common';
import { LoggerService } from '@/logger/logger.service';
import BaseNlpHelper from '@/nlp/lib/BaseNlpHelper';
import { Nlp } from '@/nlp/lib/types';
import { NlpEntity, NlpEntityFull } from '@/nlp/schemas/nlp-entity.schema';
import { NlpEntityFull } from '@/nlp/schemas/nlp-entity.schema';
import { NlpSampleFull } from '@/nlp/schemas/nlp-sample.schema';
import { NlpValue } from '@/nlp/schemas/nlp-value.schema';
import { NlpEntityService } from '@/nlp/services/nlp-entity.service';
import { NlpSampleService } from '@/nlp/services/nlp-sample.service';
import { NlpService } from '@/nlp/services/nlp.service';
import {
CommonExample,
DatasetType,
EntitySynonym,
ExampleEntity,
LookupTable,
NlpParseResultType,
} from './types';
import { DatasetType, NlpParseResultType } from './types';
@Injectable()
export default class DefaultNlpHelper extends BaseNlpHelper {
@ -61,69 +53,16 @@ export default class DefaultNlpHelper extends BaseNlpHelper {
* @param entities - All available entities
* @returns {DatasetType} - The formatted RASA training set
*/
format(samples: NlpSampleFull[], entities: NlpEntityFull[]): DatasetType {
const entityMap = NlpEntity.getEntityMap(entities);
const valueMap = NlpValue.getValueMap(
NlpValue.getValuesFromEntities(entities),
async format(
samples: NlpSampleFull[],
entities: NlpEntityFull[],
): Promise<DatasetType> {
const nluData = await this.nlpSampleService.formatRasaNlu(
samples,
entities,
);
const common_examples: CommonExample[] = samples
.filter((s) => s.entities.length > 0)
.map((s) => {
const intent = s.entities.find(
(e) => entityMap[e.entity].name === 'intent',
);
if (!intent) {
throw new Error('Unable to find the `intent` nlp entity.');
}
const sampleEntities: ExampleEntity[] = s.entities
.filter((e) => entityMap[<string>e.entity].name !== 'intent')
.map((e) => {
const res: ExampleEntity = {
entity: entityMap[<string>e.entity].name,
value: valueMap[<string>e.value].value,
};
if ('start' in e && 'end' in e) {
Object.assign(res, {
start: e.start,
end: e.end,
});
}
return res;
});
return {
text: s.text,
intent: valueMap[intent.value].value,
entities: sampleEntities,
};
});
const lookup_tables: LookupTable[] = entities.map((e) => {
return {
name: e.name,
elements: e.values.map((v) => {
return v.value;
}),
};
});
const entity_synonyms = entities
.reduce((acc, e) => {
const synonyms = e.values.map((v) => {
return {
value: v.value,
synonyms: v.expressions,
};
});
return acc.concat(synonyms);
}, [] as EntitySynonym[])
.filter((s) => {
return s.synonyms.length > 0;
});
return {
common_examples,
regex_features: [],
lookup_tables,
entity_synonyms,
};
return nluData;
}
/**
@ -138,7 +77,7 @@ export default class DefaultNlpHelper extends BaseNlpHelper {
entities: NlpEntityFull[],
): Promise<any> {
const self = this;
const nluData: DatasetType = self.format(samples, entities);
const nluData: DatasetType = await self.format(samples, entities);
// Train samples
const result = await this.httpService.axiosRef.post(
`${this.settings.endpoint}/train`,
@ -169,7 +108,7 @@ export default class DefaultNlpHelper extends BaseNlpHelper {
entities: NlpEntityFull[],
): Promise<any> {
const self = this;
const nluTestData: DatasetType = self.format(samples, entities);
const nluTestData: DatasetType = await self.format(samples, entities);
// Evaluate model with test samples
return await this.httpService.axiosRef.post(
`${this.settings.endpoint}/evaluate`,

View File

@ -96,13 +96,23 @@ describe('LanguageController', () => {
});
describe('findPage', () => {
const pageQuery = getPageQuery<Language>();
const pageQuery = getPageQuery<Language>({ sort: ['code', 'asc'] });
it('should find languages', async () => {
jest.spyOn(languageService, 'findPage');
const result = await languageController.findPage(pageQuery, {});
expect(languageService.findPage).toHaveBeenCalledWith({}, pageQuery);
expect(result).toEqualPayload(languageFixtures);
expect(result).toEqualPayload(
languageFixtures.sort(({ code: codeA }, { code: codeB }) => {
if (codeA < codeB) {
return -1;
}
if (codeA > codeB) {
return 1;
}
return 0;
}),
);
});
});

View File

@ -93,7 +93,7 @@ export class NlpSampleController extends BaseController<
type ? { type } : {},
);
const entities = await this.nlpEntityService.findAllAndPopulate();
const result = this.nlpSampleService.formatRasaNlu(samples, entities);
const result = await this.nlpSampleService.formatRasaNlu(samples, entities);
// Sending the JSON data as a file
const buffer = Buffer.from(JSON.stringify(result));

View File

@ -7,12 +7,14 @@
* 3. SaaS Restriction: This software, or any derivative of it, may not be used to offer a competing product or service (SaaS) without prior written consent from Hexastack. Offering the software as a service or using it in a commercial cloud environment without express permission is strictly prohibited.
*/
import { CACHE_MANAGER } from '@nestjs/cache-manager';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { MongooseModule } from '@nestjs/mongoose';
import { Test, TestingModule } from '@nestjs/testing';
import { LanguageRepository } from '@/i18n/repositories/language.repository';
import { Language, LanguageModel } from '@/i18n/schemas/language.schema';
import { LanguageService } from '@/i18n/services/language.service';
import { nlpSampleFixtures } from '@/utils/test/fixtures/nlpsample';
import { installNlpSampleEntityFixtures } from '@/utils/test/fixtures/nlpsampleentity';
import { getPageQuery } from '@/utils/test/pagination';
@ -68,7 +70,16 @@ describe('NlpSampleService', () => {
NlpSampleEntityService,
NlpEntityService,
NlpValueService,
LanguageService,
EventEmitter2,
{
provide: CACHE_MANAGER,
useValue: {
del: jest.fn(),
get: jest.fn(),
set: jest.fn(),
},
},
],
}).compile();
nlpSampleService = module.get<NlpSampleService>(NlpSampleService);

View File

@ -16,6 +16,7 @@ import {
ExampleEntity,
LookupTable,
} from '@/extensions/helpers/nlp/default/types';
import { LanguageService } from '@/i18n/services/language.service';
import { BaseService } from '@/utils/generics/base-service';
import { NlpSampleRepository } from '../repositories/nlp-sample.repository';
@ -33,7 +34,10 @@ export class NlpSampleService extends BaseService<
NlpSamplePopulate,
NlpSampleFull
> {
constructor(readonly repository: NlpSampleRepository) {
constructor(
readonly repository: NlpSampleRepository,
private readonly languageService: LanguageService,
) {
super(repository);
}
@ -56,10 +60,10 @@ export class NlpSampleService extends BaseService<
*
* @returns The formatted Rasa NLU training dataset.
*/
formatRasaNlu(
async formatRasaNlu(
samples: NlpSampleFull[],
entities: NlpEntityFull[],
): DatasetType {
): Promise<DatasetType> {
const entityMap = NlpEntity.getEntityMap(entities);
const valueMap = NlpValue.getValueMap(
NlpValue.getValuesFromEntities(entities),
@ -88,21 +92,34 @@ export class NlpSampleService extends BaseService<
});
}
return res;
})
// TODO : place language at the same level as the intent
.concat({
entity: 'language',
value: s.language.code,
});
return {
text: s.text,
intent: valueMap[intent.value].value,
entities: sampleEntities,
};
});
const lookup_tables: LookupTable[] = entities.map((e) => {
return {
name: e.name,
elements: e.values.map((v) => {
return v.value;
}),
};
});
const languages = await this.languageService.getLanguages();
const lookup_tables: LookupTable[] = entities
.map((e) => {
return {
name: e.name,
elements: e.values.map((v) => {
return v.value;
}),
};
})
.concat({
name: 'language',
elements: Object.keys(languages),
});
const entity_synonyms = entities
.reduce((acc, e) => {
const synonyms = e.values.map((v) => {

View File

@ -44,6 +44,7 @@ export const NlpImportDialog: FC<NlpImportDialogProps> = ({
QueryType.collection,
EntityType.NLP_SAMPLE,
]);
queryClient.removeQueries([QueryType.count, EntityType.NLP_SAMPLE]);
handleCloseDialog();
toast.success(t("message.success_save"));

View File

@ -157,12 +157,16 @@ const NlpDatasetSample: FC<NlpDatasetSampleProps> = ({
},
onSuccess: (result) => {
const traitEntities: INlpDatasetTraitEntity[] = result.entities.filter(
(e) => !("start" in e && "end" in e),
(e) => !("start" in e && "end" in e) && e.entity !== "language",
);
const keywordEntities = result.entities.filter(
(e) => "start" in e && "end" in e,
) as INlpDatasetKeywordEntity[];
const language = result.entities.find(
({ entity }) => entity === "language",
);
setValue("language", language?.value || "");
setValue("traitEntities", traitEntities);
setValue("keywordEntities", keywordEntities);
},

View File

@ -88,13 +88,13 @@ def parse(input: ParseInput, is_authenticated: Annotated[str, Depends(authentica
headers = {"Retry-After": "120"} # Suggest retrying after 2 minutes
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content={"message": "Models are still loading, please retry later."}, headers=headers)
language = app.language_classifier.get_prediction(input.q) # type: ignore
lang = language.get("value")
intent_prediction = app.intent_classifiers[lang].get_prediction(
language_prediction = app.language_classifier.get_prediction(input.q) # type: ignore
language = language_prediction.get("value")
intent_prediction = app.intent_classifiers[language].get_prediction(
input.q) # type: ignore
slot_prediction = app.slot_fillers[lang].get_prediction(
slot_prediction = app.slot_fillers[language].get_prediction(
input.q) # type: ignore
slot_prediction.get("entities").append(language)
slot_prediction.get("entities").append(language_prediction)
return {
"text": input.q,