Skip to content

Generalizing Dask-XGBoost #3075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mrocklin opened this issue Sep 20, 2019 · 7 comments
Open

Generalizing Dask-XGBoost #3075

mrocklin opened this issue Sep 20, 2019 · 7 comments
Assignees
Labels
discussion Discussing a topic with no specific actions yet

Comments

@mrocklin
Copy link
Member

In many ML workloads we want to do pre-processing with Dask, load all of the data into memory, and then hand off to some other system:

  1. XGBoost
  2. LightGBM
  3. Horovod/Tensorflow
  4. Various cuML projects
  5. Dask itself in some future Actor-filled world

The Dask-XGBoost relationship does this in a few ways today:

  1. https://github.com/dask/dask-xgboost/ : we wait until all data is ready, then we query the scheduler for the location of each partition, and submit a function on each worker that grabs the local data and sends it to XGBoost
  2. https://xgboost.readthedocs.io/en/latest/python/python_api.html#module-xgboost.dask : we wait until all the data is ready, then we run a function on all the workers where they just grab partitions from the data without thinking about locality. Training tends to be much slower than a full data transfer, so this is less error prone without being much slower
  3. Rewrite Dask interface. dmlc/xgboost#4819 : a proposed rewrite in XGBoost that is similar to option 1 above

The processes above work today, but there are some problems:

  1. The code within dask-xgboost to figure out where data is and then run a function on every worker uses internal APIs that ML researchers don't understand. There is a lot of blind copy-pasting going on. If we go this route then we should maybe give them something higher level.
  2. This approach is error prone.
    1. If data moves in between querying its location and running the function, things break.
    2. If a worker dies anywhere in this process, things break.

So, here are some things that we could do:

  1. We could encode either of the approaches above into some higher level API that others could use in the future. This might make things easier to use, and also allow us to improve the internals behind the scenes in the future. It would be good to figure out what this contract would look like regardless.
  2. We could implement a few coordination primitives that would make writing code like this at a lower level more approachable. This would probably help enable more creative solutions. For example operations like barrier, or collect_local_data_that_looks_like_X, might be useful.

I was doodling some pseudocode on a plane about what a solution for XGBoost might look like with some higher level primitives and came up with the following (although I don't think that people should read too much into it).

Disorganized XGBoost ravings
def train_xgboost(df):
    tasks = [
        dask.delayed(train_task)(
             part,
            n=df.npartitions,
            name="xgboost-" + tokenize(df)
        )
        for part in df.to_delayed()
    ]

    @dask.delayed
    def first_nonempty(L):
        return toolz.first(filter(None, L))

    return first_nonempty(tasks)


def train_xgboost_task(partition, n=n, name=name):
    group_data = group_action(data=partition, n=n, name=name)

    if not group_data:  # Someone else collected all of the data
        return None

    partitions = group.data

    Do XGBoost training

    return result


def group_action(data=None, n=None, name=None):
    worker = get_worker()

    # Send message to scheduler that a new task has checked in
    # This increments some counter that we'll check later
    # This is kind of like a Semaphore, but in reverse?
    some_semaphore.release()

    # This will have to happen in a threadsafe way, maybe on the event loop
    if name in worker.groups:  # someone beat us here
        worker.group_action_data[name].append(data)
        return []  # we're not the first one here, return empty list
    else:
        group_data = [data]
        worker.groups_action_data[name] = group_data

    secede()  # leave the thread pool so that we don't block progress
    # Block until N tasks have checked in
    some_semaphore.acquire()
    rejoin()  # rejoin thread pool

    return group_data  # this has now collected lots of partitions

I think that focusing on what a good contract would look like for XGBoost, and then copying over one of the solutions for that, might be a helpful start.

cc @TomAugspurger @trivialfis @ogrisel @@RAMitchell

@felipeblazing
Copy link

felipeblazing commented Nov 14, 2019

I was looking at the _train function in https://github.com/dask/dask-xgboost/blob/master/dask_xgboost/core.py and I found this line

find the locations of all chunks and map them to particular Dask workers

key_to_part_dict = dict([(part.key, part) for part in parts])
who_has = yield client.scheduler.who_has(keys=[part.key for part in parts])
worker_map = defaultdict(list)
for key, workers in who_has.items():
    worker_map[first(workers)].append(key_to_part_dict[key])`

I made the following test code to see if this would use different workers but I am seeing

import cudf
import dask_cudf
import dask
import dask.dataframe
from dask.distributed import wait

df = cudf.DataFrame()
df['col1'] = cudf.Series(list(range(10000)))
df['col2'] = cudf.Series(list(range(10000)))
ds = dask_cudf.from_cudf(df,npartitions=2)

ds = ds.persist()
wait(ds)
print(client.who_has(ds.get_partition(0)))
print(client.who_has(ds.get_partition(1)))```

output

{"('from_pandas-22fbbb8d31eea4f29bc46bf1d987f2d7', 1)": ('tcp://127.0.0.1:38971',),
 "('from_pandas-22fbbb8d31eea4f29bc46bf1d987f2d7', 0)": ('tcp://127.0.0.1:32851',)}
{"('from_pandas-22fbbb8d31eea4f29bc46bf1d987f2d7', 1)": ('tcp://127.0.0.1:38971',),
 "('from_pandas-22fbbb8d31eea4f29bc46bf1d987f2d7', 0)": ('tcp://127.0.0.1:32851',)}

So given that the code above is using worker_map[first(workers)].append(key_to_part_dict[key]) The jobs will both be launched on a singler worker and the second worker will not be used. In our particular use case what we would like to do is launch a task on every worker that gets partitions that are local to it and starts a computation process. Even if a worker has no data at all we still want to run the same process on an empty dataframe (we just do a_dask_cudf.head(0) to get the schema). Is there a way to get more understanding on which worker is controlling which memory?

@felipeblazing
Copy link

felipeblazing commented Nov 14, 2019

new_ds = ds.get_partition(0).persist()
wait(new_ds)
new_ds2 = ds.get_partition(1).persist()
wait(new_ds2)
print(client.who_has(new_ds))
print(client.who_has(new_ds2))

This seems to do the trick for me.

@TomAugspurger
Copy link
Member

That client.who_has output looks a bit strange for your input, but I think things are balanced acros the multiple workers. Notice that the output of each who_has is the same, it's giving the mapping for all the futures backing ds, not just that one partition.

If you want the who_has for a single partition, you may have to do a bit more

>>> client.who_has(client.futures_of(ds)[0])
{"('from_pandas-02ae7c2929339175d14a7c1c3e7c60b2', 0)": ('tcp://127.0.0.1:59229',)}

@TomAugspurger
Copy link
Member

I'd recommend avoiding

new_ds = ds.get_partition(0).persist()
wait(new_ds)
new_ds2 = ds.get_partition(1).persist()
wait(new_ds2)

since that will be blocking each of the waits serially, but the ds.persist() can run in parallel. If that's not doing the right thing, then we should fix it there (but I think it is doing the right thing).

@felipeblazing
Copy link

Awesome thank you for the clarification!

@felipeblazing
Copy link

ds = ds.persist()
wait(ds)
print(client.who_has(client.futures_of(ds)[0]))
print(client.who_has(client.futures_of(ds.get_partition(0))[0]))

print()
print(client.who_has(client.futures_of(ds)[1]))
print(client.who_has(client.futures_of(ds.get_partition(1))[0]))

I would expect these two print statements to output the same but they output

{"('from_pandas-d3911cc4d2fa8389fe6ac0f5afbd2cae', 0)": ('tcp://127.0.0.1:32851',)}
{"('from_pandas-d3911cc4d2fa8389fe6ac0f5afbd2cae', 0)": ('tcp://127.0.0.1:32851',)}

{"('from_pandas-d3911cc4d2fa8389fe6ac0f5afbd2cae', 1)": ('tcp://127.0.0.1:38971',)}
{"('from_pandas-d3911cc4d2fa8389fe6ac0f5afbd2cae', 0)": ('tcp://127.0.0.1:32851',)}

Perhaps I am misunderstanding something. Is number 0 in this case "('from_pandas-02ae7c2929339175d14a7c1c3e7c60b2', 0)" the partition number?

@mrocklin
Copy link
Member Author

See #3236

@GenevieveBuckley GenevieveBuckley added the discussion Discussing a topic with no specific actions yet label Oct 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion Discussing a topic with no specific actions yet
Projects
None yet
Development

No branches or pull requests

4 participants