fix: refactor + unit tests

This commit is contained in:
Mohamed Marrouchi 2024-12-25 07:36:50 +01:00
parent f60b59aa54
commit 328c5cefb3
4 changed files with 255 additions and 236 deletions

View File

@ -6,17 +6,12 @@
* 2. All derivative works must include clear attribution to the original creator and software, Hexastack and Hexabot, in a prominent location (e.g., in the software's "About" section, documentation, and README file). * 2. All derivative works must include clear attribution to the original creator and software, Hexastack and Hexabot, in a prominent location (e.g., in the software's "About" section, documentation, and README file).
*/ */
import fs from 'fs';
import { CACHE_MANAGER } from '@nestjs/cache-manager'; import { CACHE_MANAGER } from '@nestjs/cache-manager';
import { BadRequestException, NotFoundException } from '@nestjs/common'; import { BadRequestException, NotFoundException } from '@nestjs/common';
import { EventEmitter2 } from '@nestjs/event-emitter'; import { EventEmitter2 } from '@nestjs/event-emitter';
import { MongooseModule } from '@nestjs/mongoose'; import { MongooseModule } from '@nestjs/mongoose';
import { Test, TestingModule } from '@nestjs/testing'; import { Test, TestingModule } from '@nestjs/testing';
import { AttachmentRepository } from '@/attachment/repositories/attachment.repository';
import { AttachmentModel } from '@/attachment/schemas/attachment.schema';
import { AttachmentService } from '@/attachment/services/attachment.service';
import { HelperService } from '@/helper/helper.service'; import { HelperService } from '@/helper/helper.service';
import { LanguageRepository } from '@/i18n/repositories/language.repository'; import { LanguageRepository } from '@/i18n/repositories/language.repository';
import { Language, LanguageModel } from '@/i18n/schemas/language.schema'; import { Language, LanguageModel } from '@/i18n/schemas/language.schema';
@ -50,7 +45,6 @@ import { NlpEntityService } from '../services/nlp-entity.service';
import { NlpSampleEntityService } from '../services/nlp-sample-entity.service'; import { NlpSampleEntityService } from '../services/nlp-sample-entity.service';
import { NlpSampleService } from '../services/nlp-sample.service'; import { NlpSampleService } from '../services/nlp-sample.service';
import { NlpValueService } from '../services/nlp-value.service'; import { NlpValueService } from '../services/nlp-value.service';
import { NlpService } from '../services/nlp.service';
import { NlpSampleController } from './nlp-sample.controller'; import { NlpSampleController } from './nlp-sample.controller';
@ -60,7 +54,6 @@ describe('NlpSampleController', () => {
let nlpSampleService: NlpSampleService; let nlpSampleService: NlpSampleService;
let nlpEntityService: NlpEntityService; let nlpEntityService: NlpEntityService;
let nlpValueService: NlpValueService; let nlpValueService: NlpValueService;
let attachmentService: AttachmentService;
let languageService: LanguageService; let languageService: LanguageService;
let byeJhonSampleId: string; let byeJhonSampleId: string;
let languages: Language[]; let languages: Language[];
@ -76,7 +69,6 @@ describe('NlpSampleController', () => {
MongooseModule.forFeature([ MongooseModule.forFeature([
NlpSampleModel, NlpSampleModel,
NlpSampleEntityModel, NlpSampleEntityModel,
AttachmentModel,
NlpEntityModel, NlpEntityModel,
NlpValueModel, NlpValueModel,
SettingModel, SettingModel,
@ -87,9 +79,7 @@ describe('NlpSampleController', () => {
LoggerService, LoggerService,
NlpSampleRepository, NlpSampleRepository,
NlpSampleEntityRepository, NlpSampleEntityRepository,
AttachmentService,
NlpEntityService, NlpEntityService,
AttachmentRepository,
NlpEntityRepository, NlpEntityRepository,
NlpValueService, NlpValueService,
NlpValueRepository, NlpValueRepository,
@ -98,7 +88,6 @@ describe('NlpSampleController', () => {
LanguageRepository, LanguageRepository,
LanguageService, LanguageService,
EventEmitter2, EventEmitter2,
NlpService,
HelperService, HelperService,
SettingRepository, SettingRepository,
SettingService, SettingService,
@ -131,7 +120,6 @@ describe('NlpSampleController', () => {
text: 'Bye Jhon', text: 'Bye Jhon',
}) })
).id; ).id;
attachmentService = module.get<AttachmentService>(AttachmentService);
languageService = module.get<LanguageService>(LanguageService); languageService = module.get<LanguageService>(LanguageService);
languages = await languageService.findAll(); languages = await languageService.findAll();
}); });
@ -315,83 +303,44 @@ describe('NlpSampleController', () => {
}); });
}); });
describe('import', () => { describe('importFile', () => {
it('should throw exception when attachment is not found', async () => { it('should throw exception when something is wrong with the upload', async () => {
const invalidattachmentId = ( const file = {
await attachmentService.findOne({ buffer: Buffer.from('', 'utf-8'),
name: 'store2.jpg', size: 0,
}) mimetype: 'text/csv',
).id; } as Express.Multer.File;
await attachmentService.deleteOne({ name: 'store2.jpg' }); await expect(nlpSampleController.importFile(file)).rejects.toThrow(
await expect( 'Bad Request Exception',
nlpSampleController.import(invalidattachmentId),
).rejects.toThrow(NotFoundException);
});
it('should throw exception when file location is not present', async () => {
const attachmentId = (
await attachmentService.findOne({
name: 'store1.jpg',
})
).id;
jest.spyOn(fs, 'existsSync').mockReturnValueOnce(false);
await expect(nlpSampleController.import(attachmentId)).rejects.toThrow(
NotFoundException,
); );
}); });
it('should return a failure if an error occurs when parsing csv file ', async () => { it('should return a failure if an error occurs when parsing csv file ', async () => {
const mockCsvDataWithErrors: string = `intent,entities,lang,question const mockCsvDataWithErrors: string = `intent,entities,lang,question
greeting,person,en`; greeting,person,en`;
jest.spyOn(fs, 'existsSync').mockReturnValueOnce(true);
jest.spyOn(fs, 'readFileSync').mockReturnValueOnce(mockCsvDataWithErrors);
const attachmentId = (
await attachmentService.findOne({
name: 'store1.jpg',
})
).id;
const mockParsedCsvDataWithErrors = { const buffer = Buffer.from(mockCsvDataWithErrors, 'utf-8');
data: [{ intent: 'greeting', entities: 'person', lang: 'en' }], const file = {
errors: [ buffer,
{ size: buffer.length,
type: 'FieldMismatch', mimetype: 'text/csv',
code: 'TooFewFields', } as Express.Multer.File;
message: 'Too few fields: expected 4 fields but parsed 3', await expect(nlpSampleController.importFile(file)).rejects.toThrow();
row: 0,
},
],
meta: {
delimiter: ',',
linebreak: '\n',
aborted: false,
truncated: false,
cursor: 49,
fields: ['intent', 'entities', 'lang', 'question'],
},
};
await expect(nlpSampleController.import(attachmentId)).rejects.toThrow(
new BadRequestException({
cause: mockParsedCsvDataWithErrors.errors,
description: 'Error while parsing CSV',
}),
);
}); });
it('should import data from a CSV file', async () => { it('should import data from a CSV file', async () => {
const attachmentId = (
await attachmentService.findOne({
name: 'store1.jpg',
})
).id;
const mockCsvData: string = [ const mockCsvData: string = [
`text,intent,language`, `text,intent,language`,
`How much does a BMW cost?,price,en`, `How much does a BMW cost?,price,en`,
].join('\n'); ].join('\n');
jest.spyOn(fs, 'existsSync').mockReturnValueOnce(true);
jest.spyOn(fs, 'readFileSync').mockReturnValueOnce(mockCsvData);
const result = await nlpSampleController.import(attachmentId); const buffer = Buffer.from(mockCsvData, 'utf-8');
const file = {
buffer,
size: buffer.length,
mimetype: 'text/csv',
} as Express.Multer.File;
const result = await nlpSampleController.importFile(file);
const intentEntityResult = await nlpEntityService.findOne({ const intentEntityResult = await nlpEntityService.findOne({
name: 'intent', name: 'intent',
}); });
@ -429,9 +378,10 @@ describe('NlpSampleController', () => {
expect(intentEntityResult).toEqualPayload(intentEntity); expect(intentEntityResult).toEqualPayload(intentEntity);
expect(priceValueResult).toEqualPayload(priceValue); expect(priceValueResult).toEqualPayload(priceValue);
expect(textSampleResult).toEqualPayload(textSample); expect(textSampleResult).toEqualPayload(textSample);
expect(result).toEqual({ success: true }); expect(result).toEqualPayload([textSample]);
}); });
}); });
describe('deleteMany', () => { describe('deleteMany', () => {
it('should delete multiple nlp samples', async () => { it('should delete multiple nlp samples', async () => {
const samplesToDelete = [ const samplesToDelete = [

View File

@ -6,8 +6,6 @@
* 2. All derivative works must include clear attribution to the original creator and software, Hexastack and Hexabot, in a prominent location (e.g., in the software's "About" section, documentation, and README file). * 2. All derivative works must include clear attribution to the original creator and software, Hexastack and Hexabot, in a prominent location (e.g., in the software's "About" section, documentation, and README file).
*/ */
import fs from 'fs';
import { join } from 'path';
import { Readable } from 'stream'; import { Readable } from 'stream';
import { import {
@ -31,10 +29,7 @@ import {
import { FileInterceptor } from '@nestjs/platform-express'; import { FileInterceptor } from '@nestjs/platform-express';
import { CsrfCheck } from '@tekuconcept/nestjs-csrf'; import { CsrfCheck } from '@tekuconcept/nestjs-csrf';
import { Response } from 'express'; import { Response } from 'express';
import Papa from 'papaparse';
import { AttachmentService } from '@/attachment/services/attachment.service';
import { config } from '@/config';
import { HelperService } from '@/helper/helper.service'; import { HelperService } from '@/helper/helper.service';
import { LanguageService } from '@/i18n/services/language.service'; import { LanguageService } from '@/i18n/services/language.service';
import { CsrfInterceptor } from '@/interceptors/csrf.interceptor'; import { CsrfInterceptor } from '@/interceptors/csrf.interceptor';
@ -47,18 +42,17 @@ import { PopulatePipe } from '@/utils/pipes/populate.pipe';
import { SearchFilterPipe } from '@/utils/pipes/search-filter.pipe'; import { SearchFilterPipe } from '@/utils/pipes/search-filter.pipe';
import { TFilterQuery } from '@/utils/types/filter.types'; import { TFilterQuery } from '@/utils/types/filter.types';
import { NlpSampleCreateDto, NlpSampleDto } from '../dto/nlp-sample.dto'; import { NlpSampleDto } from '../dto/nlp-sample.dto';
import { import {
NlpSample, NlpSample,
NlpSampleFull, NlpSampleFull,
NlpSamplePopulate, NlpSamplePopulate,
NlpSampleStub, NlpSampleStub,
} from '../schemas/nlp-sample.schema'; } from '../schemas/nlp-sample.schema';
import { NlpSampleEntityValue, NlpSampleState } from '../schemas/types'; import { NlpSampleState } from '../schemas/types';
import { NlpEntityService } from '../services/nlp-entity.service'; import { NlpEntityService } from '../services/nlp-entity.service';
import { NlpSampleEntityService } from '../services/nlp-sample-entity.service'; import { NlpSampleEntityService } from '../services/nlp-sample-entity.service';
import { NlpSampleService } from '../services/nlp-sample.service'; import { NlpSampleService } from '../services/nlp-sample.service';
import { NlpService } from '../services/nlp.service';
@UseInterceptors(CsrfInterceptor) @UseInterceptors(CsrfInterceptor)
@Controller('nlpsample') @Controller('nlpsample')
@ -70,11 +64,9 @@ export class NlpSampleController extends BaseController<
> { > {
constructor( constructor(
private readonly nlpSampleService: NlpSampleService, private readonly nlpSampleService: NlpSampleService,
private readonly attachmentService: AttachmentService,
private readonly nlpSampleEntityService: NlpSampleEntityService, private readonly nlpSampleEntityService: NlpSampleEntityService,
private readonly nlpEntityService: NlpEntityService, private readonly nlpEntityService: NlpEntityService,
private readonly logger: LoggerService, private readonly logger: LoggerService,
private readonly nlpService: NlpService,
private readonly languageService: LanguageService, private readonly languageService: LanguageService,
private readonly helperService: HelperService, private readonly helperService: HelperService,
) { ) {
@ -371,157 +363,11 @@ export class NlpSampleController extends BaseController<
return deleteResult; return deleteResult;
} }
private async parseAndSaveDataset(data: string) {
const allEntities = await this.nlpEntityService.findAll();
// Check if file location is present
if (allEntities.length === 0) {
throw new NotFoundException(
'No entities found, please create them first.',
);
}
// Parse local CSV file
const result: {
errors: any[];
data: Array<Record<string, string>>;
} = Papa.parse(data, {
header: true,
skipEmptyLines: true,
});
if (result.errors && result.errors.length > 0) {
this.logger.warn(
`Errors parsing the file: ${JSON.stringify(result.errors)}`,
);
throw new BadRequestException(result.errors, {
cause: result.errors,
description: 'Error while parsing CSV',
});
}
// Remove data with no intent
const filteredData = result.data.filter((d) => d.intent !== 'none');
const languages = await this.languageService.getLanguages();
const defaultLanguage = await this.languageService.getDefaultLanguage();
const nlpSamples: NlpSample[] = [];
// Reduce function to ensure executing promises one by one
for (const d of filteredData) {
try {
// Check if a sample with the same text already exists
const existingSamples = await this.nlpSampleService.find({
text: d.text,
});
// Skip if sample already exists
if (Array.isArray(existingSamples) && existingSamples.length > 0) {
continue;
}
// Fallback to default language if 'language' is missing or invalid
if (!d.language || !(d.language in languages)) {
if (d.language) {
this.logger.warn(
`Language "${d.language}" does not exist, falling back to default.`,
);
}
d.language = defaultLanguage.code;
}
// Create a new sample dto
const sample: NlpSampleCreateDto = {
text: d.text,
trained: false,
language: languages[d.language].id,
};
// Create a new sample entity dto
const entities: NlpSampleEntityValue[] = allEntities
.filter(({ name }) => name in d)
.map(({ name }) => ({
entity: name,
value: d[name],
}));
// Store any new entity/value
const storedEntities = await this.nlpEntityService.storeNewEntities(
sample.text,
entities,
['trait'],
);
// Store sample
const createdSample = await this.nlpSampleService.create(sample);
nlpSamples.push(createdSample);
// Map and assign the sample ID to each stored entity
const sampleEntities = storedEntities.map((storedEntity) => ({
...storedEntity,
sample: createdSample?.id,
}));
// Store sample entities
await this.nlpSampleEntityService.createMany(sampleEntities);
} catch (err) {
this.logger.error('Error occurred when extracting data. ', err);
}
}
return nlpSamples;
}
@CsrfCheck(true) @CsrfCheck(true)
@Post('import') @Post('import')
@UseInterceptors(FileInterceptor('file')) @UseInterceptors(FileInterceptor('file'))
async importFile(@UploadedFile() file: Express.Multer.File) { async importFile(@UploadedFile() file: Express.Multer.File) {
try {
const datasetContent = file.buffer.toString('utf-8'); const datasetContent = file.buffer.toString('utf-8');
return await this.parseAndSaveDataset(datasetContent); return await this.nlpSampleService.parseAndSaveDataset(datasetContent);
} catch (err) {
this.logger.error('Error processing file:', err);
}
}
/**
* @deprecated
* Imports NLP samples from a CSV file.
*
* @param file - The file path or ID of the CSV file to import.
*
* @returns A success message after the import process is completed.
*/
@CsrfCheck(true)
@Post('import/:file')
async import(
@Param('file')
file: string,
) {
// Check if file is present
const importedFile = await this.attachmentService.findOne(file);
if (!importedFile) {
throw new NotFoundException('Missing file!');
}
const filePath = importedFile
? join(config.parameters.uploadDir, importedFile.location)
: undefined;
// Check if file location is present
if (!fs.existsSync(filePath)) {
throw new NotFoundException('File does not exist');
}
const allEntities = await this.nlpEntityService.findAll();
// Check if file location is present
if (allEntities.length === 0) {
throw new NotFoundException(
'No entities found, please create them first.',
);
}
// Read file content
const data = fs.readFileSync(filePath, 'utf8');
await this.parseAndSaveDataset(data);
this.logger.log('Import process completed successfully.');
return { success: true };
} }
} }

View File

@ -7,6 +7,7 @@
*/ */
import { CACHE_MANAGER } from '@nestjs/cache-manager'; import { CACHE_MANAGER } from '@nestjs/cache-manager';
import { BadRequestException, NotFoundException } from '@nestjs/common';
import { EventEmitter2 } from '@nestjs/event-emitter'; import { EventEmitter2 } from '@nestjs/event-emitter';
import { MongooseModule } from '@nestjs/mongoose'; import { MongooseModule } from '@nestjs/mongoose';
import { Test, TestingModule } from '@nestjs/testing'; import { Test, TestingModule } from '@nestjs/testing';
@ -27,7 +28,7 @@ import { NlpEntityRepository } from '../repositories/nlp-entity.repository';
import { NlpSampleEntityRepository } from '../repositories/nlp-sample-entity.repository'; import { NlpSampleEntityRepository } from '../repositories/nlp-sample-entity.repository';
import { NlpSampleRepository } from '../repositories/nlp-sample.repository'; import { NlpSampleRepository } from '../repositories/nlp-sample.repository';
import { NlpValueRepository } from '../repositories/nlp-value.repository'; import { NlpValueRepository } from '../repositories/nlp-value.repository';
import { NlpEntityModel } from '../schemas/nlp-entity.schema'; import { NlpEntity, NlpEntityModel } from '../schemas/nlp-entity.schema';
import { import {
NlpSampleEntity, NlpSampleEntity,
NlpSampleEntityModel, NlpSampleEntityModel,
@ -41,7 +42,10 @@ import { NlpSampleService } from './nlp-sample.service';
import { NlpValueService } from './nlp-value.service'; import { NlpValueService } from './nlp-value.service';
describe('NlpSampleService', () => { describe('NlpSampleService', () => {
let nlpEntityService: NlpEntityService;
let nlpSampleService: NlpSampleService; let nlpSampleService: NlpSampleService;
let nlpSampleEntityService: NlpSampleEntityService;
let languageService: LanguageService;
let nlpSampleEntityRepository: NlpSampleEntityRepository; let nlpSampleEntityRepository: NlpSampleEntityRepository;
let nlpSampleRepository: NlpSampleRepository; let nlpSampleRepository: NlpSampleRepository;
let languageRepository: LanguageRepository; let languageRepository: LanguageRepository;
@ -84,7 +88,11 @@ describe('NlpSampleService', () => {
}, },
], ],
}).compile(); }).compile();
nlpEntityService = module.get<NlpEntityService>(NlpEntityService);
nlpSampleService = module.get<NlpSampleService>(NlpSampleService); nlpSampleService = module.get<NlpSampleService>(NlpSampleService);
nlpSampleEntityService = module.get<NlpSampleEntityService>(
NlpSampleEntityService,
);
nlpSampleRepository = module.get<NlpSampleRepository>(NlpSampleRepository); nlpSampleRepository = module.get<NlpSampleRepository>(NlpSampleRepository);
nlpSampleEntityRepository = module.get<NlpSampleEntityRepository>( nlpSampleEntityRepository = module.get<NlpSampleEntityRepository>(
NlpSampleEntityRepository, NlpSampleEntityRepository,
@ -92,6 +100,7 @@ describe('NlpSampleService', () => {
nlpSampleEntityRepository = module.get<NlpSampleEntityRepository>( nlpSampleEntityRepository = module.get<NlpSampleEntityRepository>(
NlpSampleEntityRepository, NlpSampleEntityRepository,
); );
languageService = module.get<LanguageService>(LanguageService);
languageRepository = module.get<LanguageRepository>(LanguageRepository); languageRepository = module.get<LanguageRepository>(LanguageRepository);
noNlpSample = await nlpSampleService.findOne({ text: 'No' }); noNlpSample = await nlpSampleService.findOne({ text: 'No' });
nlpSampleEntity = await nlpSampleEntityRepository.findOne({ nlpSampleEntity = await nlpSampleEntityRepository.findOne({
@ -162,4 +171,104 @@ describe('NlpSampleService', () => {
expect(result.deletedCount).toEqual(1); expect(result.deletedCount).toEqual(1);
}); });
}); });
describe('parseAndSaveDataset', () => {
it('should throw NotFoundException if no entities are found', async () => {
jest.spyOn(nlpEntityService, 'findAll').mockResolvedValue([]);
await expect(
nlpSampleService.parseAndSaveDataset(
'text,intent,language\nHello,none,en',
),
).rejects.toThrow(NotFoundException);
expect(nlpEntityService.findAll).toHaveBeenCalled();
});
it('should throw BadRequestException if CSV parsing fails', async () => {
const invalidCSV = 'text,intent,language\n"Hello,none'; // Malformed CSV
jest
.spyOn(nlpEntityService, 'findAll')
.mockResolvedValue([{ name: 'intent' } as NlpEntity]);
jest.spyOn(languageService, 'getLanguages').mockResolvedValue({});
jest
.spyOn(languageService, 'getDefaultLanguage')
.mockResolvedValue({ code: 'en' } as Language);
await expect(
nlpSampleService.parseAndSaveDataset(invalidCSV),
).rejects.toThrow(BadRequestException);
});
it('should filter out rows with "none" as intent', async () => {
const mockData = 'text,intent,language\nHello,none,en\nHi,greet,en';
jest
.spyOn(nlpEntityService, 'findAll')
.mockResolvedValue([{ name: 'intent' } as NlpEntity]);
jest
.spyOn(languageService, 'getLanguages')
.mockResolvedValue({ en: { id: '1' } });
jest
.spyOn(languageService, 'getDefaultLanguage')
.mockResolvedValue({ code: 'en' } as Language);
jest.spyOn(nlpSampleService, 'find').mockResolvedValue([]);
jest
.spyOn(nlpSampleService, 'create')
.mockResolvedValue({ id: '1', text: 'Hi' } as NlpSample);
jest.spyOn(nlpSampleEntityService, 'createMany').mockResolvedValue([]);
const result = await nlpSampleService.parseAndSaveDataset(mockData);
expect(result).toHaveLength(1);
expect(result[0].text).toEqual('Hi');
});
it('should fallback to the default language if the language is invalid', async () => {
const mockData = 'text,intent,language\nHi,greet,invalidLang';
jest
.spyOn(nlpEntityService, 'findAll')
.mockResolvedValue([{ name: 'intent' } as NlpEntity]);
jest
.spyOn(languageService, 'getLanguages')
.mockResolvedValue({ en: { id: '1' } });
jest
.spyOn(languageService, 'getDefaultLanguage')
.mockResolvedValue({ code: 'en' } as Language);
jest.spyOn(nlpSampleService, 'find').mockResolvedValue([]);
jest
.spyOn(nlpSampleService, 'create')
.mockResolvedValue({ id: '1', text: 'Hi' } as NlpSample);
jest.spyOn(nlpSampleEntityService, 'createMany').mockResolvedValue([]);
const result = await nlpSampleService.parseAndSaveDataset(mockData);
expect(result).toHaveLength(1);
expect(result[0].text).toEqual('Hi');
});
it('should successfully process and save valid dataset rows', async () => {
const mockData = 'text,intent,language\nHi,greet,en\nBye,bye,en';
const mockLanguages = { en: { id: '1' } };
jest
.spyOn(languageService, 'getLanguages')
.mockResolvedValue(mockLanguages);
jest
.spyOn(languageService, 'getDefaultLanguage')
.mockResolvedValue({ code: 'en' } as Language);
jest.spyOn(nlpSampleService, 'find').mockResolvedValue([]);
let id = 0;
jest.spyOn(nlpSampleService, 'create').mockImplementation((s) => {
return Promise.resolve({ id: (++id).toString(), ...s } as NlpSample);
});
jest.spyOn(nlpSampleEntityService, 'createMany').mockResolvedValue([]);
const result = await nlpSampleService.parseAndSaveDataset(mockData);
expect(nlpSampleEntityService.createMany).toHaveBeenCalledTimes(2);
expect(result).toHaveLength(2);
expect(result[0].text).toEqual('Hi');
expect(result[1].text).toEqual('Bye');
});
});
}); });

View File

@ -6,8 +6,13 @@
* 2. All derivative works must include clear attribution to the original creator and software, Hexastack and Hexabot, in a prominent location (e.g., in the software's "About" section, documentation, and README file). * 2. All derivative works must include clear attribution to the original creator and software, Hexastack and Hexabot, in a prominent location (e.g., in the software's "About" section, documentation, and README file).
*/ */
import { Injectable } from '@nestjs/common'; import {
BadRequestException,
Injectable,
NotFoundException,
} from '@nestjs/common';
import { OnEvent } from '@nestjs/event-emitter'; import { OnEvent } from '@nestjs/event-emitter';
import Papa from 'papaparse';
import { Message } from '@/chat/schemas/message.schema'; import { Message } from '@/chat/schemas/message.schema';
import { Language } from '@/i18n/schemas/language.schema'; import { Language } from '@/i18n/schemas/language.schema';
@ -23,7 +28,10 @@ import {
NlpSampleFull, NlpSampleFull,
NlpSamplePopulate, NlpSamplePopulate,
} from '../schemas/nlp-sample.schema'; } from '../schemas/nlp-sample.schema';
import { NlpSampleState } from '../schemas/types'; import { NlpSampleEntityValue, NlpSampleState } from '../schemas/types';
import { NlpEntityService } from './nlp-entity.service';
import { NlpSampleEntityService } from './nlp-sample-entity.service';
@Injectable() @Injectable()
export class NlpSampleService extends BaseService< export class NlpSampleService extends BaseService<
@ -33,6 +41,8 @@ export class NlpSampleService extends BaseService<
> { > {
constructor( constructor(
readonly repository: NlpSampleRepository, readonly repository: NlpSampleRepository,
private readonly nlpSampleEntityService: NlpSampleEntityService,
private readonly nlpEntityService: NlpEntityService,
private readonly languageService: LanguageService, private readonly languageService: LanguageService,
private readonly logger: LoggerService, private readonly logger: LoggerService,
) { ) {
@ -50,6 +60,110 @@ export class NlpSampleService extends BaseService<
return await this.repository.deleteOne(id); return await this.repository.deleteOne(id);
} }
/**
* This function is responsible for parsing a CSV dataset string and saving the parsed data into the database.
* It ensures that all necessary entities and languages exist, validates the dataset, and processes it row by row
* to create NLP samples and associated entities in the system.
*
* @param data - The raw CSV dataset as a string.
* @returns A promise that resolves to an array of created NLP samples.
*/
async parseAndSaveDataset(data: string) {
const allEntities = await this.nlpEntityService.findAll();
// Check if file location is present
if (allEntities.length === 0) {
throw new NotFoundException(
'No entities found, please create them first.',
);
}
// Parse local CSV file
const result: {
errors: any[];
data: Array<Record<string, string>>;
} = Papa.parse(data, {
header: true,
skipEmptyLines: true,
});
if (result.errors && result.errors.length > 0) {
this.logger.warn(
`Errors parsing the file: ${JSON.stringify(result.errors)}`,
);
throw new BadRequestException(result.errors, {
cause: result.errors,
description: 'Error while parsing CSV',
});
}
// Remove data with no intent
const filteredData = result.data.filter((d) => d.intent !== 'none');
const languages = await this.languageService.getLanguages();
const defaultLanguage = await this.languageService.getDefaultLanguage();
const nlpSamples: NlpSample[] = [];
// Reduce function to ensure executing promises one by one
for (const d of filteredData) {
try {
// Check if a sample with the same text already exists
const existingSamples = await this.find({
text: d.text,
});
// Skip if sample already exists
if (Array.isArray(existingSamples) && existingSamples.length > 0) {
continue;
}
// Fallback to default language if 'language' is missing or invalid
if (!d.language || !(d.language in languages)) {
if (d.language) {
this.logger.warn(
`Language "${d.language}" does not exist, falling back to default.`,
);
}
d.language = defaultLanguage.code;
}
// Create a new sample dto
const sample: NlpSampleCreateDto = {
text: d.text,
trained: false,
language: languages[d.language].id,
};
// Create a new sample entity dto
const entities: NlpSampleEntityValue[] = allEntities
.filter(({ name }) => name in d)
.map(({ name }) => ({
entity: name,
value: d[name],
}));
// Store any new entity/value
const storedEntities = await this.nlpEntityService.storeNewEntities(
sample.text,
entities,
['trait'],
);
// Store sample
const createdSample = await this.create(sample);
nlpSamples.push(createdSample);
// Map and assign the sample ID to each stored entity
const sampleEntities = storedEntities.map((storedEntity) => ({
...storedEntity,
sample: createdSample?.id,
}));
// Store sample entities
await this.nlpSampleEntityService.createMany(sampleEntities);
} catch (err) {
this.logger.error('Error occurred when extracting data. ', err);
}
}
return nlpSamples;
}
/** /**
* When a language gets deleted, we need to set related samples to null * When a language gets deleted, we need to set related samples to null
* *