Skip to content

Commit

Permalink
support to set headers for http2 (#548)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicole00 authored Sep 15, 2023
1 parent 20d0c76 commit 12452cb
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import com.vesoft.nebula.client.graph.data.SSLParam;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

public class NebulaPoolConfig implements Serializable {

Expand Down Expand Up @@ -45,6 +47,9 @@ public class NebulaPoolConfig implements Serializable {
// Set if use http2 protocol
private boolean useHttp2 = false;

// Set custom headers for http2
private Map<String,String> customHeaders = new HashMap<>();

public boolean isEnableSsl() {
return enableSsl;
}
Expand Down Expand Up @@ -132,4 +137,13 @@ public NebulaPoolConfig setUseHttp2(boolean useHttp2) {
this.useHttp2 = useHttp2;
return this;
}

public Map<String, String> getCustomHeaders() {
return customHeaders;
}

public NebulaPoolConfig setCustomHeaders(Map<String, String> customHeaders) {
this.customHeaders = customHeaders;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

package com.vesoft.nebula.client.graph;

import static com.vesoft.nebula.client.graph.exception.IOErrorException.E_CONNECT_BROKEN;
import static com.vesoft.nebula.client.graph.exception.IOErrorException.E_UNKNOWN;

import com.alibaba.fastjson.JSON;
import com.vesoft.nebula.ErrorCode;
import com.vesoft.nebula.client.graph.data.HostAddress;
Expand Down Expand Up @@ -353,10 +350,12 @@ private NebulaSession createSessionObject(SessionState state)
if (sessionPoolConfig.isEnableSsl()) {
connection.open(getAddress(), sessionPoolConfig.getTimeout(),
sessionPoolConfig.getSslParam(),
sessionPoolConfig.isUseHttp2());
sessionPoolConfig.isUseHttp2(),
sessionPoolConfig.getCustomHeaders());
} else {
connection.open(getAddress(), sessionPoolConfig.getTimeout(),
sessionPoolConfig.isUseHttp2());
sessionPoolConfig.isUseHttp2(),
sessionPoolConfig.getCustomHeaders());
}
break;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SessionPoolConfig implements Serializable {

Expand Down Expand Up @@ -59,6 +61,8 @@ public class SessionPoolConfig implements Serializable {

private boolean useHttp2 = false;

private Map<String, String> customHeaders = new HashMap<>();


public SessionPoolConfig(List<HostAddress> addresses,
String spaceName,
Expand Down Expand Up @@ -243,6 +247,15 @@ public SessionPoolConfig setUseHttp2(boolean useHttp2) {
return this;
}

public Map<String, String> getCustomHeaders() {
return customHeaders;
}

public SessionPoolConfig setCustomHeaders(Map<String, String> customHeaders) {
this.customHeaders = customHeaders;
return this;
}

@Override
public String toString() {
return "SessionPoolConfig{"
Expand All @@ -259,8 +272,9 @@ public String toString() {
+ ", intervalTIme=" + intervalTime
+ ", reconnect=" + reconnect
+ ", enableSsl=" + enableSsl
+ ",sslParam=" + sslParam
+ ", sslParam=" + sslParam
+ ", useHttp2=" + useHttp2
+ ", customHeaders=" + customHeaders
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ public SyncConnection create() throws IOErrorException, ClientServerIncompatible
+ "is set to true");
}
conn.open(address, config.getTimeout(),
config.getSslParam(), config.isUseHttp2());
config.getSslParam(), config.isUseHttp2(), config.getCustomHeaders());
} else {
conn.open(address, config.getTimeout(), config.isUseHttp2());
conn.open(address, config.getTimeout(),
config.isUseHttp2(), config.getCustomHeaders());
}
return conn;
} catch (IOErrorException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.vesoft.nebula.client.graph.exception.ClientServerIncompatibleException;
import com.vesoft.nebula.client.graph.exception.IOErrorException;
import java.io.Serializable;
import java.util.Map;

public abstract class Connection implements Serializable {

Expand All @@ -20,14 +21,15 @@ public abstract void open(HostAddress address, int timeout, SSLParam sslParam)
throws IOErrorException, ClientServerIncompatibleException;

public abstract void open(HostAddress address, int timeout,
SSLParam sslParam, boolean isUseHttp2)
SSLParam sslParam, boolean isUseHttp2, Map<String, String> headers)
throws IOErrorException, ClientServerIncompatibleException;


public abstract void open(HostAddress address, int timeout) throws IOErrorException,
ClientServerIncompatibleException;

public abstract void open(HostAddress address, int timeout, boolean isUseHttp2)
public abstract void open(HostAddress address, int timeout,
boolean isUseHttp2, Map<String, String> headers)
throws IOErrorException, ClientServerIncompatibleException;

public abstract void reopen() throws IOErrorException, ClientServerIncompatibleException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ public boolean init(List<HostAddress> addresses, NebulaPoolConfig config)
this.waitTime = config.getWaitTime();
this.loadBalancer = config.isEnableSsl()
? new RoundRobinLoadBalancer(addresses, config.getTimeout(), config.getSslParam(),
config.getMinClusterHealthRate(), config.isUseHttp2())
config.getMinClusterHealthRate(), config.isUseHttp2(), config.getCustomHeaders())
: new RoundRobinLoadBalancer(addresses, config.getTimeout(),
config.getMinClusterHealthRate(),config.isUseHttp2());
config.getMinClusterHealthRate(),config.isUseHttp2(), config.getCustomHeaders());
ConnObjectPool objectPool = new ConnObjectPool(this.loadBalancer, config);
this.objectPool = new GenericObjectPool<>(objectPool);
GenericObjectPoolConfig objConfig = new GenericObjectPoolConfig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.vesoft.nebula.client.graph.exception.ClientServerIncompatibleException;
import com.vesoft.nebula.client.graph.exception.IOErrorException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -31,31 +32,36 @@ public class RoundRobinLoadBalancer implements LoadBalancer {

private boolean useHttp2 = false;

private Map<String, String> customHeaders;

public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout,
double minClusterHealthRate) {
this(addresses, timeout, minClusterHealthRate, false);
this(addresses, timeout, minClusterHealthRate, false, new HashMap<>());
}

public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout,
double minClusterHealthRate, boolean useHttp2) {
double minClusterHealthRate, boolean useHttp2,
Map<String, String> headers) {
this.timeout = timeout;
for (HostAddress addr : addresses) {
this.addresses.add(addr);
this.serversStatus.put(addr, S_BAD);
}
this.minClusterHealthRate = minClusterHealthRate;
this.useHttp2 = useHttp2;
this.customHeaders = headers;
schedule.scheduleAtFixedRate(this::scheduleTask, 0, delayTime, TimeUnit.SECONDS);
}

public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout, SSLParam sslParam,
double minClusterHealthRate) {
this(addresses, timeout, sslParam, minClusterHealthRate, false);
this(addresses, timeout, sslParam, minClusterHealthRate, false, new HashMap<>());
}

public RoundRobinLoadBalancer(List<HostAddress> addresses, int timeout, SSLParam sslParam,
double minClusterHealthRate, boolean useHttp2) {
this(addresses, timeout, minClusterHealthRate, useHttp2);
double minClusterHealthRate, boolean useHttp2,
Map<String, String> headers) {
this(addresses, timeout, minClusterHealthRate, useHttp2, headers);
this.sslParam = sslParam;
this.enabledSsl = true;
}
Expand Down Expand Up @@ -95,9 +101,9 @@ public boolean ping(HostAddress addr) {
try {
Connection connection = new SyncConnection();
if (enabledSsl) {
connection.open(addr, this.timeout, sslParam, useHttp2);
connection.open(addr, this.timeout, sslParam, useHttp2, customHeaders);
} else {
connection.open(addr, this.timeout, useHttp2);
connection.open(addr, this.timeout, useHttp2, customHeaders);
}
boolean pong = connection.ping();
connection.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
import com.vesoft.nebula.util.SslUtil;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import okhttp3.internal.http2.Http2Connection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -53,21 +55,25 @@ public class SyncConnection extends Connection {
private SSLSocketFactory sslSocketFactory = null;
private boolean useHttp2 = false;

private Map<String, String> headers = new HashMap<>();

@Override
public void open(HostAddress address, int timeout, SSLParam sslParam)
throws IOErrorException, ClientServerIncompatibleException {
this.open(address, timeout, sslParam, false);
this.open(address, timeout, sslParam, false, headers);
}

@Override
public void open(HostAddress address, int timeout, SSLParam sslParam, boolean isUseHttp2)
public void open(HostAddress address, int timeout, SSLParam sslParam, boolean isUseHttp2,
Map<String, String> headers)
throws IOErrorException, ClientServerIncompatibleException {
try {
this.serverAddr = address;
this.timeout = timeout <= 0 ? Integer.MAX_VALUE : timeout;
this.enabledSsl = true;
this.sslParam = sslParam;
this.useHttp2 = isUseHttp2;
this.headers = headers;
if (sslSocketFactory == null) {
if (sslParam.getSignMode() == SSLParam.SignMode.CA_SIGNED) {
sslSocketFactory =
Expand All @@ -77,7 +83,7 @@ public void open(HostAddress address, int timeout, SSLParam sslParam, boolean is
SslUtil.getSSLSocketFactoryWithoutCA((SelfSignedSSLParam) sslParam);
}
}
if (isUseHttp2) {
if (useHttp2) {
getProtocolWithTlsHttp2();
} else {
getProtocolForTls();
Expand All @@ -102,16 +108,19 @@ public void open(HostAddress address, int timeout, SSLParam sslParam, boolean is
@Override
public void open(HostAddress address, int timeout) throws IOErrorException,
ClientServerIncompatibleException {
this.open(address, timeout, false);
this.open(address, timeout, false, headers);
}

@Override
public void open(HostAddress address, int timeout, boolean isUseHttp2)
public void open(HostAddress address, int timeout,
boolean isUseHttp2, Map<String,String> headers)
throws IOErrorException, ClientServerIncompatibleException {
try {
this.serverAddr = address;
this.timeout = timeout <= 0 ? Integer.MAX_VALUE : timeout;
if (isUseHttp2) {
this.useHttp2 = isUseHttp2;
this.headers = headers;
if (useHttp2) {
getProtocolForHttp2();
} else {
getProtocol();
Expand Down Expand Up @@ -144,7 +153,9 @@ private void getProtocolWithTlsHttp2() {
}
this.transport = new THttp2Client(url, sslSocketFactory, trustManager)
.setConnectTimeout(timeout)
.setReadTimeout(timeout);
.setReadTimeout(timeout)
.setCustomHeaders(headers);

transport.open();
this.protocol = new TBinaryProtocol(transport);
}
Expand All @@ -166,7 +177,8 @@ private void getProtocolForHttp2() {
String url = "http://" + serverAddr.getHost() + ":" + serverAddr.getPort();
this.transport = new THttp2Client(url)
.setConnectTimeout(timeout)
.setReadTimeout(timeout);
.setReadTimeout(timeout)
.setCustomHeaders(headers);
transport.open();
this.protocol = new TBinaryProtocol(transport);
}
Expand Down Expand Up @@ -196,9 +208,9 @@ private void getProtocol() {
public void reopen() throws IOErrorException, ClientServerIncompatibleException {
close();
if (enabledSsl) {
open(serverAddr, timeout, sslParam, useHttp2);
open(serverAddr, timeout, sslParam, useHttp2, headers);
} else {
open(serverAddr, timeout, useHttp2);
open(serverAddr, timeout, useHttp2, headers);
}
}

Expand Down

0 comments on commit 12452cb

Please sign in to comment.