Skip to content

Commit

Permalink
support save dataset compressed in zstd
Browse files Browse the repository at this point in the history
  • Loading branch information
Binh Vu committed Jan 21, 2024
1 parent 34cf045 commit d497560
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 92 deletions.
26 changes: 25 additions & 1 deletion kgdata/spark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from pyspark import RDD, SparkConf, SparkContext, TaskContext
from sm.misc.funcs import assert_not_null

from kgdata.misc.funcs import deser_zstd_records

# SparkContext singleton
_sc = None

Expand All @@ -42,7 +44,7 @@
V2 = TypeVar("V2")


def get_spark_context():
def get_spark_context() -> SparkContext:
"""Get spark context
Returns
Expand Down Expand Up @@ -531,6 +533,28 @@ def save_partition(partition: Iterable[str] | Iterable[bytes]):
raise Exception(f"Unknown compression: {compression}")


def text_file(
filepattern: Path, min_partitions: Optional[int] = None, use_unicode: bool = True
):
"""Drop-in replacement for SparkContext.textFile that supports zstd files."""
filepattern = Path(filepattern)
# to support zst files (indir)
if (
filepattern.is_dir()
and any(
file.name.startswith("part-") and file.name.endswith(".zst")
for file in filepattern.iterdir()
)
) or filepattern.name.endswith(".zst"):
return (
get_spark_context()
.binaryFiles(str(filepattern), min_partitions)
.flatMap(lambda x: deser_zstd_records(x[1]))
)

return get_spark_context().textFile(str(filepattern), min_partitions, use_unicode)


@dataclass
class EmptyBroadcast(Generic[V]):
value: V
Expand Down
68 changes: 26 additions & 42 deletions kgdata/spark/extended_rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Generic,
Hashable,
Iterable,
Literal,
Optional,
Protocol,
Sequence,
Expand All @@ -32,6 +33,8 @@
get_spark_context,
join_repartition,
left_outer_join_repartition,
save_as_text_file,
text_file,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -219,7 +222,7 @@ def save_as_single_text_file(
shutil.rmtree(outfile + "_tmp")

def save_like_dataset(
self,
self: ExtendedRDD[str] | ExtendedRDD[bytes],
dataset: "Dataset",
checksum: bool = True,
auto_coalesce: bool = False,
Expand All @@ -228,6 +231,7 @@ def save_like_dataset(
min_num_partitions: int = 64,
max_num_partitions: int = 1024,
trust_dataset_dependencies: bool = False,
compression_level: Optional[int] = None,
) -> None:
"""Save this RDD as a dataset similar to the given dataset. By default, checksum of the dataset is computed
so we can be confident that the data hasn't changed yet, or multiple copied are indeed equal.
Expand All @@ -242,19 +246,21 @@ def save_like_dataset(
min_num_partitions: if auto_coalesce is enable and this variable is not None, this will be the minimum number of partitions to coalesce to.
max_num_partitions: if auto_coalesce is enable and this variable is not None, this will be the maximum number of partitions to coalesce to.
trust_dataset_dependencies: whether to trust the dataset dependencies. If this is False, we will verify the dataset dependencies and ensure that they are equal.
compression_level: compression level to use. This is only used when compression is zst.
"""
file_pattern = Path(dataset.file_pattern)
if file_pattern.suffix == ".gz":
compressionCodecClass = "org.apache.hadoop.io.compress.GzipCodec"
compression = "gz"
elif file_pattern.suffix == ".zst":
# this is a dummy codec for our custom version of save_as_text_file
compressionCodecClass = "kgdata.compress.ZstdCodec"
compression = "zst"
compression_level = compression_level or 3
else:
# this to make sure the dataset file pattern matches the generated file from spark.
assert file_pattern.suffix == "" and file_pattern.name.startswith(
"part-"
), file_pattern.name
compressionCodecClass = None
compression = None

# verify dataset dependencies
if trust_dataset_dependencies:
Expand Down Expand Up @@ -283,7 +289,8 @@ def save_like_dataset(

self.save_as_dataset(
dataset.get_data_directory(),
compressionCodecClass=compressionCodecClass,
compression=compression,
compression_level=compression_level,
name=dataset.name,
checksum=checksum,
auto_coalesce=auto_coalesce,
Expand All @@ -294,9 +301,10 @@ def save_like_dataset(
)

def save_as_dataset(
self,
self: ExtendedRDD[str] | ExtendedRDD[bytes],
outdir: StrPath,
compressionCodecClass: Optional[str] = None,
compression: Optional[Literal["gz", "zst"]] = None,
compression_level: Optional[int] = None,
name: Optional[str] = None,
checksum: bool = True,
auto_coalesce: bool = False,
Expand All @@ -310,7 +318,8 @@ def save_as_dataset(
# Arguments
outdir: output directory
compressionCodecClass: compression codec class to use
compression: compression to use. If None, no compression is used.
compression_level: compression level to use. This is only used when compression is zst.
name: name of the dataset, by default, we use the output directory name
checksum: whether to compute checksum of the dataset. Usually, we don't need to compute the checksum
for intermediate datasets.
Expand All @@ -322,32 +331,25 @@ def save_as_dataset(
"""
outdir = str(outdir)

# if compressionCodecClass == "kgdata.compress.ZstdCodec":
# compression = "zst"
# elif compressionCodecClass == "org.apache.hadoop.io.compress.GzipCodec":
# compression = "gz"
# else:
# assert compressionCodecClass is None
# compression = None

if not auto_coalesce:
self.rdd.saveAsTextFile(outdir, compressionCodecClass=compressionCodecClass)
save_as_text_file(self.rdd, Path(outdir), compression, compression_level)
else:
tmp_dir = str(outdir) + "_tmp"
self.rdd.saveAsTextFile(
tmp_dir, compressionCodecClass=compressionCodecClass
)
save_as_text_file(self.rdd, Path(tmp_dir), compression, compression_level)

rdd = get_spark_context().textFile(tmp_dir)
rdd = text_file(Path(tmp_dir))
num_partitions = math.ceil(
sum((os.path.getsize(file) for file in glob.glob(tmp_dir + "/part-*")))
/ partition_size
)
num_partitions = max(min_num_partitions, num_partitions)
num_partitions = min(max_num_partitions, num_partitions)

rdd.coalesce(num_partitions, shuffle).saveAsTextFile(
outdir, compressionCodecClass=compressionCodecClass
save_as_text_file(
rdd.coalesce(num_partitions, shuffle),
Path(outdir),
compression,
compression_level,
)
shutil.rmtree(tmp_dir)

Expand Down Expand Up @@ -499,25 +501,7 @@ def textFile(
dependencies={},
)

# to support zst files (indir)
if (
indir.is_dir()
and any(
file.name.startswith("part-") and file.name.endswith(".zst")
for file in indir.iterdir()
)
) or indir.name.endswith(".zst"):
return ExtendedRDD(
get_spark_context()
.binaryFiles(str(indir), minPartitions)
.flatMap(lambda x: deser_zstd_records(x[1])),
sig,
)

return ExtendedRDD(
get_spark_context().textFile(str(indir), minPartitions, use_unicode),
sig,
)
return ExtendedRDD(text_file(indir, minPartitions, use_unicode), sig)

@staticmethod
def binaryFiles(
Expand Down
3 changes: 0 additions & 3 deletions kgdata/wikidata/datasets/deprecated/__init__.py

This file was deleted.

41 changes: 0 additions & 41 deletions kgdata/wikidata/datasets/deprecated/wp2wd.py

This file was deleted.

11 changes: 6 additions & 5 deletions kgdata/wikidata/datasets/entity_redirections.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,19 @@ def entity_redirections() -> Dataset[tuple[str, str]]:
lst = serde.csv.deser(redirection_file, delimiter="\t")

unk_target_ds = Dataset.string(
cfg.entity_redirections / "unknown_target_entities/part-*"
cfg.entity_redirections / "unknown_target_entities/part-*",
name="entity-redirections/unknown-target-entities",
dependencies=[redirection_ds, entity_ids()],
)

if not unk_target_ds.has_complete_data():
(
ExtendedRDD.parallelize(list(set(x[1] for x in lst)))
.subtract(entity_ids().get_extended_rdd())
.save_as_dataset(
cfg.entity_redirections / "unknown_target_entities",
name="entity-redirections/unknown-target-entities",
auto_coalesce=True,
.save_like_dataset(
dataset=unk_target_ds,
checksum=False,
auto_coalesce=True,
)
)

Expand Down

0 comments on commit d497560

Please sign in to comment.