Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security Solution][Detections] Performance enhancement for readRules function #76398

Closed
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,21 @@ export type Immutable = t.TypeOf<typeof immutable>;
// in case we encounter 3rd party rule systems which might be using auto incrementing numbers
// or other different things.
export const rule_id = t.string;
export const rule_ids = t.array(rule_id);
export type RuleId = t.TypeOf<typeof rule_id>;
export type RuleIds = t.TypeOf<typeof rule_ids>;

export const ruleIdOrUndefined = t.union([rule_id, t.undefined]);
export const ruleIdsOrUndefined = t.union([rule_ids, t.undefined]);
export type RuleIdOrUndefined = t.TypeOf<typeof ruleIdOrUndefined>;
export type RuleIdsOrUndefined = t.TypeOf<typeof ruleIdsOrUndefined>;

export const id = UUID;
export const ids = t.array(UUID);
export const idOrUndefined = t.union([id, t.undefined]);
export const idsOrUndefined = t.union([ids, t.undefined]);
export type IdOrUndefined = t.TypeOf<typeof idOrUndefined>;
export type IdsOrUndefined = t.TypeOf<typeof idsOrUndefined>;

export const index = t.array(t.string);
export type Index = t.TypeOf<typeof index>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import { SavedObjectsFindResponse } from 'kibana/server';
import { ActionResult } from '../../../../../../actions/server';
import { SanitizedAlert } from '../../../../../../alerts/common';
import { SignalSearchResponse } from '../../signals/types';
import {
DETECTION_ENGINE_RULES_URL,
Expand Down Expand Up @@ -159,14 +160,14 @@ export const getEmptyFindResult = (): FindHit => ({
data: [],
});

export const getFindResultWithSingleHit = (): FindHit => ({
export const getFindResultWithSingleHit = (): FindHit<SanitizedAlert> => ({
page: 1,
perPage: 1,
total: 1,
data: [getResult()],
});

export const nonRuleFindResult = (): FindHit => ({
export const nonRuleFindResult = (): FindHit<SanitizedAlert> => ({
page: 1,
perPage: 1,
total: 1,
Expand Down Expand Up @@ -409,7 +410,6 @@ export const getResult = (): RuleAlertType => ({
throttle: null,
createdBy: 'elastic',
updatedBy: 'elastic',
apiKey: null,
apiKeyOwner: 'elastic',
muteAll: false,
mutedInstanceIds: [],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* you may not use this file except in compliance with the Elastic License.
*/

import { Alert } from '../../../../../../alerts/common';
import { DETECTION_ENGINE_RULES_URL } from '../../../../../common/constants';
import { mlServicesMock, mlAuthzMock as mockMlAuthzFactory } from '../../../machine_learning/mocks';
import { buildMlAuthz } from '../../../machine_learning/authz';
Expand Down Expand Up @@ -34,7 +35,7 @@ describe('create_rules_bulk', () => {

clients.clusterClient.callAsCurrentUser.mockResolvedValue(getNonEmptyIndex()); // index exists
clients.alertsClient.find.mockResolvedValue(getEmptyFindResult()); // no existing rules
clients.alertsClient.create.mockResolvedValue(getResult()); // successful creation
clients.alertsClient.create.mockResolvedValue((getResult() as unknown) as Alert); // successful creation

createRulesBulkRoute(server.router, ml);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ export const createRulesBulkRoute = (router: IRouter, ml: SetupPlugins['ml']) =>

const ruleDefinitions = request.body;
const dupes = getDuplicates(ruleDefinitions, 'rule_id');
const foundRules = await readRules({
alertsClient,
ruleIds: ruleDefinitions.map((rule) => rule.rule_id),
id: undefined,
});

const rules = await Promise.all(
ruleDefinitions
Expand Down Expand Up @@ -132,15 +137,16 @@ export const createRulesBulkRoute = (router: IRouter, ml: SetupPlugins['ml']) =>
message: `To create a rule, the index must exist first. Index ${finalIndex} does not exist`,
});
}
if (ruleId != null) {
const rule = await readRules({ alertsClient, ruleId, id: undefined });
if (rule != null) {
return createBulkErrorObject({
ruleId,
statusCode: 409,
message: `rule_id: "${ruleId}" already exists`,
});
}
if (
ruleId != null &&
foundRules != null &&
foundRules.some((rule) => rule.params.ruleId === ruleId)
) {
return createBulkErrorObject({
ruleId,
statusCode: 409,
message: `rule_id: "${ruleId}" already exists`,
});
}
const createdRule = await createRules({
alertsClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* you may not use this file except in compliance with the Elastic License.
*/

import { Alert } from '../../../../../../alerts/common';
import { DETECTION_ENGINE_RULES_URL } from '../../../../../common/constants';
import {
getEmptyFindResult,
Expand Down Expand Up @@ -36,7 +37,7 @@ describe('create_rules', () => {

clients.clusterClient.callAsCurrentUser.mockResolvedValue(getNonEmptyIndex()); // index exists
clients.alertsClient.find.mockResolvedValue(getEmptyFindResult()); // no current rules
clients.alertsClient.create.mockResolvedValue(getResult()); // creation succeeds
clients.alertsClient.create.mockResolvedValue((getResult() as unknown) as Alert); // creation succeeds
clients.savedObjectsClient.find.mockResolvedValue(getFindResultStatus()); // needed to transform

createRulesRoute(server.router, ml);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,12 @@ export const createRulesRoute = (router: IRouter, ml: SetupPlugins['ml']): void
});
}
if (ruleId != null) {
const rule = await readRules({ alertsClient, ruleId, id: undefined });
if (rule != null) {
const readRulesResult = await readRules({
alertsClient,
ruleIds: [ruleId],
id: undefined,
});
if (readRulesResult != null && readRulesResult.length > 0) {
return siemResponse.error({
statusCode: 409,
body: `rule_id: "${ruleId}" already exists`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
import { createMockConfig, requestContextMock, serverMock, requestMock } from '../__mocks__';
import { mlServicesMock, mlAuthzMock as mockMlAuthzFactory } from '../../../machine_learning/mocks';
import { buildMlAuthz } from '../../../machine_learning/authz';
import { Alert } from '../../../../../../alerts/common';
import { importRulesRoute } from './import_rules_route';
import * as createRulesStreamFromNdJson from '../../rules/create_rules_stream_from_ndjson';
import {
Expand Down Expand Up @@ -55,6 +56,18 @@ describe('import_rules_route', () => {
expect(response.status).toEqual(200);
});

test('returns 500 if more than 10,000 rules are imported', async () => {
const ruleIds = new Array(10001).fill(undefined).map((_, index) => `rule-${index}`);
const multiRequest = getImportRulesRequest(buildHapiStream(ruleIdsToNdJsonString(ruleIds)));
const response = await server.inject(multiRequest, context);

expect(response.status).toEqual(500);
expect(response.body).toEqual({
message: "Can't import more than 10000 rules",
status_code: 500,
});
});

test('returns 404 if alertClient is not available on the route', async () => {
context.alerting!.getAlertsClient = jest.fn();
const response = await server.inject(request, context);
Expand Down Expand Up @@ -145,7 +158,7 @@ describe('import_rules_route', () => {

describe('single rule import', () => {
test('returns 200 if rule imported successfully', async () => {
clients.alertsClient.create.mockResolvedValue(getResult());
clients.alertsClient.create.mockResolvedValue((getResult() as unknown) as Alert);
const response = await server.inject(request, context);
expect(response.status).toEqual(200);
expect(response.body).toEqual({
Expand Down Expand Up @@ -229,6 +242,19 @@ describe('import_rules_route', () => {
});
});

test('returns 200 if many rules are imported successfully', async () => {
const ruleIds = new Array(9999).fill(undefined).map((_, index) => `rule-${index}`);
const multiRequest = getImportRulesRequest(buildHapiStream(ruleIdsToNdJsonString(ruleIds)));
const response = await server.inject(multiRequest, context);

expect(response.status).toEqual(200);
expect(response.body).toEqual({
errors: [],
success: true,
success_count: 9999,
});
});
dhurley14 marked this conversation as resolved.
Show resolved Hide resolved

test('returns 200 with errors if all rules are missing rule_ids and import fails on validation', async () => {
const rulesWithoutRuleIds = ['rule-1', 'rule-2'].map((ruleId) =>
getImportRulesWithIdSchemaMock(ruleId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

import { chunk } from 'lodash/fp';
import { extname } from 'path';
import { schema } from '@kbn/config-schema';

import { validate } from '../../../../../common/validate';
import {
importRulesQuerySchema,
ImportRulesQuerySchemaDecoded,
importRulesPayloadSchema,
ImportRulesPayloadSchemaDecoded,
ImportRulesSchemaDecoded,
} from '../../../../../common/detection_engine/schemas/request/import_rules_schema';
import {
Expand Down Expand Up @@ -48,7 +47,7 @@ import { PartialFilter } from '../../types';

type PromiseFromStreams = ImportRulesSchemaDecoded | Error;

const CHUNK_PARSED_OBJECT_SIZE = 10;
const CHUNK_PARSED_OBJECT_SIZE = 100;

export const importRulesRoute = (router: IRouter, config: ConfigType, ml: SetupPlugins['ml']) => {
router.post(
Expand All @@ -58,10 +57,7 @@ export const importRulesRoute = (router: IRouter, config: ConfigType, ml: SetupP
query: buildRouteValidation<typeof importRulesQuerySchema, ImportRulesQuerySchemaDecoded>(
importRulesQuerySchema
),
body: buildRouteValidation<
typeof importRulesPayloadSchema,
ImportRulesPayloadSchemaDecoded
>(importRulesPayloadSchema),
body: schema.any(), // validation on file object is accomplished later in the handler.
dhurley14 marked this conversation as resolved.
Show resolved Hide resolved
},
options: {
tags: ['access:securitySolution'],
Expand Down Expand Up @@ -119,6 +115,19 @@ export const importRulesRoute = (router: IRouter, config: ConfigType, ml: SetupP

while (chunkParseObjects.length) {
const batchParseObjects = chunkParseObjects.shift() ?? [];
const ruleIds = await Promise.all(
batchParseObjects
.filter((item): item is ImportRulesSchemaDecoded => !(item instanceof Error))
.map((parsedObject) => {
return parsedObject.rule_id;
})
);

const rules = await readRules({
alertsClient,
ruleIds: ruleIds.map((someId) => someId),
id: undefined,
});
const newImportRuleResponse = await Promise.all(
batchParseObjects.reduce<Array<Promise<ImportRuleResponse>>>((accum, parsedRule) => {
const importsWorkerPromise = new Promise<ImportRuleResponse>(async (resolve) => {
Expand Down Expand Up @@ -185,7 +194,8 @@ export const importRulesRoute = (router: IRouter, config: ConfigType, ml: SetupP

throwHttpError(await mlAuthz.validateRuleType(type));

const rule = await readRules({ alertsClient, ruleId, id: undefined });
const rule = rules?.find((aRule) => aRule.params.ruleId === ruleId);

if (rule == null) {
await createRules({
alertsClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ describe('patch_rules_bulk', () => {
path: `${DETECTION_ENGINE_RULES_URL}/bulk_update`,
body: [
{
rule_id: 'my-rule-id',
rule_id: 'rule-1',
anomaly_threshold: 4,
machine_learning_job_id: 'some_job_id',
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import { updateRulesNotifications } from '../../rules/update_rules_notifications
import { ruleStatusSavedObjectsClientFactory } from '../../signals/rule_status_saved_objects_client';
import { readRules } from '../../rules/read_rules';
import { PartialFilter } from '../../types';
import { RuleAlertType } from '../../rules/types';

export const patchRulesBulkRoute = (router: IRouter, ml: SetupPlugins['ml']) => {
router.patch(
Expand All @@ -51,8 +52,20 @@ export const patchRulesBulkRoute = (router: IRouter, ml: SetupPlugins['ml']) =>

const mlAuthz = buildMlAuthz({ license: context.licensing.license, ml, request });
const ruleStatusClient = ruleStatusSavedObjectsClientFactory(savedObjectsClient);
const ruleDefinitions = request.body;
const ruleIds = ruleDefinitions.reduce<string[]>((acc, rule) => {
if (rule != null && rule.rule_id != null) {
return [rule.rule_id, ...acc];
}
return acc;
}, []);
const foundRules: RuleAlertType[] | null = await readRules({
alertsClient,
ruleIds,
id: undefined,
});
const rules = await Promise.all(
request.body.map(async (payloadRule) => {
ruleDefinitions.map(async (payloadRule) => {
const {
actions: actionsRest,
author,
Expand Down Expand Up @@ -106,14 +119,18 @@ export const patchRulesBulkRoute = (router: IRouter, ml: SetupPlugins['ml']) =>
throwHttpError(await mlAuthz.validateRuleType(type));
}

const existingRule = await readRules({ alertsClient, ruleId, id });
if (existingRule?.params.type) {
const searchedRule = foundRules?.find((rule) => rule.params.ruleId === ruleId);
const existingRule = searchedRule
? [searchedRule]
: await readRules({ alertsClient, ruleIds: undefined, id });

if (existingRule != null && existingRule[0].params.type) {
// reject an unauthorized modification of an ML rule
throwHttpError(await mlAuthz.validateRuleType(existingRule?.params.type));
throwHttpError(await mlAuthz.validateRuleType(existingRule[0].params.type));
}

const rule = await patchRules({
rule: existingRule,
rule: existingRule != null && existingRule.length > 0 ? existingRule[0] : null,
alertsClient,
author,
buildingBlockType,
Expand Down Expand Up @@ -185,7 +202,6 @@ export const patchRulesBulkRoute = (router: IRouter, ml: SetupPlugins['ml']) =>
}
})
);

const [validated, errors] = validate(rules, rulesBulkSchema);
if (errors != null) {
return siemResponse.error({ statusCode: 500, body: errors });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,14 @@ export const patchRulesRoute = (router: IRouter, ml: SetupPlugins['ml']) => {
throwHttpError(await mlAuthz.validateRuleType(type));
}

const existingRule = await readRules({ alertsClient, ruleId, id });
if (existingRule?.params.type) {
const existingRules = await readRules({
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nit: seems like theres some repeated logic being used in the routes, is it worth pulling out somewhere? Maybe not since they're small chunks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think we could pull out the logic for rejecting unauthorized modification of an ML rule into a little function. That's a good idea. Not sure if that refactor would be a good fit for a minor release like this PR but that would clean up the code nicely. I'll take a look.

alertsClient,
ruleIds: ruleId ? [ruleId] : undefined,
id,
});
if (existingRules != null && existingRules.length > 0 && existingRules[0].params.type) {
// reject an unauthorized modification of an ML rule
throwHttpError(await mlAuthz.validateRuleType(existingRule?.params.type));
throwHttpError(await mlAuthz.validateRuleType(existingRules[0].params.type));
}

const ruleStatusClient = ruleStatusSavedObjectsClientFactory(savedObjectsClient);
Expand All @@ -129,7 +133,7 @@ export const patchRulesRoute = (router: IRouter, ml: SetupPlugins['ml']) => {
timelineTitle,
meta,
filters,
rule: existingRule,
rule: existingRules != null && existingRules.length > 0 ? existingRules[0] : null,
index,
interval,
maxSignals,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,25 @@ export const readRulesRoute = (router: IRouter) => {
}

const ruleStatusClient = ruleStatusSavedObjectsClientFactory(savedObjectsClient);
const rule = await readRules({
const rules = await readRules({
alertsClient,
id,
ruleId,
ruleIds: ruleId ? [ruleId] : undefined,
});
if (rule != null) {
if (rules != null && rules.length > 0) {
const ruleActions = await getRuleActionsSavedObject({
savedObjectsClient,
ruleAlertId: rule.id,
ruleAlertId: rules[0].id,
});
const ruleStatuses = await ruleStatusClient.find({
perPage: 1,
sortField: 'statusDate',
sortOrder: 'desc',
search: rule.id,
search: rules[0].id,
searchFields: ['alertId'],
});
const [validated, errors] = transformValidate(
rule,
rules[0],
ruleActions,
ruleStatuses.saved_objects[0]
);
Expand Down
Loading