Skip to content

Commit

Permalink
Revert "EHN: clean cache for VL models (xorbitsai#2163)" (xorbitsai#2230
Browse files Browse the repository at this point in the history
)
  • Loading branch information
qinxuye authored Sep 5, 2024
1 parent 3190701 commit ea669ca
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 34 deletions.
3 changes: 0 additions & 3 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
CompletionChunk,
CompletionUsage,
)
from ..utils import ensure_cache_cleared
from .llm_family import (
LlamaCppLLMSpecV1,
LLMFamilyV1,
Expand Down Expand Up @@ -249,7 +248,6 @@ def _get_final_chat_completion_chunk(
return cast(ChatCompletionChunk, chat_chunk)

@classmethod
@ensure_cache_cleared
def _to_chat_completion_chunks(
cls,
chunks: Iterator[CompletionChunk],
Expand Down Expand Up @@ -282,7 +280,6 @@ async def _async_to_chat_completion_chunks(
i += 1

@staticmethod
@ensure_cache_cleared
def _to_chat_completion(completion: Completion) -> ChatCompletion:
return {
"id": "chat" + completion["id"],
Expand Down
32 changes: 1 addition & 31 deletions xinference/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import gc
import inspect
import json
import logging
import os
Expand All @@ -28,7 +24,7 @@
import torch

from ..constants import XINFERENCE_CACHE_DIR, XINFERENCE_ENV_MODEL_SRC
from ..device_utils import empty_cache, get_available_device, is_device_available
from ..device_utils import get_available_device, is_device_available
from .core import CacheableModelSpec

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -357,32 +353,6 @@ def convert_float_to_int_or_str(model_size: float) -> Union[int, str]:
return str(model_size)


def ensure_cache_cleared(func: Callable):
assert not inspect.iscoroutinefunction(func) and not inspect.isasyncgenfunction(
func
)
if inspect.isgeneratorfunction(func):

@functools.wraps(func)
def inner(*args, **kwargs):
for obj in func(*args, **kwargs):
yield obj
gc.collect()
empty_cache()

else:

@functools.wraps(func)
def inner(*args, **kwargs):
try:
return func(*args, **kwargs)
finally:
gc.collect()
empty_cache()

return inner


def set_all_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
Expand Down

0 comments on commit ea669ca

Please sign in to comment.