Skip to content

Commit

Permalink
SAM2 AMG server side request batching (#1197)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Oct 31, 2024
1 parent ae77f40 commit fe498e4
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 107 deletions.
16 changes: 10 additions & 6 deletions examples/sam2_amg_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ xargs -I {} curl -s -w "\n" -X POST http://<your_hostname>:<your_port>/upload_rl

Experiments run on H100 and with batch size 1

| mode | mIoU | mask count mismatch | avg. ms per request |
| --- |--- | ------------------ | ----------------- |
| baseline | 1.0 | 0 | 786 |
| ao | 1.0 | 0 | 738 |
| fast | 0.95 | 190 | 563 |
| furious | 0 | 1000 | 204 |
| mode | mIoU | mask count mismatch | avg. ms per request | batch size | points per batch |
| --- | --- | ------------------- | ------------------- | ---------- | ---------------- |
| baseline | 1.0 | 0 | 786 | 1 | 64 |
| baseline | N/A | N/A | N/A | 32 | 1024 |
| ao | 1.0 | 0 | 738 | 1 | 64 |
| ao | 0.9999994993636996 | 0 | 564 | 32 | 1024 |
| fast | 0.95 | 190 | 563 | 1 | 64 |
| fast | 0.9527849197435295 | 191 | 460 | 32 | 1024 |
| furious | 0 | 1000 | 204 | 1 | 64 |
| furious | 0 | 1000 | 210 | 32 | 1024 |

mask count mismatch counts the number of requests where the number of masks differ from the baseline.
For example, the baseline may have chosen to segment an image into 18 masks, but the fast variant produces 17 or 19.
Expand Down
270 changes: 170 additions & 100 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import matplotlib.pyplot as plt
import numpy as np

import asyncio
from contextlib import asynccontextmanager
import contextlib

# from torch._inductor import config as inductorconfig
# inductorconfig.triton.unique_kernel_names = True
# inductorconfig.coordinate_descent_tuning = True
Expand All @@ -39,13 +43,16 @@ def iou(mask1, mask2):
return (intersection.sum(dim=(-1, -2)) / union.sum(dim=(-1, -2)))


def show_anns(anns):
def show_anns(anns, rle_to_mask):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)

for ann in sorted_anns:
ann['segmentation'] = rle_to_mask(ann['segmentation'])

img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
ms = []
Expand All @@ -68,6 +75,93 @@ def profiler_runner(path, fn, *args, **kwargs):
return result


def image_tensor_to_masks(example_image, mask_generator):
masks = mask_generator.generate(example_image)
return masks


def image_tensors_to_masks(example_images, mask_generator):
return mask_generator.generate_batch(example_images)


def file_bytes_to_image_tensor(file_bytes):
image_array = np.asarray(file_bytes, dtype=np.uint8)
example_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
return example_image


def masks_to_rle_dict(masks):
ret_data = {}
for mask_id in range(len(masks)):
ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"]
return ret_data


# Queue to hold incoming requests
request_queue = asyncio.Queue()
batch_interval = 1 # Time interval to wait before processing a batch


def process_batch(batch, mask_generator):
print(f"Processing batch of len {len(batch)}")
t = time.time()
image_tensors = [image_tensor for (image_tensor, _) in batch]
masks = mask_generator.generate_batch(image_tensors)
print(f"Took avg. {(time.time() - t) / len(batch)}s per batch entry")
return masks


async def batch_worker(mask_generator, batch_size, *, pad_batch=True, furious=False):
cm = torch.autocast("cuda", dtype=torch.bfloat16) if furious else contextlib.nullcontext()
cm.__enter__()
while True:
batch = []
while len(batch) < batch_size and not request_queue.empty():
batch.append(await request_queue.get())

if batch:

padded_batch = batch
if pad_batch:
padded_batch = batch + ([batch[-1]] * (batch_size - len(batch)))
print(f"len(padded_batch): {len(padded_batch)}")
results = process_batch(padded_batch, mask_generator)
for i, (_, response_future) in enumerate(batch):
response_future.set_result(results[i])

await asyncio.sleep(batch_interval)


@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup logic
mask_generator = app.state.mask_generator
batch_size = app.state.batch_size
furious = app.state.furious
task = asyncio.create_task(batch_worker(mask_generator, batch_size, furious=furious))
yield
# Shutdown logic (if needed)
task.cancel()


def benchmark_fn(func, inp, mask_generator):
torch.cuda.reset_peak_memory_stats()
logging.info("Running 3 warumup iterations.")
for _ in range(3):
func(inp, mask_generator)
logging.info("Running 10 benchmark iterations.")
t = time.time()
for _ in range(10):
func(inp, mask_generator)
print(f"Benchmark took {(time.time() - t)/10.0}s per iteration.")
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
_, total_memory = torch.cuda.mem_get_info()
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
print(f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%")


def main(checkpoint_path,
baseline=False,
fast=False,
Expand All @@ -79,13 +173,15 @@ def main(checkpoint_path,
points_per_batch=64,
port=5000,
host="127.0.0.1",
dry=False):
dry=False,
batch_size=1):
if verbose:
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logging.info(f"Running with fast set to {fast} and furious set to {furious}")
logging.info(f"Running with port {port} and host {host}")
logging.info(f"Running with batch size {batch_size}")

if baseline:
logging.info(f"Importing sam2 from outside of torchao. If this errors, install https://github.com/facebookresearch/sam2")
Expand Down Expand Up @@ -136,67 +232,65 @@ def main(checkpoint_path,
dynamic=True,
)

example_image = cv2.imread('dog.jpg')
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
with torch.backends.cuda.sdp_kernel(enable_cudnn=True):
t = time.time()
logging.info(f"Running one iteration to compile.")
masks = mask_generator.generate(example_image)
logging.info(f"First iteration took {time.time() - t}s.")
if unittest:
logging.info(f"Running strict comparison to reference mask")
import json
ref_masks = json.loads(open("dog_rle.json").read())
ret_data = {}
for mask_id in range(len(masks)):
ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"]
v0_areas = []
v1_areas = []
miou_sum = 0.0
miou_count = 0
for k0 in ref_masks:
assert k0 in ret_data, f"Expected {k0} to be in return data"
from torchao._models.sam2.utils.amg import area_from_rle
v0_area = area_from_rle(ref_masks[k0])
v1_area = area_from_rle(ret_data[k0])
v0_areas.append(v0_area)
v1_areas.append(v1_area)
if v0_area != v1_area:
print(f"v0 area {v0_area} doesn't match v1 area {v1_area}")
v0_mask = torch.from_numpy(rle_to_mask(ref_masks[k0]))
v1_mask = torch.from_numpy(rle_to_mask(ret_data[k0]))
if not torch.allclose(v0_mask, v1_mask):
miou_sum += iou(v0_mask, v1_mask)
miou_count += 1
print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}")
if miou_count == 0:
print("Masks exactly match reference.")
else:
print(f"mIoU is {miou_sum / miou_count}")

if benchmark:
logging.info(f"Running 3 warumup iterations.")
for _ in range(3):
masks = mask_generator.generate(example_image)
logging.info(f"Running 10 benchmark iterations, then exit.")
t = time.time()
for _ in range(10):
masks = mask_generator.generate(example_image)
print(f"Benchmark took {(time.time() - t)/10.0}s per iteration.")
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
_, total_memory = torch.cuda.mem_get_info()
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
print(f"max_memory_allocated_bytes: {max_memory_allocated_bytes}MiB or {max_memory_allocated_percentage}%")

if profile is not None:
print(f"Saving profile under {profile}")
profiler_runner(profile, mask_generator.generate, example_image)
with open('dog.jpg', 'rb') as f:
image_tensor = file_bytes_to_image_tensor(bytearray(f.read()))

t = time.time()
logging.info("Running three iterations to compile and warmup.")
image_tensor_to_masks(image_tensor, mask_generator)
image_tensor_to_masks(image_tensor, mask_generator)
image_tensor_to_masks(image_tensor, mask_generator)
logging.info(f"Three iterations took {time.time() - t}s.")

if unittest:
masks = image_tensor_to_masks(image_tensor, mask_generator)
# Smoke test only for now. Need more images for batch.
logging.info(f"batched smoke test")
_ = image_tensors_to_masks([image_tensor] * batch_size, mask_generator)
ret_data = masks_to_rle_dict(masks)
import json
ref_masks = json.loads(open("dog_rle.json").read())
v0_areas = []
v1_areas = []
miou_sum = 0.0
miou_count = 0
for k0 in ref_masks:
assert k0 in ret_data, f"Expected {k0} to be in return data"
from torchao._models.sam2.utils.amg import area_from_rle
v0_area = area_from_rle(ref_masks[k0])
v1_area = area_from_rle(ret_data[k0])
v0_areas.append(v0_area)
v1_areas.append(v1_area)
if v0_area != v1_area:
print(f"v0 area {v0_area} doesn't match v1 area {v1_area}")
v0_mask = torch.from_numpy(rle_to_mask(ref_masks[k0]))
v1_mask = torch.from_numpy(rle_to_mask(ret_data[k0]))
if not torch.allclose(v0_mask, v1_mask):
miou_sum += iou(v0_mask, v1_mask)
miou_count += 1
print(f"Masks don't match for key {k0}. IoU is {iou(v0_mask, v1_mask)}")
if miou_count == 0:
print("Masks exactly match reference.")
else:
print(f"mIoU is {miou_sum / miou_count}")

if benchmark:
print("batch size 1 test")
benchmark_fn(image_tensor_to_masks, image_tensor, mask_generator)
print(f"batch size {batch_size} test")
benchmark_fn(image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)

if profile is not None:
print(f"Saving profile under {profile}")
profiler_runner(profile, image_tensors_to_masks, [image_tensor] * batch_size, mask_generator)

if dry:
return

app = FastAPI()
app = FastAPI(lifespan=lifespan)
app.state.mask_generator = mask_generator
app.state.batch_size = batch_size
app.state.furious = furious

# Allow all origins (you can restrict it in production)
app.add_middleware(
Expand All @@ -209,53 +303,29 @@ def main(checkpoint_path,

@app.post("/upload_rle")
async def upload_rle(image: UploadFile = File(...)):
# Save the uploaded image to a temporary location
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{image.filename}")
with open(temp_file.name, "wb") as b:
shutil.copyfileobj(image.file, b)

# Read the image back into memory to send as response
example_image = cv2.imread(temp_file.name)
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
t = time.time()
with torch.backends.cuda.sdp_kernel(enable_cudnn=True):
masks = mask_generator.generate(example_image)
print(f"Took {time.time() - t} to generate a mask for input image.")
ret_data = {}
for mask_id in range(len(masks)):
ret_data[f"mask_{mask_id}"] = masks[mask_id]["segmentation"]
return ret_data
image_tensor = file_bytes_to_image_tensor(bytearray(await image.read()))
response_future = asyncio.Future()
await request_queue.put((image_tensor, response_future))
masks = await response_future
return masks_to_rle_dict(masks)

@app.post("/upload")
async def upload_image(image: UploadFile = File(...)):
# Save the uploaded image to a temporary location
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=f"_{image.filename}")
with open(temp_file.name, "wb") as b:
shutil.copyfileobj(image.file, b)

# Read the image back into memory to send as response
example_image = cv2.imread(temp_file.name)
example_image = cv2.cvtColor(example_image, cv2.COLOR_BGR2RGB)
t = time.time()
with torch.backends.cuda.sdp_kernel(enable_cudnn=True):
masks = mask_generator.generate(example_image)
print(f"Took {time.time() - t} to generate a mask for input image.")
image_tensor = file_bytes_to_image_tensor(bytearray(await image.read()))
response_future = asyncio.Future()
await request_queue.put((image_tensor, response_future))
masks = await response_future

# Save an example
plt.figure(figsize=(example_image.shape[1]/100., example_image.shape[0]/100.), dpi=100)
plt.imshow(example_image)
for i in range(len(masks)):
masks[i]["segmentation"] = rle_to_mask(masks[i]["segmentation"])
show_anns(masks)
plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100)
plt.imshow(image_tensor)
show_anns(masks, rle_to_mask)
plt.axis('off')
plt.tight_layout()
plt.savefig(temp_file.name, format='png')

# Read the image back into memory to send as response
with open(temp_file.name, "rb") as f:
image_data = f.read()

# Return the image as a StreamingResponse
return StreamingResponse(BytesIO(image_data), media_type="image/png")
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")


uvicorn.run(app, host=host, port=port, log_level="info")
Expand Down
Loading

0 comments on commit fe498e4

Please sign in to comment.