Skip to content

Commit

Permalink
Merge pull request #42 from alan-turing-institute/blackvoid-tls
Browse files Browse the repository at this point in the history
Add TLS support
  • Loading branch information
jemrobinson authored May 30, 2024
2 parents ff7ed85 + 9b9011d commit d66a77a
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 31 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ RUN chmod ugo+x ./entrypoint.sh

# Open appropriate ports
EXPOSE 1389
EXPOSE 1636

# Run the server
ENTRYPOINT ["./entrypoint.sh"]
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ from the `docker` directory.
You can use a Redis server to store generated `uidNumber` and `gidNumber` values in a more persistent way.
To do this, you will need to provide the `--redis-host` and `--redis-port` arguments to `run.py`.

### Configure background refresh [Optional]

By default Apricot will refresh the LDAP tree whenever it is accessed and it contains data older than 60 seconds.
If it takes a long time to fetch all users and groups, or you want to ensure that each request gets a prompt response, you may want to configure background refresh to have it periodically be refreshed in the background.

This is enabled with the `--background-refresh` flag, which uses the `--refresh-interval` parameter as the interval to refresh the ldap database.

### Using TLS [Optional]

You can set up a TLS listener to communicate with encryption enabled over the configured port.
To enable it you need to configure the tls port ex. `--tls-port=1636`, and provide a path to the pem files for the certificate `--tls-certificate=<path>` and the private key `--tls-private-key=<path>`.

## Outputs

This will create an LDAP tree that looks like this:
Expand Down
41 changes: 37 additions & 4 deletions apricot/apricot_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import sys
from typing import Any, cast

from twisted.internet import reactor
from twisted.internet.endpoints import serverFromString
from twisted.internet import reactor, task
from twisted.internet.endpoints import quoteStringArgument, serverFromString
from twisted.internet.interfaces import IReactorCore, IStreamServerEndpoint
from twisted.python import log

Expand All @@ -21,10 +21,15 @@ def __init__(
domain: str,
port: int,
*,
background_refresh: bool = False,
debug: bool = False,
enable_mirrored_groups: bool = True,
redis_host: str | None = None,
redis_port: int | None = None,
refresh_interval: int = 60,
tls_port: int | None = None,
tls_certificate: str | None = None,
tls_private_key: str | None = None,
**kwargs: Any,
) -> None:
self.debug = debug
Expand Down Expand Up @@ -66,15 +71,43 @@ def __init__(
if self.debug:
log.msg("Creating an LDAPServerFactory.")
factory = OAuthLDAPServerFactory(
domain, oauth_client, enable_mirrored_groups=enable_mirrored_groups
domain,
oauth_client,
background_refresh=background_refresh,
enable_mirrored_groups=enable_mirrored_groups,
refresh_interval=refresh_interval,
)

if background_refresh:
if self.debug:
log.msg(
f"Starting background refresh (interval={factory.adaptor.refresh_interval})"
)
loop = task.LoopingCall(factory.adaptor.refresh)
loop.start(factory.adaptor.refresh_interval)

# Attach a listening endpoint
if self.debug:
log.msg("Attaching a listening endpoint.")
log.msg("Attaching a listening endpoint (plain).")
endpoint: IStreamServerEndpoint = serverFromString(reactor, f"tcp:{port}")
endpoint.listen(factory)

# Attach a listening endpoint
if tls_certificate or tls_private_key:
if not tls_certificate:
msg = "No TLS certificate provided. Please provide one with --tls-certificate or disable TLS."
raise ValueError(msg)
if not tls_private_key:
msg = "No TLS private key provided. Please provide one with --tls-private-key or disable TLS."
raise ValueError(msg)
if self.debug:
log.msg("Attaching a listening endpoint (TLS).")
ssl_endpoint: IStreamServerEndpoint = serverFromString(
reactor,
f"ssl:{tls_port}:privateKey={quoteStringArgument(tls_private_key)}:certKey={quoteStringArgument(tls_certificate)}",
)
ssl_endpoint.listen(factory)

# Load the Twisted reactor
self.reactor = cast(IReactorCore, reactor)

Expand Down
20 changes: 17 additions & 3 deletions apricot/ldap/oauth_ldap_server_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,30 @@

class OAuthLDAPServerFactory(ServerFactory):
def __init__(
self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool
self,
domain: str,
oauth_client: OAuthClient,
*,
background_refresh: bool,
enable_mirrored_groups: bool,
refresh_interval: int,
):
"""
Initialise an LDAPServerFactory
Initialise an OAuthLDAPServerFactory
@param background_refresh: Whether to refresh the LDAP tree in the background rather than on access
@param domain: The root domain of the LDAP tree
@param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users
@param oauth_client: An OAuth client used to construct the LDAP tree
@param refresh_interval: Interval in seconds after which the tree must be refreshed
"""
# Create an LDAP lookup tree
self.adaptor = OAuthLDAPTree(
domain, oauth_client, enable_mirrored_groups=enable_mirrored_groups
domain,
oauth_client,
background_refresh=background_refresh,
enable_mirrored_groups=enable_mirrored_groups,
refresh_interval=refresh_interval,
)

def __repr__(self) -> str:
Expand Down
19 changes: 16 additions & 3 deletions apricot/ldap/oauth_ldap_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,27 @@ def __init__(
domain: str,
oauth_client: OAuthClient,
*,
background_refresh: bool,
enable_mirrored_groups: bool,
refresh_interval: int = 60,
refresh_interval: int,
) -> None:
"""
Initialise an OAuthLDAPTree
@param background_refresh: Whether to refresh the LDAP tree in the background rather than on access
@param domain: The root domain of the LDAP tree
@param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users
@param oauth_client: An OAuth client used to construct the LDAP tree
@param refresh_interval: Interval in seconds after which the tree must be refreshed
"""
self.background_refresh = background_refresh
self.debug = oauth_client.debug
self.domain = domain
self.enable_mirrored_groups = enable_mirrored_groups
self.last_update = time.monotonic()
self.oauth_client = oauth_client
self.refresh_interval = refresh_interval
self.root_: OAuthLDAPEntry | None = None
self.enable_mirrored_groups = enable_mirrored_groups

@property
def dn(self) -> DistinguishedName:
Expand All @@ -46,7 +50,17 @@ def root(self) -> OAuthLDAPEntry:
Lazy-load the LDAP tree on request
@return: An OAuthLDAPEntry for the tree
@raises: ValueError.
"""
if not self.background_refresh:
self.refresh()
if not self.root_:
msg = "LDAP tree could not be loaded"
raise ValueError(msg)
return self.root_

def refresh(self) -> None:
if (
not self.root_
or (time.monotonic() - self.last_update) > self.refresh_interval
Expand Down Expand Up @@ -104,7 +118,6 @@ def root(self) -> OAuthLDAPEntry:
# Set last updated time
log.msg("Finished building LDAP tree.")
self.last_update = time.monotonic()
return self.root_

def __repr__(self) -> str:
return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}"
Expand Down
7 changes: 7 additions & 0 deletions apricot/oauth/oauth_data_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ class OAuthDataAdaptor:
def __init__(
self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool
):
"""
Initialise an OAuthDataAdaptor
@param domain: The root domain of the LDAP tree
@param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users
@param oauth_client: An OAuth client used to construct the LDAP tree
"""
self.debug = oauth_client.debug
self.oauth_client = oauth_client
self.root_dn = "DC=" + domain.replace(".", ",DC=")
Expand Down
1 change: 1 addition & 0 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ services:
REDIS_HOST: "redis"
ports:
- "1389:1389"
- "1636:1636"
restart: always

redis:
Expand Down
47 changes: 39 additions & 8 deletions docker/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ if [ -z "${DOMAIN}" ]; then
exit 1
fi


# Arguments with defaults
if [ -z "${PORT}" ]; then
PORT="1389"
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] PORT environment variable is not set: using default of '${PORT}'"
fi



# Optional arguments
EXTRA_OPTS=""
if [ -n "${DEBUG}" ]; then
Expand All @@ -41,18 +41,14 @@ if [ -n "${DISABLE_MIRRORED_GROUPS}" ]; then
EXTRA_OPTS="${EXTRA_OPTS} --disable-mirrored-groups"
fi


# Backend arguments: Entra
if [ -n "${ENTRA_TENANT_ID}" ]; then
EXTRA_OPTS="${EXTRA_OPTS} --entra-tenant-id $ENTRA_TENANT_ID"
fi

if [ -n "${REDIS_HOST}" ]; then
if [ -z "${REDIS_PORT}" ]; then
REDIS_PORT="6379"
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'"
fi
EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT"
fi

# Backend arguments: Keycloak
if [ -n "${KEYCLOAK_BASE_URL}" ]; then
if [ -z "${KEYCLOAK_REALM}" ]; then
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] KEYCLOAK_REALM environment variable is not set"
Expand All @@ -61,6 +57,41 @@ if [ -n "${KEYCLOAK_BASE_URL}" ]; then
EXTRA_OPTS="${EXTRA_OPTS} --keycloak-base-url $KEYCLOAK_BASE_URL --keycloak-realm $KEYCLOAK_REALM"
fi


# LDAP refresh arguments
if [ -n "${BACKGROUND_REFRESH}" ]; then
EXTRA_OPTS="${EXTRA_OPTS} --background-refresh"
fi

if [ -n "${REFRESH_INTERVAL}" ]; then
EXTRA_OPTS="${EXTRA_OPTS} --refresh-interval $REFRESH_INTERVAL"
fi


# TLS arguments
if [ -n "${TLS_PORT}" ]; then
if [ -z "${TLS_CERTIFICATE}" ]; then
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] TLS_CERTIFICATE environment variable is not set"
exit 1
fi
if [ -z "${TLS_PRIVATE_KEY}" ]; then
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] TLS_PRIVATE_KEY environment variable is not set"
exit 1
fi
EXTRA_OPTS="${EXTRA_OPTS} --tls-port $TLS_PORT --tls-certificate $TLS_CERTIFICATE --tls-private-key $TLS_PRIVATE_KEY"
fi


# Redis arguments
if [ -n "${REDIS_HOST}" ]; then
if [ -z "${REDIS_PORT}" ]; then
REDIS_PORT="6379"
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'"
fi
EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT"
fi


# Run the server
hatch run python run.py \
--backend "${BACKEND}" \
Expand Down
74 changes: 61 additions & 13 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,75 @@
description="Apricot is a proxy for delegating LDAP requests to an OpenID Connect backend.",
)
# Common options needed for all backends
parser.add_argument("-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use.")
parser.add_argument("-d", "--domain", type=str, help="Which domain users belong to.")
parser.add_argument(
"-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use."
)
parser.add_argument(
"-d", "--domain", type=str, help="Which domain users belong to."
)
parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.")
parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.")
parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.")
parser.add_argument("--disable-mirrored-groups", action="store_false",
dest="enable_mirrored_groups", default=True,
help="Disable creation of mirrored groups.")
parser.add_argument("--debug", action="store_true", help="Enable debug logging.")
parser.add_argument(
"-p", "--port", type=int, default=1389, help="Port to run on."
)
parser.add_argument(
"-s", "--client-secret", type=str, help="OAuth client secret."
)
parser.add_argument(
"--background-refresh",
action="store_true",
default=False,
help="Refresh in the background instead of as needed per request",
)
parser.add_argument(
"--debug", action="store_true", help="Enable debug logging."
)
parser.add_argument(
"--disable-mirrored-groups",
action="store_false",
default=True,
dest="enable_mirrored_groups",
help="Disable creation of mirrored groups.",
)
parser.add_argument(
"--refresh-interval",
type=int,
default=60,
help="How often to refresh the database in seconds",
)

# Options for Microsoft Entra backend
entra_group = parser.add_argument_group("Microsoft Entra")
entra_group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False)
entra_group.add_argument(
"--entra-tenant-id", type=str, help="Microsoft Entra tenant ID."
)

# Options for Keycloak backend
keycloak_group = parser.add_argument_group("Keycloak")
keycloak_group.add_argument("--keycloak-base-url", type=str, help="Keycloak base URL.", required=False)
keycloak_group.add_argument("--keycloak-realm", type=str, help="Keycloak Realm.", required=False)
keycloak_group.add_argument(
"--keycloak-base-url", type=str, help="Keycloak base URL."
)
keycloak_group.add_argument(
"--keycloak-realm", type=str, help="Keycloak Realm."
)
# Options for Redis cache
redis_group = parser.add_argument_group("Redis")
redis_group.add_argument("--redis-host", type=str, help="Host for Redis server.")
redis_group.add_argument("--redis-port", type=int, help="Port for Redis server.")
redis_group.add_argument(
"--redis-host", type=str, help="Host for Redis server."
)
redis_group.add_argument(
"--redis-port", type=int, help="Port for Redis server."
)
# Options for TLS
tls_group = parser.add_argument_group("TLS")
tls_group.add_argument(
"--tls-certificate", type=str, help="Location of TLS certificate (pem)."
)
tls_group.add_argument(
"--tls-port", type=int, default=1636, help="Port to run on with encryption."
)
tls_group.add_argument(
"--tls-private-key", type=str, help="Location of TLS private key (pem)."
)
# Parse arguments
args = parser.parse_args()

Expand Down

0 comments on commit d66a77a

Please sign in to comment.