diff --git a/src/database-controller/sdk/index.js b/src/database-controller/sdk/index.js index 244092c5d6..a74e3ab2c1 100644 --- a/src/database-controller/sdk/index.js +++ b/src/database-controller/sdk/index.js @@ -224,10 +224,41 @@ class DatabaseModel { }, ); + class Tag extends Model {} + Tag.init( + { + insertedAt: Sequelize.DATE, + uid: { + type: Sequelize.STRING(36), + primaryKey: true, + }, + frameworkName: { + type: Sequelize.STRING(64), + allowNull: false, + }, + name: { + type: Sequelize.STRING(64), + allowNull: false, + }, + }, + { + sequelize, + modelName: 'tag', + createdAt: 'insertedAt', + indexes: [ + { + unique: false, + fields: ['frameworkName'], + }, + ], + }, + ); + Framework.hasMany(FrameworkHistory); Framework.hasMany(Pod); Framework.hasMany(FrameworkEvent); Framework.hasMany(PodEvent); + Framework.hasMany(Tag); class Version extends Model {} Version.init( @@ -253,6 +284,7 @@ class DatabaseModel { this.Pod = Pod; this.FrameworkEvent = FrameworkEvent; this.PodEvent = PodEvent; + this.Tag = Tag; this.Version = Version; this.synchronizeSchema = this.synchronizeSchema.bind(this); } @@ -267,6 +299,7 @@ class DatabaseModel { this.Pod.sync({ alter: true }), this.FrameworkEvent.sync({ alter: true }), this.PodEvent.sync({ alter: true }), + this.Tag.sync({ alter: true }), this.Version.sync({ alter: true }), ]); } diff --git a/src/database-controller/sdk/package.json b/src/database-controller/sdk/package.json index aa918ca671..4fb39bc2b3 100644 --- a/src/database-controller/sdk/package.json +++ b/src/database-controller/sdk/package.json @@ -1,6 +1,6 @@ { "name": "openpaidbsdk", - "version": "1.0.0", + "version": "1.0.1", "scripts": { "test": "echo \"Error: no test specified\" && exit 1", "lint": "eslint --ext .js --ext .jsx ." diff --git a/src/database-controller/src/initializer/index.js b/src/database-controller/src/initializer/index.js index 119334c3e8..6e21fe8b6b 100644 --- a/src/database-controller/src/initializer/index.js +++ b/src/database-controller/src/initializer/index.js @@ -34,6 +34,8 @@ async function main() { const previousVersion = (await databaseModel.getVersion()).version; if (!previousVersion) { await updateFromNoDatabaseVersion(databaseModel); + } else { + await databaseModel.synchronizeSchema(); } await databaseModel.setVersion(paiVersion, paiCommitVersion); logger.info('Database has been successfully initialized.', function() { diff --git a/src/rest-server/docs/swagger.yaml b/src/rest-server/docs/swagger.yaml index 709a2c4b22..3e62a8faad 100644 --- a/src/rest-server/docs/swagger.yaml +++ b/src/rest-server/docs/swagger.yaml @@ -5,14 +5,14 @@ info: Open Platform for AI RESTful API docs. Version 2.0.1: add more examples and fix typos Version 2.0.2: update job detail and job attempt schema - Version 2.0.3: update parameters description of get storage list + Version 2.0.3: update parameters description of get storage list, update storage example and add get job config example Version 2.0.4: add default field in get storage list Version 2.0.5: add more parameters to job list; add submissionTime - Version 2.0.3: update storage example and add get job config example + Version 2.1.0: add add/delete tag api; add tags field in get job detail and get job list; add tags filter in get job list license: name: MIT License url: "https://github.com/microsoft/pai/blob/master/LICENSE" - version: 2.0.5 + version: 2.1.0 externalDocs: description: Find out more about OpenPAI url: "https://github.com/microsoft/pai" @@ -1121,6 +1121,16 @@ paths: description: filter jobs with keyword, we search keyword in user name, job name, and virtual cluster name schema: type: string + - name: tagsContain + in: query + description: filter jobs with tags. When multiple tags are specified, every job selected should have at least one of these tags + schema: + type: string + - name: tagsNotContain + in: query + description: filter jobs with tags. When multiple tags are specified, every job selected should have none of these tags + schema: + type: string - name: offset in: query description: list job offset @@ -1157,6 +1167,7 @@ paths: state: SUCCEEDED subState: Completed executionType: STOP + tags: ['abnormal', 'low_gpu_utilization'] retries: 0 submissionTime: 0 createdTime: 0 @@ -1186,6 +1197,7 @@ paths: $ref: "#/components/schemas/JobDetail" example: name: job name + tags: ['abnormal', 'low_gpu_utilization'] jobStatus: username: user name state: SUCCEEDED @@ -1303,6 +1315,88 @@ paths: $ref: "#/components/responses/NoJobError" "500": $ref: "#/components/responses/UnknownError" + "/api/v2/jobs/{user}~{job}/tag": + put: + tags: + - job + summary: Add a tag to a job. + description: Add a tag to a job. + operationId: addTag + security: + - bearerAuth: [] + parameters: + - $ref: "#/components/parameters/user" + - $ref: "#/components/parameters/job" + requestBody: + content: + application/json: + schema: + type: object + properties: + value: + type: string + description: tag + required: + - value + required: true + responses: + "200": + description: Succeeded + content: + application/json: + schema: + $ref: "#/components/schemas/Response" + example: + message: "Add tag {tag} for job {job} successfully." + "404": + $ref: "#/components/responses/NoJobError" + "500": + $ref: "#/components/responses/UnknownError" + delete: + tags: + - job + summary: Delete a tag from a job. + description: Delete a tag from a job. + operationId: deleteTag + security: + - bearerAuth: [] + parameters: + - $ref: "#/components/parameters/user" + - $ref: "#/components/parameters/job" + requestBody: + content: + application/json: + schema: + type: object + properties: + value: + type: string + description: tag + required: + - value + required: true + responses: + "200": + description: Succeeded + content: + application/json: + schema: + $ref: "#/components/schemas/Response" + example: + message: "Delete tag {tag} from job {job} successfully." + "404": + description: NoJobError or NoTagError + content: + application/json: + schema: + $ref: "#/components/schemas/Response" + examples: + NoJobError: + $ref: "#/components/responses/NoJobError/content/application~1json/examples/NoJobError" + NoTagError: + $ref: "#/components/responses/NoTagError/content/application~1json/examples/NoTagError" + "500": + $ref: "#/components/responses/UnknownError" "/api/v2/jobs/{user}~{job}/job-attempts/healthz": get: tags: @@ -1675,6 +1769,11 @@ components: enum: - START - STOP + tags: + type: array + description: tags + items: + type: string retries: type: integer description: job retried times @@ -1739,6 +1838,11 @@ components: name: type: string description: job name + tags: + type: array + description: tags + items: + type: string jobStatus: type: object description: job status @@ -2566,6 +2670,17 @@ components: value: code: NoJobError message: "Job {job} is not found." + NoTagError: + description: NoTagError + content: + application/json: + schema: + $ref: "#/components/schemas/Response" + examples: + NoTagError: + value: + code: NoTagError + message: "Tag {tag} is not found for job {job} ." NoJobConfigError: description: NoJobConfigError content: diff --git a/src/rest-server/src/controllers/v2/job.js b/src/rest-server/src/controllers/v2/job.js index 92620f0373..988bf0ef6a 100644 --- a/src/rest-server/src/controllers/v2/job.js +++ b/src/rest-server/src/controllers/v2/job.js @@ -26,8 +26,11 @@ const { Op } = require('sequelize'); const list = asyncHandler(async (req, res) => { // ?keyword=&username=,&vc=, // &state=,&offset=&limit=&withTotalCount=true + // &tags=, // &order=state,DESC const filters = {}; + const tagsContainFilter = {}; + const tagsNotContainFilter = {}; let offset = 0; let limit; let withTotalCount = false; @@ -63,6 +66,12 @@ const list = asyncHandler(async (req, res) => { if ('withTotalCount' in req.query && req.query.withTotalCount === 'true') { withTotalCount = true; } + if ('tagsContain' in req.query) { + tagsContainFilter.name = req.query.tagsContain.split(','); + } + if ('tagsNotContain' in req.query) { + tagsNotContainFilter.name = req.query.tagsNotContain.split(','); + } if ('keyword' in req.query) { // match text in username, jobname, or vc filters[Op.or] = [ @@ -126,6 +135,8 @@ const list = asyncHandler(async (req, res) => { const data = await job.list( attributes, filters, + tagsContainFilter, + tagsNotContainFilter, order, offset, limit, @@ -207,6 +218,22 @@ const getSshInfo = asyncHandler(async (req, res) => { res.json(data); }); +const addTag = asyncHandler(async (req, res) => { + await job.addTag(req.params.frameworkName, req.body.value); + res.status(status('OK')).json({ + status: status('OK'), + message: `Add tag ${req.body.value} for job ${req.params.frameworkName} successfully.`, + }); +}); + +const deleteTag = asyncHandler(async (req, res) => { + await job.deleteTag(req.params.frameworkName, req.body.value); + res.status(status('OK')).json({ + status: status('OK'), + message: `Delete tag ${req.body.value} from job ${req.params.frameworkName} successfully.`, + }); +}); + // module exports module.exports = { list, @@ -215,4 +242,6 @@ module.exports = { execute, getConfig, getSshInfo, + addTag, + deleteTag, }; diff --git a/src/rest-server/src/middlewares/v2/protocol.js b/src/rest-server/src/middlewares/v2/protocol.js index 6f5dad3d78..649e36ca52 100644 --- a/src/rest-server/src/middlewares/v2/protocol.js +++ b/src/rest-server/src/middlewares/v2/protocol.js @@ -211,9 +211,24 @@ const protocolSubmitMiddleware = [ }), ]; +const validateTagMiddleware = async (req, _, next) => { + // tag should not include ',' + if (req.body.value.includes(',')) { + return next( + createError( + 'Bad Request', + 'InvalidProtocolError', + "tag should not include ','", + ), + ); + } + next(); +}; + // module exports module.exports = { validate: protocolValidate, render: protocolRender, submit: protocolSubmitMiddleware, + validateTag: validateTagMiddleware, }; diff --git a/src/rest-server/src/models/v2/job/k8s.js b/src/rest-server/src/models/v2/job/k8s.js index 7e68e3ee03..83e4bbb9b6 100644 --- a/src/rest-server/src/models/v2/job/k8s.js +++ b/src/rest-server/src/models/v2/job/k8s.js @@ -36,6 +36,8 @@ const { apiserver } = require('@pai/config/kubernetes'); const schedulePort = require('@pai/config/schedule-port'); const databaseModel = require('@pai/utils/dbUtils'); +const Sequelize = require('sequelize'); + let exitSpecPath; if (process.env[env.exitSpecPath]) { exitSpecPath = process.env[env.exitSpecPath]; @@ -134,6 +136,7 @@ const convertFrameworkSummary = (framework) => { state: framework.state, subState: framework.subState, executionType: framework.executionType.toUpperCase(), + tags: framework.tags.reduce((arr, curr) => [...arr, curr.name], []), retries: framework.retries, retryDetails: { user: framework.userRetries, @@ -241,7 +244,7 @@ const convertTaskDetail = async (taskStatus, ports, logPathPrefix) => { }; }; -const convertFrameworkDetail = async (framework) => { +const convertFrameworkDetail = async (framework, tags) => { const attemptStatus = framework.status.attemptStatus; // check fields which may be compressed if (attemptStatus.taskRoleStatuses == null) { @@ -267,9 +270,11 @@ const convertFrameworkDetail = async (framework) => { const completionStatus = attemptStatus.completionStatus; const diagnostics = completionStatus ? completionStatus.diagnostics : null; const exitDiagnostics = generateExitDiagnostics(diagnostics); + const detail = { debugId: framework.metadata.name, name: jobName, + tags: tags.reduce((arr, curr) => [...arr, curr.name], []), jobStatus: { username: userName, state: convertState( @@ -878,6 +883,8 @@ const getConfigSecretDef = (frameworkName, secrets) => { const list = async ( attributes, filters, + tagsContainFilter, + tagsNotContainFilter, order, offset, limit, @@ -885,12 +892,53 @@ const list = async ( ) => { let frameworks; let totalCount; + + if ( + Object.keys(tagsContainFilter).length !== 0 || + Object.keys(tagsNotContainFilter).length !== 0 + ) { + filters.name = {}; + // tagsContain + if (Object.keys(tagsContainFilter).length !== 0) { + const queryContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery( + 'tags', + { + attributes: ['frameworkName'], + where: tagsContainFilter, + }, + ); + filters.name[Sequelize.Op.in] = Sequelize.literal(` + (${queryContainFrameworkName.slice(0, -1)}) + `); + } + // tagsNotContain + if (Object.keys(tagsNotContainFilter).length !== 0) { + const queryNotContainFrameworkName = databaseModel.sequelize.dialect.QueryGenerator.selectQuery( + 'tags', + { + attributes: ['frameworkName'], + where: tagsNotContainFilter, + }, + ); + filters.name[Sequelize.Op.notIn] = Sequelize.literal(` + (${queryNotContainFrameworkName.slice(0, -1)}) + `); + } + } + frameworks = await databaseModel.Framework.findAll({ attributes: attributes, where: filters, offset: offset, limit: limit, order: order, + include: [ + { + attributes: ['name'], + required: Object.keys(tagsContainFilter).length !== 0, + model: databaseModel.Tag, + }, + ], }); if (withTotalCount) { totalCount = await databaseModel.Framework.count({ where: filters }); @@ -912,10 +960,18 @@ const get = async (frameworkName) => { const framework = await databaseModel.Framework.findOne({ attributes: ['submissionTime', 'snapshot'], where: { name: encodeName(frameworkName) }, + include: [ + { + attributes: ['name'], + model: databaseModel.Tag, + as: 'tags', + }, + ], }); if (framework) { const frameworkDetail = await convertFrameworkDetail( JSON.parse(framework.snapshot), + framework.tags, ); frameworkDetail.jobStatus.submissionTime = new Date( framework.submissionTime, @@ -1142,6 +1198,63 @@ const getSshInfo = async (frameworkName) => { ); }; +const addTag = async (frameworkName, tag) => { + // check if frameworkName exist + const framework = await databaseModel.Framework.findOne({ + where: { name: encodeName(frameworkName) }, + }); + + if (framework) { + // add tag + const data = await databaseModel.Tag.findOrCreate({ + where: { + frameworkName: encodeName(frameworkName), + name: tag, + uid: encodeName(`${frameworkName}+${tag}`), + }, + }); + return data; + } else { + throw createError( + 'Not Found', + 'NoJobError', + `Job ${frameworkName} is not found.`, + ); + } +}; + +const deleteTag = async (frameworkName, tag) => { + // check if frameworkName exist + const framework = await databaseModel.Framework.findOne({ + where: { name: encodeName(frameworkName) }, + }); + + if (framework) { + // remove tag + const numDestroyedRows = await databaseModel.Tag.destroy({ + where: { + frameworkName: encodeName(frameworkName), + name: tag, + }, + }); + if (numDestroyedRows === 0) { + throw createError( + 'Not Found', + 'NoTagError', + `Tag ${tag} is not found for job ${frameworkName}.`, + ); + } else { + return numDestroyedRows; + } + } else { + throw createError( + 'Not Found', + 'NoJobError', + `Job ${frameworkName} is not found.`, + ); + } +}; + const generateExitDiagnostics = (diag) => { if (_.isEmpty(diag)) { return null; @@ -1249,4 +1362,6 @@ module.exports = { execute, getConfig, getSshInfo, + addTag, + deleteTag, }; diff --git a/src/rest-server/src/routes/v2/job.js b/src/rest-server/src/routes/v2/job.js index 17f11dde32..573587b998 100644 --- a/src/rest-server/src/routes/v2/job.js +++ b/src/rest-server/src/routes/v2/job.js @@ -54,5 +54,15 @@ router router.use('/:frameworkName/job-attempts', jobAttemptRouter); +router + .route('/:frameworkName/tag') + /** PUT /api/v2/jobs/:frameworkName/tag - Add a framework tag */ + .put(token.check, protocol.validateTag, controller.addTag); + +router + .route('/:frameworkName/tag') + /** DELETE /api/v2/jobs/:frameworkName/tag - Delete a framework tag */ + .delete(token.check, controller.deleteTag); + // module exports module.exports = router;