mirror of
https://github.com/hexastack/hexabot
synced 2025-02-22 04:17:48 +00:00
feat: adapt nlu prediction
This commit is contained in:
parent
0c02b51cf6
commit
88168795c0
@ -28,7 +28,10 @@ import { MenuService } from '@/cms/services/menu.service';
|
||||
import { offlineEventText } from '@/extensions/channels/offline/__test__/events.mock';
|
||||
import OfflineHandler from '@/extensions/channels/offline/index.channel';
|
||||
import OfflineEventWrapper from '@/extensions/channels/offline/wrapper';
|
||||
import { LanguageRepository } from '@/i18n/repositories/language.repository';
|
||||
import { LanguageModel } from '@/i18n/schemas/language.schema';
|
||||
import { I18nService } from '@/i18n/services/i18n.service';
|
||||
import { LanguageService } from '@/i18n/services/language.service';
|
||||
import { LoggerService } from '@/logger/logger.service';
|
||||
import { NlpEntityRepository } from '@/nlp/repositories/nlp-entity.repository';
|
||||
import { NlpSampleEntityRepository } from '@/nlp/repositories/nlp-sample-entity.repository';
|
||||
@ -107,6 +110,7 @@ describe('BlockService', () => {
|
||||
NlpEntityModel,
|
||||
NlpSampleEntityModel,
|
||||
NlpSampleModel,
|
||||
LanguageModel,
|
||||
]),
|
||||
],
|
||||
providers: [
|
||||
@ -126,6 +130,7 @@ describe('BlockService', () => {
|
||||
NlpEntityRepository,
|
||||
NlpSampleEntityRepository,
|
||||
NlpSampleRepository,
|
||||
LanguageRepository,
|
||||
BlockService,
|
||||
CategoryService,
|
||||
ContentTypeService,
|
||||
@ -143,6 +148,7 @@ describe('BlockService', () => {
|
||||
NlpSampleEntityService,
|
||||
NlpSampleService,
|
||||
NlpService,
|
||||
LanguageService,
|
||||
{
|
||||
provide: PluginService,
|
||||
useValue: {},
|
||||
|
@ -23,6 +23,10 @@ export const nlpEmptyFormated: DatasetType = {
|
||||
name: 'product',
|
||||
elements: ['pizza', 'sandwich'],
|
||||
},
|
||||
{
|
||||
elements: ['en', 'fr'],
|
||||
name: 'language',
|
||||
},
|
||||
],
|
||||
entity_synonyms: [
|
||||
{
|
||||
@ -34,17 +38,33 @@ export const nlpEmptyFormated: DatasetType = {
|
||||
|
||||
export const nlpFormatted: DatasetType = {
|
||||
common_examples: [
|
||||
{ text: 'Hello', intent: 'greeting', entities: [] },
|
||||
{
|
||||
text: 'Hello',
|
||||
intent: 'greeting',
|
||||
entities: [
|
||||
{
|
||||
entity: 'language',
|
||||
value: 'en',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
text: 'i want to order a pizza',
|
||||
intent: 'order',
|
||||
entities: [{ entity: 'product', value: 'pizza', start: 19, end: 23 }],
|
||||
entities: [
|
||||
{ entity: 'product', value: 'pizza', start: 19, end: 23 },
|
||||
{
|
||||
entity: 'language',
|
||||
value: 'en',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
regex_features: [],
|
||||
lookup_tables: [
|
||||
{ name: 'intent', elements: ['greeting', 'order'] },
|
||||
{ name: 'product', elements: ['pizza', 'sandwich'] },
|
||||
{ name: 'language', elements: ['en', 'fr'] },
|
||||
],
|
||||
entity_synonyms: [
|
||||
{
|
||||
|
@ -8,10 +8,14 @@
|
||||
*/
|
||||
|
||||
import { HttpModule } from '@nestjs/axios';
|
||||
import { CACHE_MANAGER } from '@nestjs/cache-manager';
|
||||
import { EventEmitter2 } from '@nestjs/event-emitter';
|
||||
import { MongooseModule } from '@nestjs/mongoose';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { LanguageRepository } from '@/i18n/repositories/language.repository';
|
||||
import { LanguageModel } from '@/i18n/schemas/language.schema';
|
||||
import { LanguageService } from '@/i18n/services/language.service';
|
||||
import { LoggerService } from '@/logger/logger.service';
|
||||
import { NlpEntityRepository } from '@/nlp/repositories/nlp-entity.repository';
|
||||
import { NlpSampleEntityRepository } from '@/nlp/repositories/nlp-sample-entity.repository';
|
||||
@ -56,10 +60,24 @@ describe('NLP Default Helper', () => {
|
||||
NlpValueModel,
|
||||
NlpSampleModel,
|
||||
NlpSampleEntityModel,
|
||||
LanguageModel,
|
||||
]),
|
||||
HttpModule,
|
||||
],
|
||||
providers: [
|
||||
NlpService,
|
||||
NlpSampleService,
|
||||
NlpSampleRepository,
|
||||
NlpEntityService,
|
||||
NlpEntityRepository,
|
||||
NlpValueService,
|
||||
NlpValueRepository,
|
||||
NlpSampleEntityService,
|
||||
NlpSampleEntityRepository,
|
||||
LanguageService,
|
||||
LanguageRepository,
|
||||
EventEmitter2,
|
||||
DefaultNlpHelper,
|
||||
LoggerService,
|
||||
{
|
||||
provide: SettingService,
|
||||
@ -76,17 +94,14 @@ describe('NLP Default Helper', () => {
|
||||
})),
|
||||
},
|
||||
},
|
||||
NlpService,
|
||||
NlpSampleService,
|
||||
NlpSampleRepository,
|
||||
NlpEntityService,
|
||||
NlpEntityRepository,
|
||||
NlpValueService,
|
||||
NlpValueRepository,
|
||||
NlpSampleEntityService,
|
||||
NlpSampleEntityRepository,
|
||||
EventEmitter2,
|
||||
DefaultNlpHelper,
|
||||
{
|
||||
provide: CACHE_MANAGER,
|
||||
useValue: {
|
||||
del: jest.fn(),
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
settingService = module.get<SettingService>(SettingService);
|
||||
@ -103,15 +118,15 @@ describe('NLP Default Helper', () => {
|
||||
expect(nlp).toBeDefined();
|
||||
});
|
||||
|
||||
it('should format empty training set properly', () => {
|
||||
it('should format empty training set properly', async () => {
|
||||
const nlp = nlpService.getNLP();
|
||||
const results = nlp.format([], entitiesMock);
|
||||
const results = await nlp.format([], entitiesMock);
|
||||
expect(results).toEqual(nlpEmptyFormated);
|
||||
});
|
||||
|
||||
it('should format training set properly', () => {
|
||||
it('should format training set properly', async () => {
|
||||
const nlp = nlpService.getNLP();
|
||||
const results = nlp.format(samplesMock, entitiesMock);
|
||||
const results = await nlp.format(samplesMock, entitiesMock);
|
||||
expect(results).toEqual(nlpFormatted);
|
||||
});
|
||||
|
||||
|
@ -13,21 +13,13 @@ import { Injectable } from '@nestjs/common';
|
||||
import { LoggerService } from '@/logger/logger.service';
|
||||
import BaseNlpHelper from '@/nlp/lib/BaseNlpHelper';
|
||||
import { Nlp } from '@/nlp/lib/types';
|
||||
import { NlpEntity, NlpEntityFull } from '@/nlp/schemas/nlp-entity.schema';
|
||||
import { NlpEntityFull } from '@/nlp/schemas/nlp-entity.schema';
|
||||
import { NlpSampleFull } from '@/nlp/schemas/nlp-sample.schema';
|
||||
import { NlpValue } from '@/nlp/schemas/nlp-value.schema';
|
||||
import { NlpEntityService } from '@/nlp/services/nlp-entity.service';
|
||||
import { NlpSampleService } from '@/nlp/services/nlp-sample.service';
|
||||
import { NlpService } from '@/nlp/services/nlp.service';
|
||||
|
||||
import {
|
||||
CommonExample,
|
||||
DatasetType,
|
||||
EntitySynonym,
|
||||
ExampleEntity,
|
||||
LookupTable,
|
||||
NlpParseResultType,
|
||||
} from './types';
|
||||
import { DatasetType, NlpParseResultType } from './types';
|
||||
|
||||
@Injectable()
|
||||
export default class DefaultNlpHelper extends BaseNlpHelper {
|
||||
@ -61,69 +53,16 @@ export default class DefaultNlpHelper extends BaseNlpHelper {
|
||||
* @param entities - All available entities
|
||||
* @returns {DatasetType} - The formatted RASA training set
|
||||
*/
|
||||
format(samples: NlpSampleFull[], entities: NlpEntityFull[]): DatasetType {
|
||||
const entityMap = NlpEntity.getEntityMap(entities);
|
||||
const valueMap = NlpValue.getValueMap(
|
||||
NlpValue.getValuesFromEntities(entities),
|
||||
async format(
|
||||
samples: NlpSampleFull[],
|
||||
entities: NlpEntityFull[],
|
||||
): Promise<DatasetType> {
|
||||
const nluData = await this.nlpSampleService.formatRasaNlu(
|
||||
samples,
|
||||
entities,
|
||||
);
|
||||
|
||||
const common_examples: CommonExample[] = samples
|
||||
.filter((s) => s.entities.length > 0)
|
||||
.map((s) => {
|
||||
const intent = s.entities.find(
|
||||
(e) => entityMap[e.entity].name === 'intent',
|
||||
);
|
||||
if (!intent) {
|
||||
throw new Error('Unable to find the `intent` nlp entity.');
|
||||
}
|
||||
const sampleEntities: ExampleEntity[] = s.entities
|
||||
.filter((e) => entityMap[<string>e.entity].name !== 'intent')
|
||||
.map((e) => {
|
||||
const res: ExampleEntity = {
|
||||
entity: entityMap[<string>e.entity].name,
|
||||
value: valueMap[<string>e.value].value,
|
||||
};
|
||||
if ('start' in e && 'end' in e) {
|
||||
Object.assign(res, {
|
||||
start: e.start,
|
||||
end: e.end,
|
||||
});
|
||||
}
|
||||
return res;
|
||||
});
|
||||
return {
|
||||
text: s.text,
|
||||
intent: valueMap[intent.value].value,
|
||||
entities: sampleEntities,
|
||||
};
|
||||
});
|
||||
const lookup_tables: LookupTable[] = entities.map((e) => {
|
||||
return {
|
||||
name: e.name,
|
||||
elements: e.values.map((v) => {
|
||||
return v.value;
|
||||
}),
|
||||
};
|
||||
});
|
||||
const entity_synonyms = entities
|
||||
.reduce((acc, e) => {
|
||||
const synonyms = e.values.map((v) => {
|
||||
return {
|
||||
value: v.value,
|
||||
synonyms: v.expressions,
|
||||
};
|
||||
});
|
||||
return acc.concat(synonyms);
|
||||
}, [] as EntitySynonym[])
|
||||
.filter((s) => {
|
||||
return s.synonyms.length > 0;
|
||||
});
|
||||
return {
|
||||
common_examples,
|
||||
regex_features: [],
|
||||
lookup_tables,
|
||||
entity_synonyms,
|
||||
};
|
||||
return nluData;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -138,7 +77,7 @@ export default class DefaultNlpHelper extends BaseNlpHelper {
|
||||
entities: NlpEntityFull[],
|
||||
): Promise<any> {
|
||||
const self = this;
|
||||
const nluData: DatasetType = self.format(samples, entities);
|
||||
const nluData: DatasetType = await self.format(samples, entities);
|
||||
// Train samples
|
||||
const result = await this.httpService.axiosRef.post(
|
||||
`${this.settings.endpoint}/train`,
|
||||
@ -169,7 +108,7 @@ export default class DefaultNlpHelper extends BaseNlpHelper {
|
||||
entities: NlpEntityFull[],
|
||||
): Promise<any> {
|
||||
const self = this;
|
||||
const nluTestData: DatasetType = self.format(samples, entities);
|
||||
const nluTestData: DatasetType = await self.format(samples, entities);
|
||||
// Evaluate model with test samples
|
||||
return await this.httpService.axiosRef.post(
|
||||
`${this.settings.endpoint}/evaluate`,
|
||||
|
@ -96,13 +96,23 @@ describe('LanguageController', () => {
|
||||
});
|
||||
|
||||
describe('findPage', () => {
|
||||
const pageQuery = getPageQuery<Language>();
|
||||
const pageQuery = getPageQuery<Language>({ sort: ['code', 'asc'] });
|
||||
it('should find languages', async () => {
|
||||
jest.spyOn(languageService, 'findPage');
|
||||
const result = await languageController.findPage(pageQuery, {});
|
||||
|
||||
expect(languageService.findPage).toHaveBeenCalledWith({}, pageQuery);
|
||||
expect(result).toEqualPayload(languageFixtures);
|
||||
expect(result).toEqualPayload(
|
||||
languageFixtures.sort(({ code: codeA }, { code: codeB }) => {
|
||||
if (codeA < codeB) {
|
||||
return -1;
|
||||
}
|
||||
if (codeA > codeB) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -93,7 +93,7 @@ export class NlpSampleController extends BaseController<
|
||||
type ? { type } : {},
|
||||
);
|
||||
const entities = await this.nlpEntityService.findAllAndPopulate();
|
||||
const result = this.nlpSampleService.formatRasaNlu(samples, entities);
|
||||
const result = await this.nlpSampleService.formatRasaNlu(samples, entities);
|
||||
|
||||
// Sending the JSON data as a file
|
||||
const buffer = Buffer.from(JSON.stringify(result));
|
||||
|
@ -7,12 +7,14 @@
|
||||
* 3. SaaS Restriction: This software, or any derivative of it, may not be used to offer a competing product or service (SaaS) without prior written consent from Hexastack. Offering the software as a service or using it in a commercial cloud environment without express permission is strictly prohibited.
|
||||
*/
|
||||
|
||||
import { CACHE_MANAGER } from '@nestjs/cache-manager';
|
||||
import { EventEmitter2 } from '@nestjs/event-emitter';
|
||||
import { MongooseModule } from '@nestjs/mongoose';
|
||||
import { Test, TestingModule } from '@nestjs/testing';
|
||||
|
||||
import { LanguageRepository } from '@/i18n/repositories/language.repository';
|
||||
import { Language, LanguageModel } from '@/i18n/schemas/language.schema';
|
||||
import { LanguageService } from '@/i18n/services/language.service';
|
||||
import { nlpSampleFixtures } from '@/utils/test/fixtures/nlpsample';
|
||||
import { installNlpSampleEntityFixtures } from '@/utils/test/fixtures/nlpsampleentity';
|
||||
import { getPageQuery } from '@/utils/test/pagination';
|
||||
@ -68,7 +70,16 @@ describe('NlpSampleService', () => {
|
||||
NlpSampleEntityService,
|
||||
NlpEntityService,
|
||||
NlpValueService,
|
||||
LanguageService,
|
||||
EventEmitter2,
|
||||
{
|
||||
provide: CACHE_MANAGER,
|
||||
useValue: {
|
||||
del: jest.fn(),
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
},
|
||||
},
|
||||
],
|
||||
}).compile();
|
||||
nlpSampleService = module.get<NlpSampleService>(NlpSampleService);
|
||||
|
@ -16,6 +16,7 @@ import {
|
||||
ExampleEntity,
|
||||
LookupTable,
|
||||
} from '@/extensions/helpers/nlp/default/types';
|
||||
import { LanguageService } from '@/i18n/services/language.service';
|
||||
import { BaseService } from '@/utils/generics/base-service';
|
||||
|
||||
import { NlpSampleRepository } from '../repositories/nlp-sample.repository';
|
||||
@ -33,7 +34,10 @@ export class NlpSampleService extends BaseService<
|
||||
NlpSamplePopulate,
|
||||
NlpSampleFull
|
||||
> {
|
||||
constructor(readonly repository: NlpSampleRepository) {
|
||||
constructor(
|
||||
readonly repository: NlpSampleRepository,
|
||||
private readonly languageService: LanguageService,
|
||||
) {
|
||||
super(repository);
|
||||
}
|
||||
|
||||
@ -56,10 +60,10 @@ export class NlpSampleService extends BaseService<
|
||||
*
|
||||
* @returns The formatted Rasa NLU training dataset.
|
||||
*/
|
||||
formatRasaNlu(
|
||||
async formatRasaNlu(
|
||||
samples: NlpSampleFull[],
|
||||
entities: NlpEntityFull[],
|
||||
): DatasetType {
|
||||
): Promise<DatasetType> {
|
||||
const entityMap = NlpEntity.getEntityMap(entities);
|
||||
const valueMap = NlpValue.getValueMap(
|
||||
NlpValue.getValuesFromEntities(entities),
|
||||
@ -88,21 +92,34 @@ export class NlpSampleService extends BaseService<
|
||||
});
|
||||
}
|
||||
return res;
|
||||
})
|
||||
// TODO : place language at the same level as the intent
|
||||
.concat({
|
||||
entity: 'language',
|
||||
value: s.language.code,
|
||||
});
|
||||
|
||||
return {
|
||||
text: s.text,
|
||||
intent: valueMap[intent.value].value,
|
||||
entities: sampleEntities,
|
||||
};
|
||||
});
|
||||
const lookup_tables: LookupTable[] = entities.map((e) => {
|
||||
return {
|
||||
name: e.name,
|
||||
elements: e.values.map((v) => {
|
||||
return v.value;
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const languages = await this.languageService.getLanguages();
|
||||
const lookup_tables: LookupTable[] = entities
|
||||
.map((e) => {
|
||||
return {
|
||||
name: e.name,
|
||||
elements: e.values.map((v) => {
|
||||
return v.value;
|
||||
}),
|
||||
};
|
||||
})
|
||||
.concat({
|
||||
name: 'language',
|
||||
elements: Object.keys(languages),
|
||||
});
|
||||
const entity_synonyms = entities
|
||||
.reduce((acc, e) => {
|
||||
const synonyms = e.values.map((v) => {
|
||||
|
@ -44,6 +44,7 @@ export const NlpImportDialog: FC<NlpImportDialogProps> = ({
|
||||
QueryType.collection,
|
||||
EntityType.NLP_SAMPLE,
|
||||
]);
|
||||
queryClient.removeQueries([QueryType.count, EntityType.NLP_SAMPLE]);
|
||||
|
||||
handleCloseDialog();
|
||||
toast.success(t("message.success_save"));
|
||||
|
@ -157,12 +157,16 @@ const NlpDatasetSample: FC<NlpDatasetSampleProps> = ({
|
||||
},
|
||||
onSuccess: (result) => {
|
||||
const traitEntities: INlpDatasetTraitEntity[] = result.entities.filter(
|
||||
(e) => !("start" in e && "end" in e),
|
||||
(e) => !("start" in e && "end" in e) && e.entity !== "language",
|
||||
);
|
||||
const keywordEntities = result.entities.filter(
|
||||
(e) => "start" in e && "end" in e,
|
||||
) as INlpDatasetKeywordEntity[];
|
||||
const language = result.entities.find(
|
||||
({ entity }) => entity === "language",
|
||||
);
|
||||
|
||||
setValue("language", language?.value || "");
|
||||
setValue("traitEntities", traitEntities);
|
||||
setValue("keywordEntities", keywordEntities);
|
||||
},
|
||||
|
10
nlu/main.py
10
nlu/main.py
@ -88,13 +88,13 @@ def parse(input: ParseInput, is_authenticated: Annotated[str, Depends(authentica
|
||||
headers = {"Retry-After": "120"} # Suggest retrying after 2 minutes
|
||||
return JSONResponse(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content={"message": "Models are still loading, please retry later."}, headers=headers)
|
||||
|
||||
language = app.language_classifier.get_prediction(input.q) # type: ignore
|
||||
lang = language.get("value")
|
||||
intent_prediction = app.intent_classifiers[lang].get_prediction(
|
||||
language_prediction = app.language_classifier.get_prediction(input.q) # type: ignore
|
||||
language = language_prediction.get("value")
|
||||
intent_prediction = app.intent_classifiers[language].get_prediction(
|
||||
input.q) # type: ignore
|
||||
slot_prediction = app.slot_fillers[lang].get_prediction(
|
||||
slot_prediction = app.slot_fillers[language].get_prediction(
|
||||
input.q) # type: ignore
|
||||
slot_prediction.get("entities").append(language)
|
||||
slot_prediction.get("entities").append(language_prediction)
|
||||
|
||||
return {
|
||||
"text": input.q,
|
||||
|
Loading…
Reference in New Issue
Block a user