Skip to content

Commit

Permalink
A python way to workaround too early NodeFactory disposal
Browse files Browse the repository at this point in the history
  • Loading branch information
slyalin committed Mar 8, 2024
1 parent 0d2ba62 commit 1866b77
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from typing import Dict, List, Optional, Tuple, Union
import math
import gc

import numpy as np
import torch
Expand Down Expand Up @@ -296,15 +297,11 @@ def wrapper(module, target_op, *args, **kwargs):
model._openvino_patch_orig_forward = model.forward
model.forward = partial(ov_wrapper, model)

from openvino.runtime.utils.node_factory import NodeFactory
factory = NodeFactory()

def patch_stateful_model(model):
def patch_stateful_model(model, factory):
print('TRANSFORMING OPTIMUM-INTEL MODEL TO vLLM COMPATIBLE FORM')
from openvino.runtime.passes import Manager, MatcherPass, WrapType, Matcher, AnyInput, Or
from openvino.runtime import opset13
from openvino.runtime.utils import replace_node
factory.add_extension("libuser_ov_extensions.so")

#model.remove_parameter(model.input('beam_idx').get_node())
max_context_len = opset13.parameter(shape=[], dtype=np.int64, name='max_context_len') # max_context_len
Expand Down Expand Up @@ -551,7 +548,12 @@ def load_model(self) -> None:
import openvino as ov
from optimum.intel import OVModelForCausalLM
self.model = OVModelForCausalLM.from_pretrained(self.model_config.model, export=True, compile=False, load_in_8bit=False, trust_remote_code=True) # need stateful because it also enables SDPA
patch_stateful_model(self.model.model)
if not hasattr(self.model, 'ov_node_factory'):
from openvino.runtime.utils.node_factory import NodeFactory
# Keep factory to destroy it in a particular moment when all other objects referencing custom nodes are destoyed
self.model.ov_node_factory = NodeFactory()
self.model.ov_node_factory.add_extension('libuser_ov_extensions.so')
patch_stateful_model(self.model.model, self.model.ov_node_factory)
#ov.serialize(self.model.model, 'vllm_openvino_model.xml')
core = ov.Core()
ov_compiled = core.compile_model(self.model.model, "CPU")
Expand All @@ -570,6 +572,14 @@ def load_model(self) -> None:
else:
self.model = get_model(self.model_config)

def __del__(self):
# Order is important
del self.model.ov_request
del self.model.model
if gc: # when app is being destroyed the module may not be available
gc.collect()
del self.model.ov_node_factory

def set_block_size(self, block_size: int) -> None:
self.block_size = block_size

Expand Down

0 comments on commit 1866b77

Please sign in to comment.