diff --git a/Dockerfile b/Dockerfile index 6f8fe57..cd6db06 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,6 +21,7 @@ RUN chmod ugo+x ./entrypoint.sh # Open appropriate ports EXPOSE 1389 +EXPOSE 1636 # Run the server ENTRYPOINT ["./entrypoint.sh"] diff --git a/README.md b/README.md index 170917d..8911e95 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,18 @@ from the `docker` directory. You can use a Redis server to store generated `uidNumber` and `gidNumber` values in a more persistent way. To do this, you will need to provide the `--redis-host` and `--redis-port` arguments to `run.py`. +### Configure background refresh [Optional] + +By default Apricot will refresh the LDAP tree whenever it is accessed and it contains data older than 60 seconds. +If it takes a long time to fetch all users and groups, or you want to ensure that each request gets a prompt response, you may want to configure background refresh to have it periodically be refreshed in the background. + +This is enabled with the `--background-refresh` flag, which uses the `--refresh-interval` parameter as the interval to refresh the ldap database. + +### Using TLS [Optional] + +You can set up a TLS listener to communicate with encryption enabled over the configured port. +To enable it you need to configure the tls port ex. `--tls-port=1636`, and provide a path to the pem files for the certificate `--tls-certificate=` and the private key `--tls-private-key=`. + ## Outputs This will create an LDAP tree that looks like this: diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index fa98c22..8776335 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -2,8 +2,8 @@ import sys from typing import Any, cast -from twisted.internet import reactor -from twisted.internet.endpoints import serverFromString +from twisted.internet import reactor, task +from twisted.internet.endpoints import quoteStringArgument, serverFromString from twisted.internet.interfaces import IReactorCore, IStreamServerEndpoint from twisted.python import log @@ -21,10 +21,15 @@ def __init__( domain: str, port: int, *, + background_refresh: bool = False, debug: bool = False, enable_mirrored_groups: bool = True, redis_host: str | None = None, redis_port: int | None = None, + refresh_interval: int = 60, + tls_port: int | None = None, + tls_certificate: str | None = None, + tls_private_key: str | None = None, **kwargs: Any, ) -> None: self.debug = debug @@ -66,15 +71,43 @@ def __init__( if self.debug: log.msg("Creating an LDAPServerFactory.") factory = OAuthLDAPServerFactory( - domain, oauth_client, enable_mirrored_groups=enable_mirrored_groups + domain, + oauth_client, + background_refresh=background_refresh, + enable_mirrored_groups=enable_mirrored_groups, + refresh_interval=refresh_interval, ) + if background_refresh: + if self.debug: + log.msg( + f"Starting background refresh (interval={factory.adaptor.refresh_interval})" + ) + loop = task.LoopingCall(factory.adaptor.refresh) + loop.start(factory.adaptor.refresh_interval) + # Attach a listening endpoint if self.debug: - log.msg("Attaching a listening endpoint.") + log.msg("Attaching a listening endpoint (plain).") endpoint: IStreamServerEndpoint = serverFromString(reactor, f"tcp:{port}") endpoint.listen(factory) + # Attach a listening endpoint + if tls_certificate or tls_private_key: + if not tls_certificate: + msg = "No TLS certificate provided. Please provide one with --tls-certificate or disable TLS." + raise ValueError(msg) + if not tls_private_key: + msg = "No TLS private key provided. Please provide one with --tls-private-key or disable TLS." + raise ValueError(msg) + if self.debug: + log.msg("Attaching a listening endpoint (TLS).") + ssl_endpoint: IStreamServerEndpoint = serverFromString( + reactor, + f"ssl:{tls_port}:privateKey={quoteStringArgument(tls_private_key)}:certKey={quoteStringArgument(tls_certificate)}", + ) + ssl_endpoint.listen(factory) + # Load the Twisted reactor self.reactor = cast(IReactorCore, reactor) diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index bcabc6c..303d9e4 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -9,16 +9,30 @@ class OAuthLDAPServerFactory(ServerFactory): def __init__( - self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool + self, + domain: str, + oauth_client: OAuthClient, + *, + background_refresh: bool, + enable_mirrored_groups: bool, + refresh_interval: int, ): """ - Initialise an LDAPServerFactory + Initialise an OAuthLDAPServerFactory + @param background_refresh: Whether to refresh the LDAP tree in the background rather than on access + @param domain: The root domain of the LDAP tree + @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users @param oauth_client: An OAuth client used to construct the LDAP tree + @param refresh_interval: Interval in seconds after which the tree must be refreshed """ # Create an LDAP lookup tree self.adaptor = OAuthLDAPTree( - domain, oauth_client, enable_mirrored_groups=enable_mirrored_groups + domain, + oauth_client, + background_refresh=background_refresh, + enable_mirrored_groups=enable_mirrored_groups, + refresh_interval=refresh_interval, ) def __repr__(self) -> str: diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index d9eb133..66e649f 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -18,23 +18,27 @@ def __init__( domain: str, oauth_client: OAuthClient, *, + background_refresh: bool, enable_mirrored_groups: bool, - refresh_interval: int = 60, + refresh_interval: int, ) -> None: """ Initialise an OAuthLDAPTree + @param background_refresh: Whether to refresh the LDAP tree in the background rather than on access @param domain: The root domain of the LDAP tree + @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users @param oauth_client: An OAuth client used to construct the LDAP tree @param refresh_interval: Interval in seconds after which the tree must be refreshed """ + self.background_refresh = background_refresh self.debug = oauth_client.debug self.domain = domain + self.enable_mirrored_groups = enable_mirrored_groups self.last_update = time.monotonic() self.oauth_client = oauth_client self.refresh_interval = refresh_interval self.root_: OAuthLDAPEntry | None = None - self.enable_mirrored_groups = enable_mirrored_groups @property def dn(self) -> DistinguishedName: @@ -46,7 +50,17 @@ def root(self) -> OAuthLDAPEntry: Lazy-load the LDAP tree on request @return: An OAuthLDAPEntry for the tree + + @raises: ValueError. """ + if not self.background_refresh: + self.refresh() + if not self.root_: + msg = "LDAP tree could not be loaded" + raise ValueError(msg) + return self.root_ + + def refresh(self) -> None: if ( not self.root_ or (time.monotonic() - self.last_update) > self.refresh_interval @@ -104,7 +118,6 @@ def root(self) -> OAuthLDAPEntry: # Set last updated time log.msg("Finished building LDAP tree.") self.last_update = time.monotonic() - return self.root_ def __repr__(self) -> str: return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}" diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index e2e6ea5..58aaf8d 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -24,6 +24,13 @@ class OAuthDataAdaptor: def __init__( self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool ): + """ + Initialise an OAuthDataAdaptor + + @param domain: The root domain of the LDAP tree + @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users + @param oauth_client: An OAuth client used to construct the LDAP tree + """ self.debug = oauth_client.debug self.oauth_client = oauth_client self.root_dn = "DC=" + domain.replace(".", ",DC=") diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3824a35..7ce3b01 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -14,6 +14,7 @@ services: REDIS_HOST: "redis" ports: - "1389:1389" + - "1636:1636" restart: always redis: diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 04261da..4739a1f 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -23,6 +23,7 @@ if [ -z "${DOMAIN}" ]; then exit 1 fi + # Arguments with defaults if [ -z "${PORT}" ]; then PORT="1389" @@ -30,7 +31,6 @@ if [ -z "${PORT}" ]; then fi - # Optional arguments EXTRA_OPTS="" if [ -n "${DEBUG}" ]; then @@ -41,18 +41,14 @@ if [ -n "${DISABLE_MIRRORED_GROUPS}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --disable-mirrored-groups" fi + +# Backend arguments: Entra if [ -n "${ENTRA_TENANT_ID}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --entra-tenant-id $ENTRA_TENANT_ID" fi -if [ -n "${REDIS_HOST}" ]; then - if [ -z "${REDIS_PORT}" ]; then - REDIS_PORT="6379" - echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'" - fi - EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT" -fi +# Backend arguments: Keycloak if [ -n "${KEYCLOAK_BASE_URL}" ]; then if [ -z "${KEYCLOAK_REALM}" ]; then echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] KEYCLOAK_REALM environment variable is not set" @@ -61,6 +57,41 @@ if [ -n "${KEYCLOAK_BASE_URL}" ]; then EXTRA_OPTS="${EXTRA_OPTS} --keycloak-base-url $KEYCLOAK_BASE_URL --keycloak-realm $KEYCLOAK_REALM" fi + +# LDAP refresh arguments +if [ -n "${BACKGROUND_REFRESH}" ]; then + EXTRA_OPTS="${EXTRA_OPTS} --background-refresh" +fi + +if [ -n "${REFRESH_INTERVAL}" ]; then + EXTRA_OPTS="${EXTRA_OPTS} --refresh-interval $REFRESH_INTERVAL" +fi + + +# TLS arguments +if [ -n "${TLS_PORT}" ]; then + if [ -z "${TLS_CERTIFICATE}" ]; then + echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] TLS_CERTIFICATE environment variable is not set" + exit 1 + fi + if [ -z "${TLS_PRIVATE_KEY}" ]; then + echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] TLS_PRIVATE_KEY environment variable is not set" + exit 1 + fi + EXTRA_OPTS="${EXTRA_OPTS} --tls-port $TLS_PORT --tls-certificate $TLS_CERTIFICATE --tls-private-key $TLS_PRIVATE_KEY" +fi + + +# Redis arguments +if [ -n "${REDIS_HOST}" ]; then + if [ -z "${REDIS_PORT}" ]; then + REDIS_PORT="6379" + echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'" + fi + EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT" +fi + + # Run the server hatch run python run.py \ --backend "${BACKEND}" \ diff --git a/run.py b/run.py index c228f20..98f4831 100644 --- a/run.py +++ b/run.py @@ -11,27 +11,75 @@ description="Apricot is a proxy for delegating LDAP requests to an OpenID Connect backend.", ) # Common options needed for all backends - parser.add_argument("-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use.") - parser.add_argument("-d", "--domain", type=str, help="Which domain users belong to.") + parser.add_argument( + "-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use." + ) + parser.add_argument( + "-d", "--domain", type=str, help="Which domain users belong to." + ) parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.") - parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.") - parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.") - parser.add_argument("--disable-mirrored-groups", action="store_false", - dest="enable_mirrored_groups", default=True, - help="Disable creation of mirrored groups.") - parser.add_argument("--debug", action="store_true", help="Enable debug logging.") + parser.add_argument( + "-p", "--port", type=int, default=1389, help="Port to run on." + ) + parser.add_argument( + "-s", "--client-secret", type=str, help="OAuth client secret." + ) + parser.add_argument( + "--background-refresh", + action="store_true", + default=False, + help="Refresh in the background instead of as needed per request", + ) + parser.add_argument( + "--debug", action="store_true", help="Enable debug logging." + ) + parser.add_argument( + "--disable-mirrored-groups", + action="store_false", + default=True, + dest="enable_mirrored_groups", + help="Disable creation of mirrored groups.", + ) + parser.add_argument( + "--refresh-interval", + type=int, + default=60, + help="How often to refresh the database in seconds", + ) + # Options for Microsoft Entra backend entra_group = parser.add_argument_group("Microsoft Entra") - entra_group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False) + entra_group.add_argument( + "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID." + ) # Options for Keycloak backend keycloak_group = parser.add_argument_group("Keycloak") - keycloak_group.add_argument("--keycloak-base-url", type=str, help="Keycloak base URL.", required=False) - keycloak_group.add_argument("--keycloak-realm", type=str, help="Keycloak Realm.", required=False) + keycloak_group.add_argument( + "--keycloak-base-url", type=str, help="Keycloak base URL." + ) + keycloak_group.add_argument( + "--keycloak-realm", type=str, help="Keycloak Realm." + ) # Options for Redis cache redis_group = parser.add_argument_group("Redis") - redis_group.add_argument("--redis-host", type=str, help="Host for Redis server.") - redis_group.add_argument("--redis-port", type=int, help="Port for Redis server.") + redis_group.add_argument( + "--redis-host", type=str, help="Host for Redis server." + ) + redis_group.add_argument( + "--redis-port", type=int, help="Port for Redis server." + ) + # Options for TLS + tls_group = parser.add_argument_group("TLS") + tls_group.add_argument( + "--tls-certificate", type=str, help="Location of TLS certificate (pem)." + ) + tls_group.add_argument( + "--tls-port", type=int, default=1636, help="Port to run on with encryption." + ) + tls_group.add_argument( + "--tls-private-key", type=str, help="Location of TLS private key (pem)." + ) # Parse arguments args = parser.parse_args()