Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[djl-convert] Support convert local model to DJL format #3386

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,19 @@ def __init__(self):
self.outputs = None
self.api = HfApi()

def save_model(self, model_info, args: Namespace, temp_dir: str,
def save_model(self, model_info, task: str, args: Namespace, temp_dir: str,
model_zoo: bool):
if args.output_format == "OnnxRuntime":
return self.save_onnx_model(model_info, args, temp_dir, model_zoo)
return self.save_onnx_model(model_info, task, args, temp_dir,
model_zoo)
elif args.output_format == "Rust":
return self.save_rust_model(model_info, args, temp_dir, model_zoo)
else:
return self.save_pytorch_model(model_info, args, temp_dir,
model_zoo)

def save_onnx_model(self, model_info, args: Namespace, temp_dir: str,
model_zoo: bool):
def save_onnx_model(self, model_info, task: str, args: Namespace,
temp_dir: str, model_zoo: bool):
model_id = model_info.modelId

if not os.path.exists(temp_dir):
Expand All @@ -82,6 +83,8 @@ def save_onnx_model(self, model_info, args: Namespace, temp_dir: str,
sys.argv.extend(["--dtype", args.dtype])
if args.trust_remote_code:
sys.argv.append("--trust-remote-code")
if os.path.exists(model_id):
sys.argv.extend(["--task", task])
sys.argv.append(temp_dir)

main()
Expand Down Expand Up @@ -135,29 +138,46 @@ def save_rust_model(self, model_info, args: Namespace, temp_dir: str,
return False, "Failed to save tokenizer", -1

# Save config.json
config_file = hf_hub_download(repo_id=model_id, filename="config.json")
if os.path.exists(model_id):
config_file = os.path.join(model_id, "config.json")
else:
config_file = hf_hub_download(repo_id=model_id,
filename="config.json")

shutil.copyfile(config_file, os.path.join(temp_dir, "config.json"))

target = os.path.join(temp_dir, "model.safetensors")
model = self.api.model_info(model_id, files_metadata=True)
has_sf_file = False
has_pt_file = False
for sibling in model.siblings:
if sibling.rfilename == "model.safetensors":
has_sf_file = True
elif sibling.rfilename == "pytorch_model.bin":
has_pt_file = True

if has_sf_file:
file = hf_hub_download(repo_id=model_id,
filename="model.safetensors")
shutil.copyfile(file, target)
elif has_pt_file:
file = hf_hub_download(repo_id=model_id,
filename="pytorch_model.bin")
convert_file(file, target)

if os.path.exists(model_id):
file = os.path.join(model_id, "model.safetensors")
if os.path.exists(file):
shutil.copyfile(file, target)
else:
file = os.path.join(model_id, "pytorch_model.bin")
if os.path.exists(file):
convert_file(file, target)
else:
return False, f"No model file found for: {model_id}", -1
else:
return False, f"No model file found for: {model_id}", -1
model = self.api.model_info(model_id, files_metadata=True)
has_sf_file = False
has_pt_file = False
for sibling in model.siblings:
if sibling.rfilename == "model.safetensors":
has_sf_file = True
elif sibling.rfilename == "pytorch_model.bin":
has_pt_file = True

if has_sf_file:
file = hf_hub_download(repo_id=model_id,
filename="model.safetensors")
shutil.copyfile(file, target)
elif has_pt_file:
file = hf_hub_download(repo_id=model_id,
filename="pytorch_model.bin")
convert_file(file, target)
else:
return False, f"No model file found for: {model_id}", -1

arguments = self.save_serving_properties(model_info, "Rust", temp_dir,
hf_pipeline, include_types)
Expand Down Expand Up @@ -191,8 +211,13 @@ def save_pytorch_model(self, model_info, args: Namespace, temp_dir: str,
return False, "Failed to save tokenizer", -1

# Save config.json just for reference
config = hf_hub_download(repo_id=model_id, filename="config.json")
shutil.copyfile(config, os.path.join(temp_dir, "config.json"))
if os.path.exists(model_id):
config_file = os.path.join(model_id, "config.json")
else:
config_file = hf_hub_download(repo_id=model_id,
filename="config.json")

shutil.copyfile(config_file, os.path.join(temp_dir, "config.json"))

# Save jit traced .pt file to temp dir
include_types = "token_type_ids" in hf_pipeline.tokenizer.model_input_names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import json
import logging
import os
import sys

from huggingface_hub import HfApi

sys.path.append(os.path.dirname(os.path.realpath(__file__)))

from djl_converter.arg_parser import converter_args


class ModelInfoHolder(object):

def __init__(self, model_id: str):
self.modelId = model_id
with open(os.path.join(model_id, "config.json")) as f:
self.config = json.load(f)


def main():
logging.basicConfig(stream=sys.stdout,
format="%(message)s",
Expand All @@ -38,10 +45,17 @@ def main():
logging.error(f"output directory: {output_dir} is not empty.")
return

api = HfApi()
model_info = api.model_info(args.model_id,
revision=args.revision,
token=args.token)
if os.path.exists(args.model_id):
logging.info(f"converting local model: {args.model_id}")
model_info = ModelInfoHolder(args.model_id)
else:
logging.info(f"converting HuggingFace hub model: {args.model_id}")
from huggingface_hub import HfApi

api = HfApi()
model_info = api.model_info(args.model_id,
revision=args.revision,
token=args.token)

from djl_converter.huggingface_models import HuggingfaceModels, SUPPORTED_TASKS

Expand All @@ -51,14 +65,14 @@ def main():
task = "sentence-similarity"
if not task:
logging.error(
f"Unsupported model architecture: {arch} for {model_id}.")
f"Unsupported model architecture: {arch} for {args.model_id}.")
return

converter = SUPPORTED_TASKS[task]

try:
result, reason, _ = converter.save_model(model_info, args, output_dir,
False)
result, reason, _ = converter.save_model(model_info, task, args,
output_dir, False)
if result:
logging.info(f"Convert model {model_info.modelId} finished.")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main():

try:
result, reason, size = converter.save_model(
model_info, args, temp_dir, True)
model_info, task, args, temp_dir, True)
if not result:
logging.error(f"{model_info.modelId}: {reason}")
except Exception as e:
Expand Down
Loading