Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(dataset): remove dataset batch_size argument #910

Merged
merged 3 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ Starwhale is an MLOps platform. It provides **Instance**, **Project**, **Runtime
super().__init__(merge_label=True, ignore_error=False)
self.model = self._load_model()

def ppl(self, data:bytes, batch_size:int, **kw):
data = self._pre(data, batch_size)
def ppl(self, data:bytes, **kw):
data = self._pre(data)
output = self.model(data)
return self._post(output)

def handle_label(self, label:bytes, batch_size:int, **kw):
def handle_label(self, label:bytes, **kw):
return [int(l) for l in label]

@multi_classification(
Expand All @@ -150,7 +150,7 @@ Starwhale is an MLOps platform. It provides **Instance**, **Project**, **Runtime
_result.extend([int(l) for l in pred])
_pr.extend([l for l in pr])

def _pre(self, input:bytes, batch_size:int):
def _pre(self, input:bytes):
"""write some mnist preprocessing code"""

def _post(self, input:bytes):
Expand Down Expand Up @@ -216,7 +216,6 @@ Starwhale is an MLOps platform. It provides **Instance**, **Project**, **Runtime
label_filter: "t10k-label*"
process: mnist.process:DataSetProcessExecutor
attr:
batch_size: 50
alignment_size: 4k
volume_size: 2M
```
Expand Down
34 changes: 14 additions & 20 deletions client/starwhale/api/_impl/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import sys
import math
import struct
import typing as t
from abc import ABCMeta, abstractmethod
Expand All @@ -24,17 +23,14 @@
)
from starwhale.api._impl.wrapper import Dataset as DatastoreWrapperDataset
from starwhale.core.dataset.store import DatasetStorage
from starwhale.core.dataset.dataset import (
D_ALIGNMENT_SIZE,
D_USER_BATCH_SIZE,
D_FILE_VOLUME_SIZE,
)
from starwhale.core.dataset.dataset import D_ALIGNMENT_SIZE, D_FILE_VOLUME_SIZE

# TODO: tune header size
_header_magic = struct.unpack(">I", b"SWDS")[0]
_data_magic = struct.unpack(">I", b"SDWS")[0]
_header_struct = struct.Struct(">IIQIIII")
_header_size = _header_struct.size
_header_version = 0


@unique
Expand Down Expand Up @@ -192,13 +188,13 @@ class BuildExecutor(metaclass=ABCMeta):
BuildExecutor can build swds.

swds_bin format:
header_magic uint32 I
crc uint32 I
idx uint64 Q
size uint32 I
padding_size uint32 I
batch_size uint32 I
data_magic uint32 I --> above 32 bytes
header_magic uint32 I
crc uint32 I
idx uint64 Q
size uint32 I
padding_size uint32 I
header_version uint32 I
data_magic uint32 I --> above 32 bytes
data bytes...
padding bytes... --> default 4K padding
"""
Expand All @@ -216,13 +212,11 @@ def __init__(
output_dir: Path = Path("./sw_output"),
data_filter: str = "*",
label_filter: str = "*",
batch: int = D_USER_BATCH_SIZE,
alignment_bytes_size: int = D_ALIGNMENT_SIZE,
volume_bytes_size: int = D_FILE_VOLUME_SIZE,
) -> None:
# TODO: add more docstring for args
# TODO: validate group upper and lower?
self._batch = max(batch, 1)
self.data_dir = data_dir
self.data_filter = data_filter
self.label_filter = label_filter
Expand Down Expand Up @@ -269,7 +263,7 @@ def _write(self, writer: t.Any, idx: int, data: bytes) -> t.Tuple[int, int]:
padding_size = self._get_padding_size(size + _header_size)

_header = _header_struct.pack(
_header_magic, crc, idx, size, padding_size, self._batch, _data_magic
_header_magic, crc, idx, size, padding_size, _header_version, _data_magic
)
_padding = b"\0" * padding_size
writer.write(_header + data + _padding)
Expand Down Expand Up @@ -366,10 +360,10 @@ def iter_data_slice(self, path: str) -> t.Generator[bytes, None, None]:

with fpath.open("rb") as f:
_, number, height, width = struct.unpack(">IIII", f.read(16))
print(f">data({fpath.name}) split {math.ceil(number / self._batch)} group")
print(f">data({fpath.name}) split {number} group")

while True:
content = f.read(self._batch * height * width)
content = f.read(height * width)
if not content:
break
yield content
Expand All @@ -379,10 +373,10 @@ def iter_label_slice(self, path: str) -> t.Generator[bytes, None, None]:

with fpath.open("rb") as f:
_, number = struct.unpack(">II", f.read(8))
print(f">label({fpath.name}) split {math.ceil(number / self._batch)} group")
print(f">label({fpath.name}) split {number} group")

while True:
content = f.read(self._batch)
content = f.read(1)
if not content:
break
yield content
Expand Down
5 changes: 1 addition & 4 deletions client/starwhale/api/_impl/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def _get_bucket_by_uri(self) -> str:
class DataField(t.NamedTuple):
idx: int
data_size: int
batch_size: int
data: bytes
ext_attr: t.Dict[str, t.Any]

Expand Down Expand Up @@ -176,7 +175,6 @@ def __iter__(self) -> t.Generator[t.Tuple[DataField, DataField], None, None]:
label = DataField(
idx=row.id,
data_size=len(row.label),
batch_size=data.batch_size,
data=row.label,
ext_attr=_attr,
)
Expand All @@ -193,12 +191,11 @@ def _do_iter(
header: bytes = _file.read(_header_size)
if not header:
break
_, _, idx, size, padding_size, batch, _ = _header_struct.unpack(header)
_, _, idx, size, padding_size, _, _ = _header_struct.unpack(header)
data = _file.read(size + padding_size)
yield DataField(
idx,
size,
batch,
data[:size].tobytes() if isinstance(data, memoryview) else data[:size],
attr,
)
Expand Down
12 changes: 4 additions & 8 deletions client/starwhale/api/_impl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from starwhale.consts import CURRENT_FNAME
from starwhale.base.uri import URI
from starwhale.utils.fs import ensure_dir, ensure_file
from starwhale.base.type import URIType
from starwhale.utils.log import StreamWrapper
from starwhale.consts.env import SWEnv

from .loader import DataField, DataLoader, get_data_loader
from ...base.type import URIType

_TASK_ROOT_DIR = "/var/starwhale" if in_production() else "/tmp/starwhale"

Expand Down Expand Up @@ -210,8 +210,8 @@ def __exit__(
self._sw_logger.remove()

@abstractmethod
def ppl(self, data: bytes, batch_size: int, **kw: t.Any) -> t.Any:
# TODO: how to handle each batch element is not equal.
def ppl(self, data: bytes, **kw: t.Any) -> t.Any:
# TODO: how to handle each element is not equal.
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -239,7 +239,7 @@ def deserialize(self, data: t.Union[str, bytes]) -> t.Any:
ret[self._label_field] = self.label_data_deserialize(ret[self._label_field])
return ret

def handle_label(self, label: bytes, batch_size: int, **kw: t.Any) -> t.Any:
def handle_label(self, label: bytes, **kw: t.Any) -> t.Any:
return label.decode()

def _record_status(func): # type: ignore
Expand Down Expand Up @@ -291,12 +291,10 @@ def _starwhale_internal_run_ppl(self) -> None:
# TODO: inspect profiling
pred = self.ppl(
data.data,
data.batch_size,
data_index=data.idx,
data_size=data.data_size,
label_content=label.data,
label_size=label.data_size,
label_batch=label.batch_size,
label_index=label.idx,
ds_name=data.ext_attr.get("ds_name", ""),
ds_version=data.ext_attr.get("ds_version", ""),
Expand Down Expand Up @@ -335,13 +333,11 @@ def _do_record(
self._ppl_data_field: base64.b64encode(
self.ppl_data_serialize(*args)
).decode("ascii"),
"batch": data.batch_size,
}
if self.merge_label:
try:
label = self.handle_label(
label.data,
label.batch_size,
index=label.idx,
size=label.data_size,
)
Expand Down
3 changes: 0 additions & 3 deletions client/starwhale/core/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class DSProcessMode:

D_FILE_VOLUME_SIZE = 64 * 1024 * 1024 # 64MB
D_ALIGNMENT_SIZE = 4 * 1024 # 4k for page cache
D_USER_BATCH_SIZE = 1
ARCHIVE_SWDS_META = "archive.%s" % SWDSSubFileType.META


Expand All @@ -26,10 +25,8 @@ def __init__(
self,
volume_size: t.Union[int, str] = D_FILE_VOLUME_SIZE,
alignment_size: t.Union[int, str] = D_ALIGNMENT_SIZE,
batch_size: int = D_USER_BATCH_SIZE,
**kw: t.Any,
) -> None:
self.batch_size = batch_size
self.volume_size = convert_to_bytes(volume_size)
self.alignment_size = convert_to_bytes(alignment_size)
self.kw = kw
Expand Down
1 change: 0 additions & 1 deletion client/starwhale/core/dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def _call_make_swds(self, workdir: Path, swds_config: DatasetConfig) -> None:
output_dir=self.store.data_dir,
data_filter=swds_config.data_filter,
label_filter=swds_config.label_filter,
batch=swds_config.attr.batch_size,
alignment_bytes_size=swds_config.attr.alignment_size,
volume_bytes_size=swds_config.attr.volume_size,
) as _obj:
Expand Down
3 changes: 1 addition & 2 deletions client/tests/data/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ build:
created_at: 2022-05-30 17:05:31 CST
dataset_attr:
alignment_size: 4096
batch_size: 50
volume_size: 2097152
dataset_byte_size: 9033838
mode: generate
Expand All @@ -21,4 +20,4 @@ signature:
label_ubyte_2.swds_bin: 211328:blake2b:33cc73cea261f2d534121962f1388389518673ce8d16d9c74ef47fb6c89e6dc7be3381b8e9479e4e085d23a0c9368bb012e724663b8cd0066f16cca0e6e4da23
label_ubyte_3.swds_bin: 178816:blake2b:baa1d6ba0f3bdae60fd261d6694ff0ece80c87dd4d69c00dabf9850f0dd72ede4430c9a1cbbce5f3846c8c80c669bcec31fad208ea0af6237e88618e9e029afc
user_raw_config: {}
version: me4dczlegzswgnrtmftdgyjznqywwza
version: me4dczlegzswgnrtmftdgyjznqywwza
1 change: 0 additions & 1 deletion client/tests/data/dataset/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ tag:
- bin

attr:
batch_size: 50
alignment_size: 4k
volume_size: 2M
5 changes: 1 addition & 4 deletions client/tests/sdk/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def setUp(self) -> None:
f.write(_mnist_label)

def test_workflow(self) -> None:
batch_size = 2
with MNISTBuildExecutor(
dataset_name="mnist",
dataset_version="112233",
Expand All @@ -43,7 +42,6 @@ def test_workflow(self) -> None:
output_dir=Path(self.output_data),
data_filter="mnist-data-*",
label_filter="mnist-data-*",
batch=batch_size,
alignment_bytes_size=64,
volume_bytes_size=100,
) as e:
Expand All @@ -58,8 +56,7 @@ def test_workflow(self) -> None:
_parser = _header_struct.unpack(data_content[:_header_size])
assert _parser[0] == _header_magic
assert _parser[2] == 0
assert _parser[3] == 28 * 28 * batch_size
assert _parser[5] == batch_size
assert _parser[3] == 28 * 28
assert _parser[6] == _data_magic
assert len(data_content) == _header_size + _parser[3] + _parser[4]

Expand Down
8 changes: 3 additions & 5 deletions client/tests/sdk/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class SimpleHandler(PipelineHandler):
def ppl(self, data: bytes, batch_size: int, **kw: t.Any) -> t.Any:
def ppl(self, data: bytes, **kw: t.Any) -> t.Any:
return [1, 2], 0.1

def cmp(self, _data_loader: t.Any) -> t.Any:
Expand Down Expand Up @@ -105,7 +105,6 @@ def test_cmp(self) -> None:
"ppl": base64.b64encode(
_handler.ppl_data_serialize([1, 2], 0.1)
).decode("ascii"),
"batch": 10,
"label": base64.b64encode(
_handler.label_data_serialize([3, 4])
).decode("ascii"),
Expand Down Expand Up @@ -167,7 +166,6 @@ def test_ppl(self, m_scan: MagicMock) -> None:
assert result == [1, 2]
assert pr == 0.1
assert line["index"] == 0
assert line["batch"] == 10

@pytest.mark.skip(reason="wait job scheduler feature, cmp will use datastore")
def test_deserializer(self) -> None:
Expand All @@ -181,10 +179,10 @@ def test_deserializer(self) -> None:
label_data = [1, 2, 3]

class Dummy(PipelineHandler):
def ppl(self, data: bytes, batch_size: int, **kw: t.Any) -> t.Any:
def ppl(self, data: bytes, **kw: t.Any) -> t.Any:
return builtin_data, np_data, tensor_data

def handle_label(self, label: bytes, batch_size: int, **kw: t.Any) -> t.Any:
def handle_label(self, label: bytes, **kw: t.Any) -> t.Any:
return label_data

def cmp(self, _data_loader: t.Any) -> t.Any:
Expand Down
2 changes: 0 additions & 2 deletions docs/docs/standalone/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ title: Starwhale Dataset
|`label_filer`|the filter for label files, support regular expression|✅||string|`t10k-label*`|
|`process`|the class import path which is inherited by `starwhale.api.dataset.BuildExecutor` class. The format is {module path}:{class name}|✅||String|`mnist.process:DataSetProcessExecutor`|
|`desc`|description|❌|""|String|`This is a mnist dataset.`|
|`attr.batch_size`|data batch size|❌|`50`|Integer|`50`|
|`attr.alignment_size`|every section data alignment size|❌|`4k`|String|`4k`|
|`attr.volume_size`|data volume size|❌|`64M`|String|`2M`|

Expand All @@ -29,7 +28,6 @@ process: mnist.process:DataSetProcessExecutor

desc: MNIST data and label test dataset
attr:
batch_size: 50
alignment_size: 4k
volume_size: 2M
```
Loading