Skip to content

Commit

Permalink
support pinning adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
rohithkrn committed Jun 17, 2024
1 parent 1f12122 commit 838f602
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 4 deletions.
63 changes: 63 additions & 0 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.activate_lora(1)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.deactivate_lora(2)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.pin_lora(3)
assert manager.pin_lora(1)
with pytest.raises(RuntimeError):
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
with pytest.raises(RuntimeError):
assert manager.activate_lora(2)

assert manager.deactivate_lora(3)
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.remove_lora(3)
with pytest.raises(ValueError):
assert manager.pin_lora(3)

def test_lru_lora_model_manager(dist_init, dummy_model):
# This tests just the LRU cache functionality, everything else is
Expand Down Expand Up @@ -288,6 +315,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert set(manager.list_loras()) == set()
assert all(x is None for x in manager.lora_index_to_id)

# pinning
assert manager.add_lora(model_lora3)
assert manager.activate_lora(3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(4)
assert set(manager.list_loras()) == {3, 4}
with pytest.raises(ValueError):
assert manager.pin_lora(1)
assert manager.pin_lora(3)
# Remove manually
assert manager.remove_lora(3)
assert not manager.remove_lora(3)

assert set(manager.list_loras()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4

assert manager.add_lora(model_lora1)
assert manager.pin_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)

assert set(manager.list_loras()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2

assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {1}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] is None

with pytest.raises(RuntimeError):
assert manager.remove_oldest_lora()

assert set(manager.list_loras()) == {1}


def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,5 +976,8 @@ def remove_lora(self, lora_id: int) -> bool:
def list_loras(self) -> Set[int]:
return self.model_executor.list_loras()

def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)

def check_health(self) -> None:
self.model_executor.check_health()
3 changes: 3 additions & 0 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:

def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
Expand Down
7 changes: 7 additions & 0 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def remove_lora(self, lora_id: int) -> bool:
"remove_lora",
lora_id=lora_id,
)

def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"pin_lora",
lora_id=lora_id,
)

def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def list_loras(self) -> Set[int]:
Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/neuron_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:

def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
Expand Down
21 changes: 21 additions & 0 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,27 @@ def remove_lora(self, lora_id: int) -> bool:
if self.long_lora_context:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None))

def pin_lora(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id)
return True

def _pin_lora_in_cpu_cache(self, lora_id: int):
try:
self._registered_loras.pin(lora_id)
except ValueError:
raise ValueError(f"Pinning failed. LoRA {lora_id} is not registered.")

def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_loras:
# move lora to gpu if not already active
self.activate_lora(lora_id)

self._active_loras.pin(lora_id)



# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
Expand Down
3 changes: 3 additions & 0 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:

def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self._lora_manager.pin_lora(lora_id)

def remove_all_loras(self):
self._lora_manager.remove_all_loras()
Expand Down
27 changes: 23 additions & 4 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class LRUCache(Generic[T]):

def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: set[Hashable] = set()
self.capacity = capacity

def __contains__(self, key: Hashable) -> bool:
Expand Down Expand Up @@ -101,14 +102,29 @@ def put(self, key: Hashable, value: T) -> None:
self.cache.move_to_end(key)
self._remove_old_if_needed()

def pin(self, key: Hashable) -> None:
if key not in self.cache:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)

def _unpin(self, key: Hashable) -> None:
self.pinned_items.remove(key)

def _on_remove(self, key: Hashable, value: Optional[T]):
pass

def remove_oldest(self):
def remove_oldest(self, remove_pinned=False):
if not self.cache:
return
key, value = self.cache.popitem(last=False)
self._on_remove(key, value)

if not remove_pinned:
if all(key in self.pinned_items for key in self.cache):
raise RuntimeError("All items are pinned, cannot remove oldest from the cache.")
# pop the oldest item in the cache that is not pinned
lru_key = next(key for key in self.cache if key not in self.pinned_items)
else:
lru_key = next(iter(self.cache))
self.pop(lru_key)

def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
Expand All @@ -119,13 +135,16 @@ def pop(self,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value

def clear(self):
while len(self.cache) > 0:
self.remove_oldest()
self.remove_oldest(remove_pinned=True)
self.cache.clear()


Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,11 @@ def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
if not self.lora_manager:
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:

def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
Expand Down
7 changes: 7 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def list_loras(self) -> Set[int]:
Expand All @@ -85,6 +89,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:

def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")

def pin_lora(self, lora_id: int) -> bool:
return ValueError(f"{type(self)} does not support LoRA")

def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")
Expand Down

0 comments on commit 838f602

Please sign in to comment.