Skip to content

Commit

Permalink
remove qat and ptq func
Browse files Browse the repository at this point in the history
  • Loading branch information
kzaleskaa committed Jun 6, 2024
1 parent b2be87c commit 5c0c739
Showing 1 changed file with 0 additions and 46 deletions.
46 changes: 0 additions & 46 deletions src/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
log_hyperparameters,
task_wrapper,
)
from src.models.unet_module import UNETLitModule
from src.data.depth_datamodule import DepthDataModule

log = RankedLogger(__name__, rank_zero_only=True)

Expand All @@ -64,50 +62,6 @@ def calibration(model, dataloader, num_iterations):
return model


def ptq(cfg, model, datamodule):
quantizer = QATQuantizer(
model.net,
torch.randn(1, 3, 52, 52),
work_dir=cfg.work_dir,
config=cfg.config
)
model.net = quantizer.quantize()
model.net.apply(torch.quantization.disable_fake_quant)
model.net.apply(torch.quantization.enable_observer)

calibration(model.net, datamodule.train_dataloader(), 50)

model.net.apply(torch.quantization.disable_observer)
model.net.apply(torch.quantization.enable_fake_quant)

with torch.no_grad():
model.net.eval()
model.net.cpu()
model.net = torch.quantization.convert(model.net)

return model

def qat(cfg, model, datamodule, trainer):
quantizer = QATQuantizer(
model.net,
torch.randn(1, 3, 52, 52),
work_dir=cfg.qat.work_dir,
config=cfg.qat.config
)
model.net = quantizer.quantize()

# quantization-aware training
trainer.fit(model=model, datamodule=datamodule)

with torch.no_grad():
model.net.eval()
model.net.cpu()

model.net = quantizer.convert(model.net)

return model


@task_wrapper
def quantization(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
assert cfg.ckpt_path
Expand Down

0 comments on commit 5c0c739

Please sign in to comment.