Skip to content

Commit

Permalink
Support compact RunRequest so requests sent to heavy plugin services …
Browse files Browse the repository at this point in the history
…still fit in the default gRPC message limit.

PiperOrigin-RevId: 656270995
Change-Id: I3216fe01796ff866caa464aad1d49d1c0dd35256
  • Loading branch information
Tsunami Team authored and copybara-github committed Jul 26, 2024
1 parent b4b2089 commit 67f5d79
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.tsunami.common.server;

import com.google.common.collect.ImmutableList;
import com.google.tsunami.proto.MatchedPlugin;
import com.google.tsunami.proto.NetworkService;
import com.google.tsunami.proto.RunCompactRequest;
import com.google.tsunami.proto.RunCompactRequest.PluginNetworkServiceTarget;
import com.google.tsunami.proto.RunRequest;
import java.util.HashMap;

/**
* CompactRunRequestHelper is a helper class to compress/uncompress the RunRequest into/from the
* compact representation.
*/
public final class CompactRunRequestHelper {

private CompactRunRequestHelper() {}

public static RunCompactRequest compress(RunRequest runRequest) {
var builder = RunCompactRequest.newBuilder().setTarget(runRequest.getTarget());
HashMap<NetworkService, Integer> serviceIndexMap = new HashMap<>();
int pluginIndex = -1;
for (MatchedPlugin matchedPlugin : runRequest.getPluginsList()) {
pluginIndex++;
builder.addPlugins(matchedPlugin.getPlugin());
for (NetworkService service : matchedPlugin.getServicesList()) {
Integer serviceIndex = serviceIndexMap.get(service);
if (serviceIndex == null) {
serviceIndex = serviceIndexMap.size();
serviceIndexMap.put(service, serviceIndex);
builder.addServices(service);
}

builder.addScanTargets(
PluginNetworkServiceTarget.newBuilder()
.setPluginIndex(pluginIndex)
.setServiceIndex(serviceIndex)
.build());
}
}
return builder.build();
}

public static RunRequest uncompress(RunCompactRequest runCompactRequest) {
ImmutableList.Builder<MatchedPlugin> matchedPlugins = ImmutableList.builder();
for (var target : runCompactRequest.getScanTargetsList()) {
var plugin = runCompactRequest.getPlugins(target.getPluginIndex());
var networkService = runCompactRequest.getServices(target.getServiceIndex());
matchedPlugins.add(
MatchedPlugin.newBuilder().setPlugin(plugin).addServices(networkService).build());
}

return RunRequest.newBuilder()
.setTarget(runCompactRequest.getTarget())
.addAllPlugins(matchedPlugins.build())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.tsunami.common.server;

import static com.google.common.truth.Truth.assertThat;

import com.google.common.collect.ImmutableList;
import com.google.tsunami.proto.Hostname;
import com.google.tsunami.proto.MatchedPlugin;
import com.google.tsunami.proto.NetworkEndpoint;
import com.google.tsunami.proto.NetworkService;
import com.google.tsunami.proto.PluginDefinition;
import com.google.tsunami.proto.PluginInfo;
import com.google.tsunami.proto.RunCompactRequest;
import com.google.tsunami.proto.RunCompactRequest.PluginNetworkServiceTarget;
import com.google.tsunami.proto.RunRequest;
import com.google.tsunami.proto.TargetInfo;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class CompactRunRequestHelperTest {

@Test
public void compressingRunRequest_isMoreCompact() {
NetworkService service1 = NetworkService.newBuilder().setServiceName("service1").build();
NetworkService service2 = NetworkService.newBuilder().setServiceName("service2").build();
PluginDefinition plugin1 =
PluginDefinition.newBuilder()
.setInfo(PluginInfo.newBuilder().setName("plugin1").build())
.build();
PluginDefinition plugin2 =
PluginDefinition.newBuilder()
.setInfo(PluginInfo.newBuilder().setName("plugin2").build())
.build();
PluginDefinition plugin3 =
PluginDefinition.newBuilder()
.setInfo(PluginInfo.newBuilder().setName("plugin3").build())
.build();
MatchedPlugin matchedPlugin1 =
MatchedPlugin.newBuilder().addServices(service1).setPlugin(plugin1).build();
MatchedPlugin matchedPlugin2 =
MatchedPlugin.newBuilder().addServices(service2).setPlugin(plugin2).build();
MatchedPlugin matchedPlugin3 =
MatchedPlugin.newBuilder().addServices(service1).setPlugin(plugin3).build();
ImmutableList<MatchedPlugin> expectedMatchedPlugins =
ImmutableList.of(matchedPlugin1, matchedPlugin2, matchedPlugin3);
TargetInfo expectedTargetInfo =
TargetInfo.newBuilder()
.addNetworkEndpoints(
NetworkEndpoint.newBuilder()
.setHostname(Hostname.newBuilder().setName("example.com").build())
.build())
.build();
RunRequest expectedUncompressedRunRequest =
RunRequest.newBuilder()
.setTarget(expectedTargetInfo)
.addAllPlugins(expectedMatchedPlugins)
.build();
var actualCompressedRunRequest =
CompactRunRequestHelper.compress(expectedUncompressedRunRequest);

var expectedCompressedRunRequest =
RunCompactRequest.newBuilder()
.setTarget(expectedTargetInfo)
.addServices(service1)
.addServices(service2)
.addPlugins(plugin1)
.addPlugins(plugin2)
.addPlugins(plugin3)
.addScanTargets(
PluginNetworkServiceTarget.newBuilder()
.setPluginIndex(0)
.setServiceIndex(0)
.build())
.addScanTargets(
PluginNetworkServiceTarget.newBuilder()
.setPluginIndex(1)
.setServiceIndex(1)
.build())
.addScanTargets(
PluginNetworkServiceTarget.newBuilder()
.setPluginIndex(2)
.setServiceIndex(0)
.build())
.build();
assertThat(actualCompressedRunRequest).isEqualTo(expectedCompressedRunRequest);

// And now uncompressing it again:
var actualUncompressedRunRequest =
CompactRunRequestHelper.uncompress(actualCompressedRunRequest);

// It should match the original setup
assertThat(actualUncompressedRunRequest).isEqualTo(expectedUncompressedRunRequest);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.tsunami.proto.ListPluginsResponse;
import com.google.tsunami.proto.PluginServiceGrpc;
import com.google.tsunami.proto.PluginServiceGrpc.PluginServiceFutureStub;
import com.google.tsunami.proto.RunCompactRequest;
import com.google.tsunami.proto.RunRequest;
import com.google.tsunami.proto.RunResponse;
import io.grpc.Channel;
Expand Down Expand Up @@ -56,6 +57,18 @@ public ListenableFuture<RunResponse> runWithDeadline(RunRequest request, Deadlin
return pluginService.withDeadline(deadline).run(request);
}

/**
* Sends a runCompact request to the gRPC language server with a specified deadline.
*
* @param request The main request containing plugins to run.
* @param deadline The timeout of the service.
* @return The future of the run response.
*/
public ListenableFuture<RunResponse> runCompactWithDeadline(
RunCompactRequest request, Deadline deadline) {
return pluginService.withDeadline(deadline).runCompact(request);
}

/**
* Sends a list plugins request to the gRPC language server with a specified deadline.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
import com.google.common.collect.Sets;
import com.google.common.flogger.GoogleLogger;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.tsunami.common.server.CompactRunRequestHelper;
import com.google.tsunami.proto.DetectionReportList;
import com.google.tsunami.proto.ListPluginsRequest;
import com.google.tsunami.proto.MatchedPlugin;
import com.google.tsunami.proto.NetworkService;
import com.google.tsunami.proto.PluginDefinition;
import com.google.tsunami.proto.RunRequest;
import com.google.tsunami.proto.RunResponse;
import com.google.tsunami.proto.TargetInfo;
import io.grpc.Channel;
import io.grpc.Deadline;
Expand All @@ -53,6 +55,7 @@ public final class RemoteVulnDetectorImpl implements RemoteVulnDetector {
private final ExponentialBackOff backoff;
private final int maxAttempts;
private final Deadline deadline;
private boolean wantCompactRunRequest = false;

RemoteVulnDetectorImpl(
Channel channel, ExponentialBackOff backoff, int maxAttempts, Deadline deadline) {
Expand All @@ -68,13 +71,17 @@ public DetectionReportList detect(
TargetInfo target, ImmutableList<NetworkService> matchedServices) {
try {
if (checkHealthWithBackoffs()) {
var runRequest =
RunRequest.newBuilder().setTarget(target).addAllPlugins(pluginsToRun).build();
logger.atInfo().log("Detecting with language server plugins...");
return service
.runWithDeadline(
RunRequest.newBuilder().setTarget(target).addAllPlugins(pluginsToRun).build(),
deadline)
.get()
.getReports();
RunResponse runResponse;
if (this.wantCompactRunRequest) {
var runCompactRequest = CompactRunRequestHelper.compress(runRequest);
runResponse = service.runCompactWithDeadline(runCompactRequest, deadline).get();
} else {
runResponse = service.runWithDeadline(runRequest, deadline).get();
}
return runResponse.getReports();
}
} catch (InterruptedException | ExecutionException e) {
throw new LanguageServerException("Failed to get response from language server.", e);
Expand All @@ -87,11 +94,14 @@ public ImmutableList<PluginDefinition> getAllPlugins() {
try {
if (checkHealthWithBackoffs()) {
logger.atInfo().log("Getting language server plugins...");
return ImmutableList.copyOf(
var listPluginsResponse =
service
.listPluginsWithDeadline(ListPluginsRequest.getDefaultInstance(), DEFAULT_DEADLINE)
.get()
.getPluginsList());
.get();
// Note: each plugin service client has a dedicated RemoteVulnDetectorImpl instance,
// so we can safely set this flag here.
this.wantCompactRunRequest = listPluginsResponse.getWantCompactRunRequest();
return ImmutableList.copyOf(listPluginsResponse.getPluginsList());
} else {
return ImmutableList.of();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.tsunami.proto.PluginDefinition;
import com.google.tsunami.proto.PluginInfo;
import com.google.tsunami.proto.PluginServiceGrpc.PluginServiceImplBase;
import com.google.tsunami.proto.RunCompactRequest;
import com.google.tsunami.proto.RunRequest;
import com.google.tsunami.proto.RunResponse;
import com.google.tsunami.proto.TargetInfo;
Expand Down Expand Up @@ -174,6 +175,62 @@ public void listPlugins(
assertThat(pluginToTest.getAllPlugins()).containsExactly(plugin);
}

@Test
public void getAllPlugins_withCompactRunRequest_callsRunCompact() throws Exception {
registerHealthCheckWithStatus(ServingStatus.SERVING);

var targetInfo =
TargetInfo.newBuilder()
.addNetworkEndpoints(NetworkEndpointUtils.forIpAndPort("1.1.1.1", 80))
.build();
var someNetworkService = NetworkService.getDefaultInstance();
var expectedDetectionReport =
DetectionReport.newBuilder()
.setTargetInfo(targetInfo)
.setNetworkService(someNetworkService)
.build();

var plugin = createSinglePluginDefinitionWithName("test");
RemoteVulnDetector pluginToTest = getNewRemoteVulnDetectorInstance();
serviceRegistry.addService(
new PluginServiceImplBase() {
@Override
public void listPlugins(
ListPluginsRequest request, StreamObserver<ListPluginsResponse> responseObserver) {
responseObserver.onNext(
ListPluginsResponse.newBuilder()
.setWantCompactRunRequest(true)
.addPlugins(plugin)
.build());
responseObserver.onCompleted();
}

@Override
public void run(RunRequest request, StreamObserver<RunResponse> responseObserver) {
responseObserver.onError(new Exception("run should not be called"));
}

@Override
public void runCompact(
RunCompactRequest request, StreamObserver<RunResponse> responseObserver) {
responseObserver.onNext(
RunResponse.newBuilder()
.setReports(
DetectionReportList.newBuilder()
.addDetectionReports(expectedDetectionReport))
.build());
responseObserver.onCompleted();
}
});

assertThat(pluginToTest.getAllPlugins()).containsExactly(plugin);
assertThat(
pluginToTest
.detect(targetInfo, ImmutableList.of(someNetworkService))
.getDetectionReportsList())
.containsExactly(expectedDetectionReport);
}

@Test
public void getAllPlugins_withNonServingServer_returnsEmptyList() throws Exception {
registerHealthCheckWithStatus(ServingStatus.NOT_SERVING);
Expand Down
Loading

0 comments on commit 67f5d79

Please sign in to comment.