Skip to content

Commit

Permalink
Add python server cli option to Tsunami main cli, fix setup script an…
Browse files Browse the repository at this point in the history
…d upgrade guice and mokito versions.

PiperOrigin-RevId: 636573428
Change-Id: Ib105fb5459d6081bb9b45d569df989e9f1e36d36
  • Loading branch information
maoning authored and copybara-github committed May 23, 2024
1 parent d9bc309 commit 41097b5
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 37 deletions.
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ subprojects {
googleCloudStorageVersion = '1.103.1'
googleHttpClientVersion = '1.44.1'
guavaVersion = '28.2-jre'
guiceVersion = '4.2.3'
guiceVersion = '6.0.0'
grpcVersion = '1.60.0'
gsonVersion = '2.8.6'
jaxbVersion = '2.3.1'
Expand All @@ -47,7 +47,7 @@ subprojects {
protocVersion = protobufVersion
snakeyamlVersion = '1.26'
junitVersion = '4.13'
mockitoVersion = '2.28.2'
mockitoVersion = '5.12.0'
truthVersion = '1.4.0'
tcsVersion = '0.0.1'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

import com.google.auto.value.AutoValue;
import java.time.Duration;
import javax.annotation.Nullable;

/** Command to spawn a language server and associated command lines. */
@AutoValue
public abstract class LanguageServerCommand {
public static LanguageServerCommand create(
String serverCommand,
@Nullable String serverAddress,
String port,
String logId,
String outputDir,
Expand All @@ -33,6 +35,7 @@ public static LanguageServerCommand create(
String pollingUri) {
return new AutoValue_LanguageServerCommand(
serverCommand,
serverAddress,
port,
logId,
outputDir,
Expand All @@ -45,6 +48,8 @@ public static LanguageServerCommand create(

public abstract String serverCommand();

public abstract String serverAddress();

public abstract String port();

public abstract String logId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ public final class LanguageServerOptions implements CliOption {
+ " chosen.")
public List<String> pluginServerPorts;

@Parameter(
names = "--python-plugin-server-address",
description = "The address for python language server.")
public String pythonPluginServerAddress;

@Parameter(
names = "--python-plugin-server-port",
description = "The port of the python plugin server to open connection with.")
public Integer pythonPluginServerPort;

@Override
public void validate() {
if (pluginServerFilenames != null || pluginServerPorts != null) {
Expand Down Expand Up @@ -81,5 +91,15 @@ public void validate() {
pathCounts, portCounts));
}
}

if (pythonPluginServerAddress != null) {
if (!(pythonPluginServerPort <= NetworkEndpointUtils.MAX_PORT_NUMBER
&& pythonPluginServerPort > 0)) {
throw new ParameterException(
String.format(
"Python plugin server port out of range. Expected [0, %s], actual %s.",
NetworkEndpointUtils.MAX_PORT_NUMBER, pythonPluginServerPort));
}
}
}
}
91 changes: 77 additions & 14 deletions main/src/main/java/com/google/tsunami/main/cli/TsunamiCli.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.beust.jcommander.ParameterException;
import com.google.common.base.Splitter;
import com.google.common.base.Stopwatch;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -183,24 +184,41 @@ private ImmutableList<LanguageServerCommand> extractPluginServerArgs(
Boolean trustAllSslCertCli = extractCliTrustAllSslCert(args);
var paths = extractCliPluginServerArgs(args, "--plugin-server-paths=");
var ports = extractCliPluginServerArgs(args, "--plugin-server-ports=");
var pythonServerAddress = extractPythonPluginServerAddress(args);
var pythonServerPort = extractPythonPluginServerPort(args);
if (paths.size() != ports.size()) {
throw new ParameterException(
String.format(
"Number of plugin server paths must be equal to number of plugin server ports."
+ " Paths: %s. Ports: %s.",
paths.size(), ports.size()));
}
if (paths.size() == 0) {
if (paths.isEmpty() && Strings.isNullOrEmpty(pythonServerAddress)) {
return ImmutableList.of();
}
if (tsunamiConfig.getRawConfigData().isEmpty()) {
for (int i = 0; i < paths.size(); ++i) {
commands.add(
LanguageServerCommand.create(
paths.get(i),
"",
ports.get(i),
logId,
extractOutputDir(args),
trustAllSslCertCli != null && trustAllSslCertCli.booleanValue(),
Duration.ZERO,
"",
0,
""));
}
if (!Strings.isNullOrEmpty(pythonServerAddress)) {
commands.add(
LanguageServerCommand.create(
"",
pythonServerAddress,
pythonServerPort,
logId,
extractOutputDir(args),
trustAllSslCertCli != null && trustAllSslCertCli.booleanValue(),
Duration.ZERO,
"",
Expand All @@ -219,6 +237,7 @@ private ImmutableList<LanguageServerCommand> extractPluginServerArgs(
commands.add(
LanguageServerCommand.create(
paths.get(i),
"",
ports.get(i),
logId,
extractOutputDir(args),
Expand All @@ -230,16 +249,32 @@ private ImmutableList<LanguageServerCommand> extractPluginServerArgs(
(Integer) ((Map) callbackConfig).get("callback_port"),
(String) ((Map) callbackConfig).get("polling_uri")));
}
if (!Strings.isNullOrEmpty(pythonServerAddress)) {
commands.add(
LanguageServerCommand.create(
"",
pythonServerAddress,
pythonServerPort,
logId,
extractOutputDir(args),
trustAllSslCertCli == null
? trustAllSslCertConfig
: trustAllSslCertCli.booleanValue(),
Duration.ofSeconds((int) ((Map) httpClientConfig).get("connect_timeout_seconds")),
(String) ((Map) callbackConfig).get("callback_address"),
(Integer) ((Map) callbackConfig).get("callback_port"),
(String) ((Map) callbackConfig).get("polling_uri")));
}
return ImmutableList.copyOf(commands);
}
}

@Nullable
private Boolean extractCliTrustAllSslCert(String[] args) {
for (int i = 0; i < args.length; ++i) {
if (args[i].startsWith("--http-client-trust-all-certificates")) {
if (args[i].contains("=")) {
return Boolean.valueOf(Iterables.get(Splitter.on('=').split(args[i]), 1));
for (String arg : args) {
if (arg.startsWith("--http-client-trust-all-certificates")) {
if (arg.contains("=")) {
return Boolean.valueOf(Iterables.get(Splitter.on('=').split(arg), 1));
} else {
return true;
}
Expand All @@ -248,20 +283,48 @@ private Boolean extractCliTrustAllSslCert(String[] args) {
return null;
}

@Nullable
private String extractPythonPluginServerAddress(String[] args) {
for (String arg : args) {
if (arg.startsWith("--python-plugin-server-address")) {
if (arg.contains("=")) {
return Iterables.get(Splitter.on('=').split(arg), 1);
} else {
return null;
}
}
}
return null;
}

@Nullable
private String extractPythonPluginServerPort(String[] args) {
for (String arg : args) {
if (arg.startsWith("--python-plugin-server-port")) {
if (arg.contains("=")) {
return Iterables.get(Splitter.on('=').split(arg), 1);
} else {
return null;
}
}
}
return null;
}

private String extractOutputDir(String[] args) {
for (int i = 0; i < args.length; ++i) {
if (args[i].startsWith("--scan-results-local-output-filename=")) {
String filename = Iterables.get(Splitter.on('=').split(args[i]), 1) + ": ";
for (String arg : args) {
if (arg.startsWith("--scan-results-local-output-filename=")) {
String filename = Iterables.get(Splitter.on('=').split(arg), 1) + ": ";
return Path.of(filename).getParent().toString();
}
}
return "";
}

private ImmutableList<String> extractCliPluginServerArgs(String[] args, String flag) {
for (int i = 0; i < args.length; ++i) {
if (args[i].startsWith(flag)) {
var count = Iterables.get(Splitter.on('=').split(args[i]), 1);
for (String arg : args) {
if (arg.startsWith(flag)) {
var count = Iterables.get(Splitter.on('=').split(arg), 1);
return ImmutableList.copyOf(Splitter.on(',').split(count));
}
}
Expand All @@ -270,9 +333,9 @@ private ImmutableList<String> extractCliPluginServerArgs(String[] args, String f

private String extractLogId(String[] args) {
// TODO(b/171405612): Use the Flag class instead of manual parsing.
for (int i = 0; i < args.length; ++i) {
if (args[i].startsWith("--log-id=")) {
return Iterables.get(Splitter.on('=').split(args[i]), 1) + ": ";
for (String arg : args) {
if (arg.startsWith("--log-id=")) {
return Iterables.get(Splitter.on('=').split(arg), 1) + ": ";
}
}
return "";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.flogger.GoogleLogger;
import com.google.tsunami.common.command.CommandExecutor;
Expand Down Expand Up @@ -46,6 +47,8 @@ public class RemoteServerLoader {
public ImmutableList<Process> runServerProcesses() {
logger.atInfo().log("Starting language server processes (if any)...");
return commands.stream()
// Filter out commands that don't need server start up
.filter(command -> !Strings.isNullOrEmpty(command.serverCommand()))
.map(
command ->
runProcess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,19 @@ public void validate_whenPortNumberOutOfRange_throwsParameterException() {
ParameterException.class,
options::validate);
}

@Test
public void validate_whenPythonPluginServerPortNumberOutOfRange_throwsParameterException() {
LanguageServerOptions options = new LanguageServerOptions();
options.pythonPluginServerAddress = "127.0.0.1";
options.pythonPluginServerPort = -1;

assertThrows(
"Python plugin server port out of range. Expected [0, "
+ NetworkEndpointUtils.MAX_PORT_NUMBER
+ "]"
+ ", actual -1",
ParameterException.class,
options::validate);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public void runServerProcess_whenPathExistsAndNormalPort_returnsValidProcessList
ImmutableList.of(
LanguageServerCommand.create(
"/bin/sh",
"",
"34567",
"34",
"/output-here",
Expand All @@ -50,6 +51,29 @@ public void runServerProcess_whenPathExistsAndNormalPort_returnsValidProcessList
assertThat(processList).hasSize(1);
assertThat(processList.get(0)).isNotNull();
}

@Test
public void runServerProcess_whenServerAddressExistsAndNormalPort_returnsEmptyProcessList() {
ImmutableList<LanguageServerCommand> commands =
ImmutableList.of(
LanguageServerCommand.create(
"",
"127.0.0.1",
"34567",
"34",
"/output-here",
false,
Duration.ofSeconds(10),
"157.34.0.2",
8080,
"157.34.0.2:8881"));

RemoteServerLoader loader =
Guice.createInjector(new RemoteServerLoaderModule(commands))
.getInstance(RemoteServerLoader.class);
var processList = loader.runServerProcesses();
assertThat(processList).isEmpty();
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.auto.value.AutoAnnotation;
import com.google.auto.value.AutoBuilder;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.inject.AbstractModule;
import com.google.inject.multibindings.MapBinder;
Expand Down Expand Up @@ -74,12 +75,21 @@ private ImmutableList<Channel> getLanguageServerChannels(
ImmutableList<LanguageServerCommand> commands) {
return commands.stream()
.map(
command ->
command -> {
if (Strings.isNullOrEmpty(command.serverCommand())) {
return NettyChannelBuilder.forTarget(
String.format("%s:%s", command.serverAddress(), command.port()))
.negotiationType(NegotiationType.PLAINTEXT)
.maxInboundMessageSize(MAX_MESSAGE_SIZE)
.build();
} else {
// TODO(b/289462738): Support IPv6 loopback (::1) interface
NettyChannelBuilder.forTarget("127.0.0.1:" + command.port())
return NettyChannelBuilder.forTarget("127.0.0.1:" + command.port())
.negotiationType(NegotiationType.PLAINTEXT)
.maxInboundMessageSize(MAX_MESSAGE_SIZE)
.build())
.build();
}
})
.collect(toImmutableList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ public void configure_always_loadsAllRemotePlugins() {
var path0 =
LanguageServerCommand.create(
generateServerName(),
"",
"34567",
"193",
"/output/here",
Expand All @@ -66,6 +67,7 @@ public void configure_always_loadsAllRemotePlugins() {
var path1 =
LanguageServerCommand.create(
generateServerName(),
"",
"34566",
"193",
"/output/now",
Expand All @@ -74,10 +76,23 @@ public void configure_always_loadsAllRemotePlugins() {
"157.34.0.2",
8080,
"157.34.0.2:8881");
var server0 =
LanguageServerCommand.create(
"",
"127.0.0.1",
"34567",
"193",
"/output/here",
false,
Duration.ofSeconds(10),
"157.34.0.2",
8080,
"157.34.0.2:8881");
Map<PluginDefinition, TsunamiPlugin> remotePlugins =
Guice.createInjector(new RemoteVulnDetectorLoadingModule(ImmutableList.of(path0, path1)))
Guice.createInjector(
new RemoteVulnDetectorLoadingModule(ImmutableList.of(path0, path1, server0)))
.getInstance(PLUGIN_BINDING_KEY);

assertThat(remotePlugins).hasSize(2);
assertThat(remotePlugins).hasSize(3);
}
}
Loading

0 comments on commit 41097b5

Please sign in to comment.