From 12452cb7b92f514b9f18fb2a715c97c701f097b5 Mon Sep 17 00:00:00 2001 From: Anqi Date: Fri, 15 Sep 2023 14:04:34 +0800 Subject: [PATCH] support to set headers for http2 (#548) --- .../nebula/client/graph/NebulaPoolConfig.java | 14 ++++++++ .../nebula/client/graph/SessionPool.java | 9 +++--- .../client/graph/SessionPoolConfig.java | 16 +++++++++- .../client/graph/net/ConnObjectPool.java | 5 +-- .../nebula/client/graph/net/Connection.java | 6 ++-- .../nebula/client/graph/net/NebulaPool.java | 4 +-- .../graph/net/RoundRobinLoadBalancer.java | 20 ++++++++---- .../client/graph/net/SyncConnection.java | 32 +++++++++++++------ 8 files changed, 77 insertions(+), 29 deletions(-) diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/NebulaPoolConfig.java b/client/src/main/java/com/vesoft/nebula/client/graph/NebulaPoolConfig.java index 0189366a5..c76b2edb6 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/NebulaPoolConfig.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/NebulaPoolConfig.java @@ -7,6 +7,8 @@ import com.vesoft.nebula.client.graph.data.SSLParam; import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; public class NebulaPoolConfig implements Serializable { @@ -45,6 +47,9 @@ public class NebulaPoolConfig implements Serializable { // Set if use http2 protocol private boolean useHttp2 = false; + // Set custom headers for http2 + private Map customHeaders = new HashMap<>(); + public boolean isEnableSsl() { return enableSsl; } @@ -132,4 +137,13 @@ public NebulaPoolConfig setUseHttp2(boolean useHttp2) { this.useHttp2 = useHttp2; return this; } + + public Map getCustomHeaders() { + return customHeaders; + } + + public NebulaPoolConfig setCustomHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } } diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/SessionPool.java b/client/src/main/java/com/vesoft/nebula/client/graph/SessionPool.java index 0daeef8ad..f9e4abdf0 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/SessionPool.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/SessionPool.java @@ -5,9 +5,6 @@ package com.vesoft.nebula.client.graph; -import static com.vesoft.nebula.client.graph.exception.IOErrorException.E_CONNECT_BROKEN; -import static com.vesoft.nebula.client.graph.exception.IOErrorException.E_UNKNOWN; - import com.alibaba.fastjson.JSON; import com.vesoft.nebula.ErrorCode; import com.vesoft.nebula.client.graph.data.HostAddress; @@ -353,10 +350,12 @@ private NebulaSession createSessionObject(SessionState state) if (sessionPoolConfig.isEnableSsl()) { connection.open(getAddress(), sessionPoolConfig.getTimeout(), sessionPoolConfig.getSslParam(), - sessionPoolConfig.isUseHttp2()); + sessionPoolConfig.isUseHttp2(), + sessionPoolConfig.getCustomHeaders()); } else { connection.open(getAddress(), sessionPoolConfig.getTimeout(), - sessionPoolConfig.isUseHttp2()); + sessionPoolConfig.isUseHttp2(), + sessionPoolConfig.getCustomHeaders()); } break; } catch (Exception e) { diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/SessionPoolConfig.java b/client/src/main/java/com/vesoft/nebula/client/graph/SessionPoolConfig.java index 5a9fb3fcf..f5a66875a 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/SessionPoolConfig.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/SessionPoolConfig.java @@ -8,7 +8,9 @@ import com.vesoft.nebula.client.graph.data.HostAddress; import com.vesoft.nebula.client.graph.data.SSLParam; import java.io.Serializable; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class SessionPoolConfig implements Serializable { @@ -59,6 +61,8 @@ public class SessionPoolConfig implements Serializable { private boolean useHttp2 = false; + private Map customHeaders = new HashMap<>(); + public SessionPoolConfig(List addresses, String spaceName, @@ -243,6 +247,15 @@ public SessionPoolConfig setUseHttp2(boolean useHttp2) { return this; } + public Map getCustomHeaders() { + return customHeaders; + } + + public SessionPoolConfig setCustomHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + return this; + } + @Override public String toString() { return "SessionPoolConfig{" @@ -259,8 +272,9 @@ public String toString() { + ", intervalTIme=" + intervalTime + ", reconnect=" + reconnect + ", enableSsl=" + enableSsl - + ",sslParam=" + sslParam + + ", sslParam=" + sslParam + ", useHttp2=" + useHttp2 + + ", customHeaders=" + customHeaders + '}'; } } diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/ConnObjectPool.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/ConnObjectPool.java index dba3a02b9..d9e215fe5 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/ConnObjectPool.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/ConnObjectPool.java @@ -40,9 +40,10 @@ public SyncConnection create() throws IOErrorException, ClientServerIncompatible + "is set to true"); } conn.open(address, config.getTimeout(), - config.getSslParam(), config.isUseHttp2()); + config.getSslParam(), config.isUseHttp2(), config.getCustomHeaders()); } else { - conn.open(address, config.getTimeout(), config.isUseHttp2()); + conn.open(address, config.getTimeout(), + config.isUseHttp2(), config.getCustomHeaders()); } return conn; } catch (IOErrorException e) { diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/Connection.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/Connection.java index 47abf3524..bd871f530 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/Connection.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/Connection.java @@ -5,6 +5,7 @@ import com.vesoft.nebula.client.graph.exception.ClientServerIncompatibleException; import com.vesoft.nebula.client.graph.exception.IOErrorException; import java.io.Serializable; +import java.util.Map; public abstract class Connection implements Serializable { @@ -20,14 +21,15 @@ public abstract void open(HostAddress address, int timeout, SSLParam sslParam) throws IOErrorException, ClientServerIncompatibleException; public abstract void open(HostAddress address, int timeout, - SSLParam sslParam, boolean isUseHttp2) + SSLParam sslParam, boolean isUseHttp2, Map headers) throws IOErrorException, ClientServerIncompatibleException; public abstract void open(HostAddress address, int timeout) throws IOErrorException, ClientServerIncompatibleException; - public abstract void open(HostAddress address, int timeout, boolean isUseHttp2) + public abstract void open(HostAddress address, int timeout, + boolean isUseHttp2, Map headers) throws IOErrorException, ClientServerIncompatibleException; public abstract void reopen() throws IOErrorException, ClientServerIncompatibleException; diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/NebulaPool.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/NebulaPool.java index 6090da72b..882023601 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/NebulaPool.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/NebulaPool.java @@ -86,9 +86,9 @@ public boolean init(List addresses, NebulaPoolConfig config) this.waitTime = config.getWaitTime(); this.loadBalancer = config.isEnableSsl() ? new RoundRobinLoadBalancer(addresses, config.getTimeout(), config.getSslParam(), - config.getMinClusterHealthRate(), config.isUseHttp2()) + config.getMinClusterHealthRate(), config.isUseHttp2(), config.getCustomHeaders()) : new RoundRobinLoadBalancer(addresses, config.getTimeout(), - config.getMinClusterHealthRate(),config.isUseHttp2()); + config.getMinClusterHealthRate(),config.isUseHttp2(), config.getCustomHeaders()); ConnObjectPool objectPool = new ConnObjectPool(this.loadBalancer, config); this.objectPool = new GenericObjectPool<>(objectPool); GenericObjectPoolConfig objConfig = new GenericObjectPoolConfig(); diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/RoundRobinLoadBalancer.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/RoundRobinLoadBalancer.java index c4da59766..04f77a919 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/RoundRobinLoadBalancer.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/RoundRobinLoadBalancer.java @@ -5,6 +5,7 @@ import com.vesoft.nebula.client.graph.exception.ClientServerIncompatibleException; import com.vesoft.nebula.client.graph.exception.IOErrorException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -31,13 +32,16 @@ public class RoundRobinLoadBalancer implements LoadBalancer { private boolean useHttp2 = false; + private Map customHeaders; + public RoundRobinLoadBalancer(List addresses, int timeout, double minClusterHealthRate) { - this(addresses, timeout, minClusterHealthRate, false); + this(addresses, timeout, minClusterHealthRate, false, new HashMap<>()); } public RoundRobinLoadBalancer(List addresses, int timeout, - double minClusterHealthRate, boolean useHttp2) { + double minClusterHealthRate, boolean useHttp2, + Map headers) { this.timeout = timeout; for (HostAddress addr : addresses) { this.addresses.add(addr); @@ -45,17 +49,19 @@ public RoundRobinLoadBalancer(List addresses, int timeout, } this.minClusterHealthRate = minClusterHealthRate; this.useHttp2 = useHttp2; + this.customHeaders = headers; schedule.scheduleAtFixedRate(this::scheduleTask, 0, delayTime, TimeUnit.SECONDS); } public RoundRobinLoadBalancer(List addresses, int timeout, SSLParam sslParam, double minClusterHealthRate) { - this(addresses, timeout, sslParam, minClusterHealthRate, false); + this(addresses, timeout, sslParam, minClusterHealthRate, false, new HashMap<>()); } public RoundRobinLoadBalancer(List addresses, int timeout, SSLParam sslParam, - double minClusterHealthRate, boolean useHttp2) { - this(addresses, timeout, minClusterHealthRate, useHttp2); + double minClusterHealthRate, boolean useHttp2, + Map headers) { + this(addresses, timeout, minClusterHealthRate, useHttp2, headers); this.sslParam = sslParam; this.enabledSsl = true; } @@ -95,9 +101,9 @@ public boolean ping(HostAddress addr) { try { Connection connection = new SyncConnection(); if (enabledSsl) { - connection.open(addr, this.timeout, sslParam, useHttp2); + connection.open(addr, this.timeout, sslParam, useHttp2, customHeaders); } else { - connection.open(addr, this.timeout, useHttp2); + connection.open(addr, this.timeout, useHttp2, customHeaders); } boolean pong = connection.ping(); connection.close(); diff --git a/client/src/main/java/com/vesoft/nebula/client/graph/net/SyncConnection.java b/client/src/main/java/com/vesoft/nebula/client/graph/net/SyncConnection.java index 28a1ba027..dc339ff90 100644 --- a/client/src/main/java/com/vesoft/nebula/client/graph/net/SyncConnection.java +++ b/client/src/main/java/com/vesoft/nebula/client/graph/net/SyncConnection.java @@ -33,9 +33,11 @@ import com.vesoft.nebula.util.SslUtil; import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.TrustManager; +import okhttp3.internal.http2.Http2Connection; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -53,14 +55,17 @@ public class SyncConnection extends Connection { private SSLSocketFactory sslSocketFactory = null; private boolean useHttp2 = false; + private Map headers = new HashMap<>(); + @Override public void open(HostAddress address, int timeout, SSLParam sslParam) throws IOErrorException, ClientServerIncompatibleException { - this.open(address, timeout, sslParam, false); + this.open(address, timeout, sslParam, false, headers); } @Override - public void open(HostAddress address, int timeout, SSLParam sslParam, boolean isUseHttp2) + public void open(HostAddress address, int timeout, SSLParam sslParam, boolean isUseHttp2, + Map headers) throws IOErrorException, ClientServerIncompatibleException { try { this.serverAddr = address; @@ -68,6 +73,7 @@ public void open(HostAddress address, int timeout, SSLParam sslParam, boolean is this.enabledSsl = true; this.sslParam = sslParam; this.useHttp2 = isUseHttp2; + this.headers = headers; if (sslSocketFactory == null) { if (sslParam.getSignMode() == SSLParam.SignMode.CA_SIGNED) { sslSocketFactory = @@ -77,7 +83,7 @@ public void open(HostAddress address, int timeout, SSLParam sslParam, boolean is SslUtil.getSSLSocketFactoryWithoutCA((SelfSignedSSLParam) sslParam); } } - if (isUseHttp2) { + if (useHttp2) { getProtocolWithTlsHttp2(); } else { getProtocolForTls(); @@ -102,16 +108,19 @@ public void open(HostAddress address, int timeout, SSLParam sslParam, boolean is @Override public void open(HostAddress address, int timeout) throws IOErrorException, ClientServerIncompatibleException { - this.open(address, timeout, false); + this.open(address, timeout, false, headers); } @Override - public void open(HostAddress address, int timeout, boolean isUseHttp2) + public void open(HostAddress address, int timeout, + boolean isUseHttp2, Map headers) throws IOErrorException, ClientServerIncompatibleException { try { this.serverAddr = address; this.timeout = timeout <= 0 ? Integer.MAX_VALUE : timeout; - if (isUseHttp2) { + this.useHttp2 = isUseHttp2; + this.headers = headers; + if (useHttp2) { getProtocolForHttp2(); } else { getProtocol(); @@ -144,7 +153,9 @@ private void getProtocolWithTlsHttp2() { } this.transport = new THttp2Client(url, sslSocketFactory, trustManager) .setConnectTimeout(timeout) - .setReadTimeout(timeout); + .setReadTimeout(timeout) + .setCustomHeaders(headers); + transport.open(); this.protocol = new TBinaryProtocol(transport); } @@ -166,7 +177,8 @@ private void getProtocolForHttp2() { String url = "http://" + serverAddr.getHost() + ":" + serverAddr.getPort(); this.transport = new THttp2Client(url) .setConnectTimeout(timeout) - .setReadTimeout(timeout); + .setReadTimeout(timeout) + .setCustomHeaders(headers); transport.open(); this.protocol = new TBinaryProtocol(transport); } @@ -196,9 +208,9 @@ private void getProtocol() { public void reopen() throws IOErrorException, ClientServerIncompatibleException { close(); if (enabledSsl) { - open(serverAddr, timeout, sslParam, useHttp2); + open(serverAddr, timeout, sslParam, useHttp2, headers); } else { - open(serverAddr, timeout, useHttp2); + open(serverAddr, timeout, useHttp2, headers); } }