Skip to content

Commit

Permalink
Support push_to_hub without org/user to default to logged-in user (#6629
Browse files Browse the repository at this point in the history
)

Revert "Support push_to_hub canonical datasets (#6519)"

This reverts commit a887ee7.
  • Loading branch information
albertvillanova authored Feb 5, 2024
1 parent 991169e commit ea261dd
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5384,13 +5384,14 @@ def push_to_hub(

api = HfApi(endpoint=config.HF_ENDPOINT, token=token)

_ = api.create_repo(
repo_url = api.create_repo(
repo_id,
token=token,
repo_type="dataset",
private=private,
exist_ok=True,
)
repo_id = repo_url.repo_id

if revision is not None:
api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True)
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,13 +1699,14 @@ def push_to_hub(

api = HfApi(endpoint=config.HF_ENDPOINT, token=token)

_ = api.create_repo(
repo_url = api.create_repo(
repo_id,
token=token,
repo_type="dataset",
private=private,
exist_ok=True,
)
repo_id = repo_url.repo_id

if revision is not None:
api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True)
Expand Down
14 changes: 10 additions & 4 deletions tests/test_upstream_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import numpy as np
import pytest
from huggingface_hub import DatasetCard, HfApi
from huggingface_hub.utils import RepositoryNotFoundError

from datasets import (
Audio,
Expand Down Expand Up @@ -71,9 +70,16 @@ def test_push_dataset_dict_to_hub_name_without_namespace(self, temporary_repo):
local_ds = DatasetDict({"train": ds})

with temporary_repo() as ds_name:
# cannot create a repo without namespace
with pytest.raises(RepositoryNotFoundError):
local_ds.push_to_hub(ds_name.split("/")[-1], token=self._token)
local_ds.push_to_hub(ds_name.split("/")[-1], token=self._token)
hub_ds = load_dataset(ds_name, download_mode="force_redownload")

assert local_ds.column_names == hub_ds.column_names
assert list(local_ds["train"].features.keys()) == list(hub_ds["train"].features.keys())
assert local_ds["train"].features == hub_ds["train"].features

# Ensure that there is a single file on the repository that has the correct name
files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset"))
assert files == [".gitattributes", "README.md", "data/train-00000-of-00001.parquet"]

def test_push_dataset_dict_to_hub_datasets_with_different_features(self, cleanup_repo):
ds_train = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]})
Expand Down

0 comments on commit ea261dd

Please sign in to comment.