Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SAM2 AMG cli and other QoL improvements #1336

Merged
merged 3 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/sam2_amg_server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ curl -X POST http://127.0.0.1:5000/upload -F 'image=@/path/to/file.jpg' --output
Start the server

```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --fast
```

Collect the rles
Expand Down Expand Up @@ -58,7 +58,7 @@ Make sure you've installed https://github.com/facebookresearch/sam2

Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --baseline
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --baseline
```

Generate and save rles (one line per json via `-w "\n"`)
Expand All @@ -73,7 +73,7 @@ sys 0m4.137s
### 3. Start server with torchao variant of SAM2
Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname>
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname>
```

Generate and save rles (one line per json via `-w "\n"`)
Expand All @@ -88,7 +88,7 @@ sys 0m4.350s
### 4. Start server with torchao variant of SAM2 and `--fast` optimizations
Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --fast
```

Generate and save rles (one line per json via `-w "\n"`)
Expand All @@ -103,7 +103,7 @@ sys 0m4.138s
### 5. Start server with torchao variant of SAM2 and `--fast` and `--furious` optimizations
Start server
```
python server.py ~/checkpoints/sam2 --port <your_port> --host <your_hostname> --fast --furious
python server.py ~/checkpoints/sam2 large --port <your_port> --host <your_hostname> --fast --furious
```

Generate and save rles (one line per json via `-w "\n"`)
Expand Down
48 changes: 48 additions & 0 deletions examples/sam2_amg_server/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import fire
import logging
import matplotlib.pyplot as plt
from server import file_bytes_to_image_tensor
from server import show_anns
from server import model_type_to_paths
from server import MODEL_TYPES_TO_MODEL
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from torchao._models.sam2.utils.amg import rle_to_mask
from io import BytesIO

def main_docstring():
return f"""
Args:
checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints
model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())}
input_path (str): Path to input image
output_path (str): Path to output image
"""

def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False):
device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
if verbose:
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
image_tensor = file_bytes_to_image_tensor(bytearray(open(input_path, 'rb').read()))
if verbose:
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
masks = mask_generator.generate(image_tensor)

# Save an example
plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100)
plt.imshow(image_tensor)
show_anns(masks, rle_to_mask)
plt.axis('off')
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format=output_format)
buf.seek(0)
with open(output_path, "wb") as file:
file.write(buf.getvalue())

main.__doc__ = main_docstring()
if __name__ == "__main__":
fire.Fire(main)
1 change: 1 addition & 0 deletions examples/sam2_amg_server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ hydra-core
tqdm
iopath
python-multipart
requests
67 changes: 64 additions & 3 deletions examples/sam2_amg_server/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import requests
import uvicorn
import fire
import tempfile
Expand Down Expand Up @@ -37,6 +38,23 @@
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

def download_file(url, download_dir):
# Create the directory if it doesn't exist
download_dir = Path(download_dir)
download_dir.mkdir(parents=True, exist_ok=True)
# Extract the file name from the URL
file_name = url.split('/')[-1]
# Define the full path for the downloaded file
file_path = download_dir / file_name
# Download the file
response = requests.get(url, stream=True)
response.raise_for_status() # Raise an error for bad responses
# Write the file to the specified directory
print(f"Downloading '{file_name}' to '{download_dir}'")
with open(file_path, 'wb') as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
print(f"Downloaded '{file_name}' to '{download_dir}'")

def example_shapes():
return [(848, 480, 3),
Expand Down Expand Up @@ -272,7 +290,51 @@ def unittest_fn(masks, ref_masks, order_by_area=False, verbose=False):
print(f"mIoU is {miou} with equal count {equal_count} out of {len(masks)}")


MODEL_TYPES_TO_CONFIG = {
"tiny": "sam2.1_hiera_t.yaml",
"small": "sam2.1_hiera_s.yaml",
"plus": "sam2.1_hiera_b+.yaml",
"large": "sam2.1_hiera_l.yaml",
}

MODEL_TYPES_TO_MODEL = {
"tiny": "sam2.1_hiera_tiny.pt",
"small": "sam2.1_hiera_small.pt",
"plus": "sam2.1_hiera_base_plus.pt",
"large": "sam2.1_hiera_large.pt",
}


MODEL_TYPES_TO_URL = {
"tiny": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
"small": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
"plus": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
"large": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
}


def main_docstring():
return f"""
Args:
checkpoint_path (str): Path to folder containing checkpoints from https://github.com/facebookresearch/sam2?tab=readme-ov-file#download-checkpoints
model_type (str): Choose from one of {", ".join(MODEL_TYPES_TO_MODEL.keys())}
"""


def model_type_to_paths(checkpoint_path, model_type):
if model_type not in MODEL_TYPES_TO_CONFIG.keys():
raise ValueError(f"Expected model_type to be one of {', '.join(MODEL_TYPES_TO_MODEL.keys())} but got {model_type}")
sam2_checkpoint = Path(checkpoint_path) / Path(MODEL_TYPES_TO_MODEL[model_type])
if not sam2_checkpoint.exists():
print(f"Can't find checkpoint {sam2_checkpoint} in folder {checkpoint_path}. Downloading.")
download_file(MODEL_TYPES_TO_URL[model_type], checkpoint_path)
assert sam2_checkpoint.exists(), "Can't find downloaded file. Please open an issue."
model_cfg = f"configs/sam2.1/{MODEL_TYPES_TO_CONFIG[model_type]}"
return sam2_checkpoint, model_cfg


def main(checkpoint_path,
model_type,
baseline=False,
fast=False,
furious=False,
Expand Down Expand Up @@ -306,9 +368,7 @@ def main(checkpoint_path,
from torchao._models.sam2.utils.amg import rle_to_mask

device = "cuda"
from pathlib import Path
sam2_checkpoint = Path(checkpoint_path) / Path("sam2.1_hiera_large.pt")
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)

logging.info(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
Expand Down Expand Up @@ -450,5 +510,6 @@ async def upload_image(image: UploadFile = File(...)):
# uvicorn.run(app, host=host, port=port, log_level="info")
uvicorn.run(app, host=host, port=port)

main.__doc__ = main_docstring()
if __name__ == "__main__":
fire.Fire(main)
28 changes: 14 additions & 14 deletions torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
_target_: torchao._models.sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
_target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
_target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera
embed_dim: 112
num_heads: 2
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
_target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
Expand All @@ -24,17 +24,17 @@ model:
fpn_interp_model: nearest

memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
Expand All @@ -45,7 +45,7 @@ model:
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
Expand All @@ -57,23 +57,23 @@ model:
num_layers: 4

memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
_target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
_target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
_target_: torchao._models.sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
_target_: torchao._models.sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
Expand Down
28 changes: 14 additions & 14 deletions torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@

# Model
model:
_target_: sam2.modeling.sam2_base.SAM2Base
_target_: torchao._models.sam2.modeling.sam2_base.SAM2Base
image_encoder:
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
_target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder
scalp: 1
trunk:
_target_: sam2.modeling.backbones.hieradet.Hiera
_target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera
embed_dim: 96
num_heads: 1
stages: [1, 2, 11, 2]
global_att_blocks: [7, 10, 13]
window_pos_embed_bkg_spatial_size: [7, 7]
neck:
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
_target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 256
normalize: true
scale: null
Expand All @@ -27,17 +27,17 @@ model:
fpn_interp_model: nearest

memory_attention:
_target_: sam2.modeling.memory_attention.MemoryAttention
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention
d_model: 256
pos_enc_at_input: true
layer:
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
_target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer
activation: relu
dim_feedforward: 2048
dropout: 0.1
pos_enc_at_attn: false
self_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
embedding_dim: 256
Expand All @@ -48,7 +48,7 @@ model:
pos_enc_at_cross_attn_keys: true
pos_enc_at_cross_attn_queries: false
cross_attention:
_target_: sam2.modeling.sam.transformer.RoPEAttention
_target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention
rope_theta: 10000.0
feat_sizes: [32, 32]
rope_k_repeat: True
Expand All @@ -60,23 +60,23 @@ model:
num_layers: 4

memory_encoder:
_target_: sam2.modeling.memory_encoder.MemoryEncoder
_target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder
out_dim: 64
position_encoding:
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
_target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine
num_pos_feats: 64
normalize: true
scale: null
temperature: 10000
mask_downsampler:
_target_: sam2.modeling.memory_encoder.MaskDownSampler
_target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler
kernel_size: 3
stride: 2
padding: 1
fuser:
_target_: sam2.modeling.memory_encoder.Fuser
_target_: torchao._models.sam2.modeling.memory_encoder.Fuser
layer:
_target_: sam2.modeling.memory_encoder.CXBlock
_target_: torchao._models.sam2.modeling.memory_encoder.CXBlock
dim: 256
kernel_size: 7
padding: 3
Expand Down
Loading
Loading