Skip to content

Commit

Permalink
Adding a list of generation and a limit of 3 generations (#9)
Browse files Browse the repository at this point in the history
- Adding field lastGenerationsAsc of type LastGeneration in User
- Adding a method canGenerateImage in User to limit to 3 generations

The list of last generation is a field populate at each generation.
It automaticaly handle the wipe of the list thanks to the method addGenerationId based on the date and month of the previous generation
  • Loading branch information
BaptisteLecat authored May 17, 2024
1 parent 19d81c7 commit 688ed1b
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 19 deletions.
73 changes: 66 additions & 7 deletions src/modules/generation/controllers/generation.controller.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
import { Body, Controller, Get, Inject, NotFoundException, Param, Post, Request, UseGuards } from '@nestjs/common';
import {
Body,
Controller,
Get,
HttpException,
Inject,
NotFoundException,
Param,
Post,
Request,
UseGuards,
} from '@nestjs/common';
import { PromptGeneratorService } from '../services/prompt_generator.service';
import { CustomLogging } from '../../logging/custom-logging';
import { ApiBearerAuth, ApiBody, ApiParam } from '@nestjs/swagger';
Expand All @@ -9,43 +20,91 @@ import { GenerationsService } from '../services/generations.service';
import { LocationsService } from '../../locations/locations.service';
import { CreateGenerationDto } from '../dto/create-generation.dto';
import { UsersService } from '../../users/users.service';
import { LastGeneration } from '../../users/entities/lastGeneration.entity';

@UseGuards(ApiKeyAuthGuard, JwtAuthGuard)
@ApiBearerAuth('JWT-auth')
@ApiBearerAuth('ApiKey')
@Controller({
version: '1',
path: 'locations/:locationId/generations'
path: 'locations/:locationId/generations',
})
export class GenerationController {
private readonly logger = new CustomLogging(GenerationController.name);
constructor(@Inject(PromptGeneratorService) private readonly promptGenerationService: PromptGeneratorService, @Inject(GenerationsService) private readonly generationService: GenerationsService, @Inject(LocationsService) private readonly locationService: LocationsService, @Inject(UsersService) private readonly userService : UsersService) { }
constructor(
@Inject(PromptGeneratorService)
private readonly promptGenerationService: PromptGeneratorService,
@Inject(GenerationsService)
private readonly generationService: GenerationsService,
@Inject(LocationsService)
private readonly locationService: LocationsService,
@Inject(UsersService) private readonly userService: UsersService,
) {}

@ApiParam({ name: 'locationId', type: String })
@ApiBody({ type: CreateGenerationDto })
@Post()
async getGeneration(@Request() request, @Param('locationId') locationId: string, @Body() createGeneration : CreateGenerationDto): Promise<Generation> {
async getGeneration(
@Request() request,
@Param('locationId') locationId: string,
@Body() createGeneration: CreateGenerationDto,
): Promise<Generation> {
let generation: Generation | null = null;
const appUser = request.user;
locationId = decodeURIComponent(locationId);
console.log(locationId);
this.logger.log('Getting user data');
const user = await this.userService.findOne(appUser.id);

this.logger.log('Checking if the user can generate an image');
const canGenerate = user.canGenerateImage();
if (!canGenerate) {
this.logger.log('User can not generate an image');
//Throw a 429 exception Too Many Requests
throw new HttpException(
'You have reached the maximum number of generations per day',
429,
);
}

this.logger.log('Checking if the location exists');
const location = await this.locationService.findOne(locationId, user.id);
if (!location) {
this.logger.log('Location does not exist');
throw new NotFoundException( `Location with id ${locationId} does not exist`);
throw new NotFoundException(
`Location with id ${locationId} does not exist`,
);
}
this.logger.log('Generating prompt');
let imageStyle = '';
if (user.styles !== undefined && user.styles.length !== 0) {
this.logger.log(`User style: ${user.styles}`);
imageStyle = `${user.styles.join(', ')}`;
}
const prompt = await this.promptGenerationService.generatePrompt(imageStyle, createGeneration.time, location.city, createGeneration.weather);
const prompt = await this.promptGenerationService.generatePrompt(
imageStyle,
createGeneration.time,
location.city,
createGeneration.weather,
);
this.logger.log('Starting image generation');
generation = await this.generationService.create(null, user.id, location.id, prompt);
generation = await this.generationService.create(
null,
user.id,
location.id,
prompt,
);
console.log(`Generation ID : ${generation.id}`);
if (generation) {
this.logger.log(
'Inserting generation ID in user data : lastGenerationsIdAsc',
);
user.addGenerationId(
new LastGeneration(generation.id, generation.createdAt),
);
this.logger.log('Updating user data');
await this.userService.update(user);
}
return generation;
}
}
33 changes: 33 additions & 0 deletions src/modules/users/entities/lastGeneration.entity.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import { Timestamp } from '@google-cloud/firestore';

export class LastGeneration {
generationId: string;
insertedTimestamp: Timestamp;

public constructor(generationId: string, insertedTimestamp: Timestamp) {
this.generationId = generationId;
this.insertedTimestamp = insertedTimestamp;
}

static fromFirestoreDocument(id: any, data: any): LastGeneration {
return new LastGeneration(id, data);
}

static fromJson(data: any): LastGeneration {
return new LastGeneration(data.generationId, data.insertedTimestamp);
}

toFirestoreDocument(): any {
return {
generationId: this.generationId,
insertedTimestamp: this.insertedTimestamp,
};
}

toJson(): any {
return {
generationId: this.generationId,
insertedTimestamp: this.insertedTimestamp,
};
}
}
93 changes: 88 additions & 5 deletions src/modules/users/entities/user.entity.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Style } from "../../styles/entities/style.entity";
import { Style } from '../../styles/entities/style.entity';
import { LastGeneration } from './lastGeneration.entity';

export class User {
id: string;
Expand All @@ -7,33 +8,79 @@ export class User {
email: string;
styles?: Style[];
frequencies?: string[];
lastGenerationsAsc?: LastGeneration[];

public constructor(id: string, firstname: string, lastname: string, email: string, styles?: Style[], frequencies?: string[]) {
public constructor(
id: string,
firstname: string,
lastname: string,
email: string,
styles?: Style[],
frequencies?: string[],
lastGenerationsAsc?: LastGeneration[],
) {
this.id = id;
this.firstname = firstname;
this.lastname = lastname;
this.email = email;
this.styles = styles;
this.frequencies = frequencies;
this.lastGenerationsAsc = lastGenerationsAsc;
}

static fromFirestoreDocument(id: any, data: any): User {
return new User(id, data.firstname, data.lastname, data.email, data.styles, data.frequencies);
return new User(
id,
data.firstname,
data.lastname,
data.email,
data.styles,
data.frequencies,
data.lastGenerationsAsc != null
? data.lastGenerationsAsc.map(
(lastGeneration: { generationId: any; insertedTimestamp: any }) =>
LastGeneration.fromFirestoreDocument(
lastGeneration.generationId,
lastGeneration.insertedTimestamp,
),
)
: [],
);
}

static fromJson(data: any): User {
return new User(data.id, data.firstname, data.lastname, data.email, data.styles, data.frequencies);
return new User(
data.id,
data.firstname,
data.lastname,
data.email,
data.styles,
data.frequencies,
data.lastGenerationsAsc,
);
}

toFirestoreDocument(): any {
return {
const firestoreDocument: any = {
id: this.id,
firstname: this.firstname,
lastname: this.lastname,
email: this.email,
styles: this.styles ? this.styles : [],
frequencies: this.frequencies ? this.frequencies : [],
lastGenerationsAsc: this.lastGenerationsAsc
? this.lastGenerationsAsc.map((lastGeneration) =>
lastGeneration.toFirestoreDocument(),
)
: [],
};
//Remove all undefined values
Object.keys(firestoreDocument).forEach(
(key) =>
firestoreDocument[key] === undefined && delete firestoreDocument[key],
);

return firestoreDocument;
}

toJson(): any {
Expand All @@ -44,6 +91,42 @@ export class User {
email: this.email,
styles: this.styles ? this.styles : [],
frequencies: this.frequencies ? this.frequencies : [],
lastGenerationsAsc: this.lastGenerationsAsc
? this.lastGenerationsAsc.map((lastGeneration) =>
lastGeneration.toJson(),
)
: [],
};
}

addGenerationId(lastGeneration: LastGeneration) {
if (this.lastGenerationsAsc === undefined) {
this.lastGenerationsAsc = [];
}

//If the day/month of the first generation is the same as the day/month of the generation to add, we add the generation to the list
//Otherwise, we wipe the list and add the generation
if (this.lastGenerationsAsc.length !== 0) {
console.log(this.lastGenerationsAsc);
const lastGenerationDate =
this.lastGenerationsAsc[0].insertedTimestamp.toDate();
const generationDate = lastGeneration.insertedTimestamp.toDate();
if (
lastGenerationDate.getDate() === generationDate.getDate() &&
lastGenerationDate.getMonth() === generationDate.getMonth()
) {
this.lastGenerationsAsc.unshift(lastGeneration);
} else {
this.lastGenerationsAsc = [lastGeneration];
}
} else {
this.lastGenerationsAsc = [lastGeneration];
}
}

canGenerateImage(): boolean {
if (this.lastGenerationsAsc.length < 3) {
return true;
}
}
}
39 changes: 32 additions & 7 deletions src/modules/users/users.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,55 @@ import { Inject, Injectable } from '@nestjs/common';
import { FirebaseProvider } from '../../providers/firebase.provider';
import { UserConverter } from './converters/user.converter';
import { User } from './entities/user.entity';
import { CreateUserDto } from './dto/create-user.dto';
import { StylesService } from '../styles/styles.service';

@Injectable()
export class UsersService {
static readonly collection: string = 'users';
constructor(@Inject(FirebaseProvider) private readonly firestoreProvider: FirebaseProvider, private userConverter: UserConverter, @Inject(StylesService) private readonly stylesService: StylesService) { }
constructor(
@Inject(FirebaseProvider)
private readonly firestoreProvider: FirebaseProvider,
private userConverter: UserConverter,
@Inject(StylesService) private readonly stylesService: StylesService,
) {}

async findAll(): Promise<User[]> {
const users = await this.firestoreProvider.getFirestore().collection(UsersService.collection).withConverter(this.userConverter).get();
return users.docs.map(user => this.userConverter.fromFirestore(user));
const users = await this.firestoreProvider
.getFirestore()
.collection(UsersService.collection)
.withConverter(this.userConverter)
.get();
return users.docs.map((user) => this.userConverter.fromFirestore(user));
}

async findOne(id: string): Promise<User | undefined> {
const user = await this.firestoreProvider.getFirestore().collection(UsersService.collection).doc(id).withConverter(this.userConverter).get();
const user = await this.firestoreProvider
.getFirestore()
.collection(UsersService.collection)
.doc(id)
.withConverter(this.userConverter)
.get();
if (!user.exists) {
return undefined;
}
let userObject = this.userConverter.fromFirestoreDocumentSnapshot(user);
const userObject = this.userConverter.fromFirestoreDocumentSnapshot(user);
// Object style in user only contains id, so we need to fetch the styles from the styles collection.
if (user.data().styles != undefined && user.data().styles.length === 0) {
const styles = await this.stylesService.findAll(user.data().styles.map(style => style.id));
const styles = await this.stylesService.findAll(
user.data().styles.map((style) => style.id),
);
userObject.styles = styles;
}
return userObject;
}

async update(user: User): Promise<User> {
const userRef = await this.firestoreProvider
.getFirestore()
.collection(UsersService.collection)
.doc(user.id)
.withConverter(this.userConverter)
.update(user.toFirestoreDocument());
return user;
}
}

0 comments on commit 688ed1b

Please sign in to comment.