-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
46 lines (37 loc) · 2.27 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import cv2
import matplotlib.pyplot as plt
from words2contact import Words2Contact
def main(image_path, prompt, use_gpt, yello_vlm, output_path, llm_path, chat_template):
# Load and process the image
img = cv2.flip(cv2.imread(image_path), 0)
# Initialize the Words2Contact model
words2contact = Words2Contact(use_gpt=use_gpt, yello_vlm=yello_vlm, llm_path=llm_path, chat_template=chat_template)
# Predict based on the prompt and image
point, _, bbs, _, response = words2contact.predict(prompt, img)
# Print prompt and response
print("User: ", prompt)
print("Response: ", response)
# Visualize results
fig, ax = plt.subplots()
ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), origin='lower')
for bb in bbs:
bb.plot_bb(ax)
ax.scatter(point.x, point.y, color='red')
plt.savefig(output_path)
print(f"Output saved to {output_path}")
if __name__ == "__main__":
# Set up argument parser
parser = argparse.ArgumentParser(description="Run Words2Contact with an image and a text prompt.")
parser.add_argument("--image_path", type=str, default="data/test.png", help="Path to the input image file. Default: 'data/test.png'.")
parser.add_argument("--prompt", type=str, default="Place your hand above the red bowl.",
help="Text prompt for Words2Contact. Default: 'Place your hand above the red bowl.'.")
parser.add_argument("--use_gpt", action="store_true", help="use openai api for the llm, remember to export OPEANAI_KEY")
parser.add_argument("--yello_vlm", type=str, default="GroundingDINO", help="Model to use for YELLO VLM. Default: 'GroundingDINO'.")
parser.add_argument("--output_path", type=str, default="data/test_output.png", help="Path to save the output image. Default: 'data/test_output.png'.")
parser.add_argument("--llm_path", type=str, default="models/Calme-7B-Instruct-v0.4.Q8_0.gguf", help="Path to the .gguf llm model weights")
parser.add_argument("--chat_template", type=str, default="ChatML", help="Which chat template to use for local llms, Default: ChatML ")
# Parse arguments
args = parser.parse_args()
# Call the main function
main(args.image_path, args.prompt, args.use_gpt, args.yello_vlm, args.output_path, args.llm_path, args.chat_template)