diff --git a/images/pyspark-notebook/Dockerfile b/images/pyspark-notebook/Dockerfile index f64bb75db6..c9c9326b47 100644 --- a/images/pyspark-notebook/Dockerfile +++ b/images/pyspark-notebook/Dockerfile @@ -41,11 +41,11 @@ ENV SPARK_OPTS="--driver-java-options=-Xms1024M --driver-java-options=-Xmx4096M COPY setup_spark.py /opt/setup-scripts/ # Setup Spark -RUN SPARK_VERSION="${spark_version}" \ - HADOOP_VERSION="${hadoop_version}" \ - SCALA_VERSION="${scala_version}" \ - SPARK_DOWNLOAD_URL="${spark_download_url}" \ - /opt/setup-scripts/setup_spark.py +RUN /opt/setup-scripts/setup_spark.py \ + --spark-version="${spark_version}" \ + --hadoop-version="${hadoop_version}" \ + --scala-version="${scala_version}" \ + --spark-download-url="${spark_download_url}" # Configure IPython system-wide COPY ipython_kernel_config.py "/etc/ipython/" diff --git a/images/pyspark-notebook/setup_spark.py b/images/pyspark-notebook/setup_spark.py index 3481cc701f..a494b8322a 100755 --- a/images/pyspark-notebook/setup_spark.py +++ b/images/pyspark-notebook/setup_spark.py @@ -4,9 +4,9 @@ # Requirements: # - Run as the root user -# - Required env variables: SPARK_HOME, HADOOP_VERSION, SPARK_DOWNLOAD_URL -# - Optional env variables: SPARK_VERSION, SCALA_VERSION +# - Required env variable: SPARK_HOME +import argparse import logging import os import subprocess @@ -27,13 +27,10 @@ def get_all_refs(url: str) -> list[str]: return [a["href"] for a in soup.find_all("a", href=True)] -def get_spark_version() -> str: +def get_latest_spark_version() -> str: """ - If ${SPARK_VERSION} env variable is non-empty, simply returns it - Otherwise, returns the last stable version of Spark using spark archive + Returns the last stable version of Spark using spark archive """ - if (version := os.environ["SPARK_VERSION"]) != "": - return version LOGGER.info("Downloading Spark versions information") all_refs = get_all_refs("https://archive.apache.org/dist/spark/") stable_versions = [ @@ -106,12 +103,20 @@ def configure_spark(spark_dir_name: str, spark_home: Path) -> None: if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - spark_version = get_spark_version() + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--spark-version", required=True) + arg_parser.add_argument("--hadoop-version", required=True) + arg_parser.add_argument("--scala-version", required=True) + arg_parser.add_argument("--spark-download-url", type=Path, required=True) + args = arg_parser.parse_args() + + args.spark_version = args.spark_version or get_latest_spark_version() + spark_dir_name = download_spark( - spark_version=spark_version, - hadoop_version=os.environ["HADOOP_VERSION"], - scala_version=os.environ["SCALA_VERSION"], - spark_download_url=Path(os.environ["SPARK_DOWNLOAD_URL"]), + spark_version=args.spark_version, + hadoop_version=args.hadoop_version, + scala_version=args.scala_version, + spark_download_url=args.spark_download_url, ) configure_spark( spark_dir_name=spark_dir_name, spark_home=Path(os.environ["SPARK_HOME"])