Skip to content

Commit

Permalink
feat: support complex schemas in append (#2209)
Browse files Browse the repository at this point in the history
Fixes two bugs, both associated with schemas that have holes in the
field ids:

1. `write_fragments()` and `LanceFragment.create()` assume they can
derive the field ids from the Arrow schema. This is not the case if
there are holes in the schema. Therefore, when the mode is `Append`, we
check the existing schema of the dataset and use its field ids. Fixes
#2179
2. `PageTable` assumed that there were no holes in the field ids. It was
parametrized as `field_id_offset` and `num_fields`, assuming the field
ids were `field_id_offset..(field_id_offset + num_fields)`. This is
changed to be parametrized by the min and max field id.

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
wjones127 and westonpace authored Apr 18, 2024
1 parent 615a77f commit d1582a5
Show file tree
Hide file tree
Showing 17 changed files with 431 additions and 70 deletions.
2 changes: 1 addition & 1 deletion protos/table.proto
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ message DataFile {
// or it could be encoded into a single packed column. To determine column indices
// the column_indices property should be used instead.
//
// In Lance v1 these ids must be sorted and contiguous.
// In Lance v1 these ids must be sorted but might not always be contiguous.
repeated int32 fields = 2;
// The top-level column indices for each field in the file.
//
Expand Down
12 changes: 12 additions & 0 deletions python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def create(
schema: Optional[pa.Schema] = None,
max_rows_per_group: int = 1024,
progress: Optional[FragmentWriteProgress] = None,
mode: str = "append",
) -> FragmentMetadata:
"""Create a :class:`FragmentMetadata` from the given data.
Expand All @@ -167,6 +168,10 @@ def create(
*Experimental API*. Progress tracking for writing the fragment. Pass
a custom class that defines hooks to be called when each fragment is
starting to write and finishing writing.
mode: str, default "append"
The write mode. If "append" is specified, the data will be checked
against the existing dataset's schema. Otherwise, pass "create" or
"overwrite" to assign new field ids to the schema.
See Also
--------
Expand Down Expand Up @@ -204,6 +209,7 @@ def create(
reader,
max_rows_per_group=max_rows_per_group,
progress=progress,
mode=mode,
)
return FragmentMetadata(inner_meta.json())

Expand Down Expand Up @@ -487,6 +493,7 @@ def write_fragments(
dataset_uri: Union[str, Path],
schema: Optional[pa.Schema] = None,
*,
mode: str = "append",
max_rows_per_file: int = 1024 * 1024,
max_rows_per_group: int = 1024,
max_bytes_per_file: int = 90 * 1024 * 1024 * 1024,
Expand All @@ -509,6 +516,10 @@ def write_fragments(
schema : pa.Schema, optional
The schema of the data. If not specified, the schema will be inferred
from the data.
mode : str, default "append"
The write mode. If "append" is specified, the data will be checked
against the existing dataset's schema. Otherwise, pass "create" or
"overwrite" to assign new field ids to the schema.
max_rows_per_file : int, default 1024 * 1024
The maximum number of rows per data file.
max_rows_per_group : int, default 1024
Expand Down Expand Up @@ -548,6 +559,7 @@ def write_fragments(
fragments = _write_fragments(
dataset_uri,
reader,
mode=mode,
max_rows_per_file=max_rows_per_file,
max_rows_per_group=max_rows_per_group,
max_bytes_per_file=max_bytes_per_file,
Expand Down
29 changes: 29 additions & 0 deletions python/python/tests/test_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,35 @@ def test_write_fragments(tmp_path: Path):
assert progress.complete_called == 2


def test_write_fragments_schema_holes(tmp_path: Path):
# Create table with 3 cols
data = pa.table({"a": range(3)})
dataset = write_dataset(data, tmp_path)
dataset.add_columns({"b": "a + 1"})
dataset.add_columns({"c": "a + 2"})
# Delete the middle column to create a hole in the field ids
dataset.drop_columns(["b"])

def get_field_ids(fragment):
return [id for f in fragment.data_files() for id in f.field_ids()]

field_ids = get_field_ids(dataset.get_fragments()[0])

data = pa.table({"a": range(3, 6), "c": range(5, 8)})
fragment = LanceFragment.create(tmp_path, data)
assert get_field_ids(fragment) == field_ids

data = pa.table({"a": range(6, 9), "c": range(8, 11)})
fragments = write_fragments(data, tmp_path)
assert len(fragments) == 1
assert get_field_ids(fragments[0]) == field_ids

operation = LanceOperation.Append([fragment, *fragments])
dataset = LanceDataset.commit(tmp_path, operation, read_version=dataset.version)

assert dataset.to_table().equals(pa.table({"a": range(9), "c": range(2, 11)}))


def test_write_fragment_with_progress(tmp_path: Path):
df = pd.DataFrame({"a": [10 * 10]})
progress = ProgressForTest()
Expand Down
3 changes: 3 additions & 0 deletions python/python/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def test_ray_sink(tmp_path: Path):
ds = lance.dataset(tmp_path)
ds.count_rows() == 10
assert ds.schema.names == schema.names
# The schema is platform-dependent, because numpy uses int32 on Windows.
# So we observe the schema that is written and use that.
schema = ds.schema

tbl = ds.to_table()
assert sorted(tbl["id"].to_pylist()) == list(range(10))
Expand Down
1 change: 1 addition & 0 deletions rust/lance-file/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ tracing.workspace = true
rand.workspace = true
tempfile.workspace = true
proptest.workspace = true
pretty_assertions.workspace = true

[build-dependencies]
prost-build.workspace = true
156 changes: 121 additions & 35 deletions rust/lance-file/src/page_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,49 @@ pub struct PageTable {
impl PageTable {
/// Load [PageTable] from disk.
///
/// The field_ids that are loaded are `field_id_offset` to `field_id_offset + num_columns`.
/// `field_id_offset` should be the smallest field_id in the schema. `num_columns` should
/// be the total unique number of field ids, including struct fields despite the fact
/// they have no data pages.
/// Parameters:
/// * `position`: The start position in the file where the page table is stored.
/// * `min_field_id`: The smallest field_id that is present in the schema.
/// * `max_field_id`: The largest field_id that is present in the schema.
/// * `num_batches`: The number of batches in the file.
///
/// The page table is stored as an array. The on-disk size is determined based
/// on the `min_field_id`, `max_field_id`, and `num_batches` parameters. If
/// these are incorrect, the page table will not be read correctly.
///
/// The full sequence of field ids `min_field_id..=max_field_id` will be loaded.
/// Non-existent pages will be represented as (0, 0) in the page table. Pages
/// can be non-existent because they are not present in the file, or because
/// they are struct fields which have no data pages.
pub async fn load<'a>(
reader: &dyn Reader,
position: usize,
num_columns: i32,
min_field_id: i32,
max_field_id: i32,
num_batches: i32,
field_id_offset: i32,
) -> Result<Self> {
let length = num_columns * num_batches * 2;
let decoder = PlainDecoder::new(reader, &DataType::Int64, position, length as usize)?;
if max_field_id < min_field_id {
return Err(Error::Internal {
message: format!(
"max_field_id {} is less than min_field_id {}",
max_field_id, min_field_id
),
location: location!(),
});
}

let field_ids = min_field_id..=max_field_id;
let num_columns = field_ids.clone().count();
let length = num_columns * num_batches as usize * 2;
let decoder = PlainDecoder::new(reader, &DataType::Int64, position, length)?;
let raw_arr = decoder.decode().await?;
let arr = raw_arr.as_any().downcast_ref::<Int64Array>().unwrap();

let mut pages = BTreeMap::default();
for col in 0..num_columns {
let field_id = col + field_id_offset;
for (field_pos, field_id) in field_ids.enumerate() {
pages.insert(field_id, BTreeMap::default());
for batch in 0..num_batches {
let idx = col * num_batches + batch;
let idx = field_pos as i32 * num_batches + batch;
let batch_position = &arr.value((idx * 2) as usize);
let batch_length = &arr.value((idx * 2 + 1) as usize);
pages.get_mut(&field_id).unwrap().insert(
Expand All @@ -75,22 +96,35 @@ impl PageTable {

/// Write [PageTable] to disk.
///
/// `field_id_offset` is the smallest field_id that is present in the schema.
/// `min_field_id` is the smallest field_id that is present in the schema.
/// This might be a struct field, which has no data pages, but it still must
/// be serialized to the page table per the format spec.
///
/// Any (field_id, batch_id) combinations that are not present in the page table
/// will be written as (0, 0) to indicate an empty page.
pub async fn write(&self, writer: &mut dyn Writer, field_id_offset: i32) -> Result<usize> {
/// will be written as (0, 0) to indicate an empty page. This includes any
/// holes in the field ids as well as struct fields which have no data pages.
pub async fn write(&self, writer: &mut dyn Writer, min_field_id: i32) -> Result<usize> {
if self.pages.is_empty() {
return Err(Error::InvalidInput {
source: "empty page table".into(),
location: location!(),
});
}

let observed_min = *self.pages.keys().min().unwrap();
if min_field_id > *self.pages.keys().min().unwrap() {
return Err(Error::invalid_input(
format!(
"field_id_offset {} is greater than the minimum field_id {}",
min_field_id, observed_min
),
location!(),
));
}
let max_field_id = *self.pages.keys().max().unwrap();
let field_ids = min_field_id..=max_field_id;

let pos = writer.tell().await?;
let num_columns = self.pages.keys().max().unwrap() + 1 - field_id_offset;
let num_batches = self
.pages
.values()
Expand All @@ -99,10 +133,10 @@ impl PageTable {
.unwrap()
+ 1;

let mut builder = Int64Builder::with_capacity((num_columns * num_batches) as usize);
for col in 0..num_columns {
let mut builder =
Int64Builder::with_capacity(field_ids.clone().count() * num_batches as usize);
for field_id in field_ids {
for batch in 0..num_batches {
let field_id = col + field_id_offset;
if let Some(page_info) = self.get(field_id, batch) {
builder.append_value(page_info.position as i64);
builder.append_value(page_info.length as i64);
Expand Down Expand Up @@ -138,6 +172,7 @@ impl PageTable {
mod tests {

use super::*;
use pretty_assertions::assert_eq;

use lance_io::local::LocalObjectReader;

Expand All @@ -156,13 +191,14 @@ mod tests {
let mut page_table = PageTable::default();
let page_info = PageInfo::new(1, 2);

// Add fields 10..13, 4 batches with some missing
// Add fields 10..14, 4 batches with some missing
page_table.set(10, 2, page_info.clone());
page_table.set(11, 1, page_info.clone());
page_table.set(12, 0, page_info.clone());
page_table.set(12, 1, page_info.clone());
page_table.set(12, 2, page_info.clone());
page_table.set(12, 3, page_info.clone());
// A hole at 12
page_table.set(13, 0, page_info.clone());
page_table.set(13, 1, page_info.clone());
page_table.set(13, 2, page_info.clone());
page_table.set(13, 3, page_info.clone());

let test_dir = tempfile::tempdir().unwrap();
let path = test_dir.path().join("test");
Expand All @@ -185,27 +221,77 @@ mod tests {
let actual = PageTable::load(
reader.as_ref(),
pos,
3, // There are three columns
4, // 4 batches
starting_field_id, // First field id is 10, but we want to start at 9
13, // Last field id is 13
4, // 4 batches
)
.await
.unwrap();

// Output should have filled in the empty pages.
let mut expected = actual.clone();
let default_page_info = PageInfo::new(0, 0);
expected.set(9, 0, default_page_info.clone());
expected.set(9, 1, default_page_info.clone());
expected.set(9, 2, default_page_info.clone());
expected.set(9, 3, default_page_info.clone());
expected.set(10, 0, default_page_info.clone());
expected.set(10, 1, default_page_info.clone());
expected.set(10, 3, default_page_info.clone());
expected.set(11, 0, default_page_info.clone());
expected.set(11, 2, default_page_info.clone());
expected.set(11, 3, default_page_info);
let expected_default_pages = [
(9, 0),
(9, 1),
(9, 2),
(9, 3),
(10, 0),
(10, 1),
(10, 3),
(11, 0),
(11, 2),
(11, 3),
(12, 0),
(12, 1),
(12, 2),
(12, 3),
];
for (field_id, batch) in expected_default_pages.iter() {
expected.set(*field_id, *batch, default_page_info.clone());
}

assert_eq!(expected, actual);
}

#[tokio::test]
async fn test_error_handling() {
let mut page_table = PageTable::default();

let test_dir = tempfile::tempdir().unwrap();
let path = test_dir.path().join("test");

// Returns an error if the page table is empty
let mut writer = tokio::fs::File::create(&path).await.unwrap();
let res = page_table.write(&mut writer, 1).await;
assert!(res.is_err());
assert!(
matches!(res.unwrap_err(), Error::InvalidInput { source, .. } if source.to_string().contains("empty page table"))
);

let page_info = PageInfo::new(1, 2);
page_table.set(0, 0, page_info.clone());

// Returns an error if passing a min_field_id higher than the lowest field_id
let mut writer = tokio::fs::File::create(&path).await.unwrap();
let res = page_table.write(&mut writer, 1).await;
assert!(res.is_err());
assert!(
matches!(res.unwrap_err(), Error::InvalidInput { source, .. }
if source.to_string().contains("field_id_offset 1 is greater than the minimum field_id 0"))
);

let mut writer = tokio::fs::File::create(&path).await.unwrap();
let res = page_table.write(&mut writer, 0).await.unwrap();

let reader = LocalObjectReader::open_local_path(&path, 1024)
.await
.unwrap();

// Returns an error if max_field_id is less than min_field_id
let res = PageTable::load(reader.as_ref(), res, 1, 0, 1).await;
assert!(res.is_err());
assert!(matches!(res.unwrap_err(), Error::Internal { message, .. }
if message.contains("max_field_id 0 is less than min_field_id 1")));
}
}
Loading

0 comments on commit d1582a5

Please sign in to comment.