Skip to content

Commit

Permalink
eval and GPTQ work (pytorch#304)
Browse files Browse the repository at this point in the history
* eval and GPTQ work

Summary: fleshing out the eval code so it works reliably, adding ci,
adding gptq. fixed defaults for eval/gptq so they generally working
meaningfully without being specified. note, we need a better way to
save/load gptq models since they take so long to quantize. I tried using
.so but it doesn't seem to work reliably. also added eval and gptq to
ci.

Test Plan:

python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \
  --device cuda --dtype bfloat16

python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \
    --dtype bfloat16 --device cuda \
    --quant '{"linear:int4" : {"groupsize" : 32} }' \
    --compile

python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \
    --dtype bfloat16 --device cuda \
    --quant '{"linear:int4" : {"groupsize" : 32} }'

python eval.py --checkpoint-path checkpoints/$MODEL_REPO/model.pth \
    --dtype bfloat16 --device cuda \
    --quant '{"linear:int4-gptq" : {"groupsize" : 32} }'

...running...

Reviewers:

Subscribers:

Tasks:

Tags:

* fix language in help doc

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* declare scales_and_zeros

---------

Co-authored-by: HDCharles <charlesdavidhernandez@gmail.com>
  • Loading branch information
2 people authored and malfet committed Jul 17, 2024
1 parent 6c2fc8c commit 85132f2
Show file tree
Hide file tree
Showing 10 changed files with 576 additions and 51 deletions.
74 changes: 66 additions & 8 deletions .ci/scripts/validate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,19 @@ function generate_compiled_model_output() {
cat "$MODEL_DIR/output_eager"
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
cat "$MODEL_DIR/output_compiled"
fi

echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
cat "$MODEL_DIR/output_eager"
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
cat "$MODEL_DIR/output_compiled"
echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"
python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1
cat "$MODEL_DIR/output_eager"
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
cat "$MODEL_DIR/output_compiled"
if [ "$TARGET_DEVICE" == "cuda" ]; then
python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4-gptq" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1
cat "$MODEL_DIR/output_compiled"
fi
fi
done
}

Expand Down Expand Up @@ -179,6 +183,12 @@ function generate_aoti_model_output() {
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
cat "$MODEL_DIR/output_aoti"

if [ "$TARGET_DEVICE" == "cuda" ]; then
python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4-gptq" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
cat "$MODEL_DIR/output_aoti"
fi
fi
done
}
Expand All @@ -194,6 +204,48 @@ function generate_executorch_model_output() {
cat "$MODEL_DIR/output_et"
}

function eval_model() {
local CHECKPOINT_PATH="$1"
local TARGET_DEVICE="${2:-cpu}"
local MODEL_DIR="${CHECKPOINT_PATH%/*}"
local MODEL_NAME=$(basename "$CHECKPOINT_PATH" | sed 's/\.[^.]*$//')

for DTYPE in float32 bfloat16 float16; do
echo ""############### Run eval with torch.compile for dtype $DTYPE "###############"
echo ""
echo "******************************************"
echo "************** non-quantized *************"
echo "******************************************"
python -W ignore eval.py --compile --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" > "$MODEL_DIR/eval" || exit 1
cat "$MODEL_DIR/eval"
# extract perplexity number and compare with a constant
local REF_PERPLEXITY=100000
PERPLEXITY=cat "$MODEL_DIR/eval" | tail -n 1 log | awk -F '[, ]' '{print $4}'
# == 1 meaning the check succeeded
if [ "$(echo "$PERPLEXITY >= $REF_PERPLEXITY" | bc)" == 1]; then
echo "perplexity checking failed for non-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE"
else
echo "perplexity checking succeeded for non-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE"
fi;

echo "******************************************"
echo "******** INT4 group-wise quantized *******"
echo "******************************************"

QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}'
python -W ignore eval.py --compile --dtype ${DTYPE} --quant $QUANT_OPTIONS --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" > "$MODEL_DIR/eval" || exit 1
cat "$MODEL_DIR/eval"
local REF_PERPLEXITY=100000
PERPLEXITY=cat "$MODEL_DIR/eval" | tail -n 1 log | awk -F '[, ]' '{print $4}'
# == 1 meaning the check succeeded
if [ "$(echo "$PERPLEXITY >= $REF_PERPLEXITY" | bc)" == 1]; then
echo "perplexity checking failed for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
else
echo "perplexity checking succeeded for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS"
fi;
done
}

function run_compile() {
generate_compiled_model_output "$CHECKPOINT_PATH" "$TARGET_DEVICE" || exit 1
}
Expand All @@ -210,6 +262,9 @@ function run_executorch() {
fi
}

function run_eval(){
eval_model "$CHECKPOINT_PATH" "$TARGET_DEVICE" || exit 1
}

CHECKPOINT_PATH="$1"
TARGET_DEVICE="${2:-cpu}"
Expand All @@ -229,6 +284,9 @@ if [ "$#" -gt 2 ]; then
"executorch")
run_executorch || exit 1
;;
"eval")
run_eval || exit 1
;;
*)
echo "Unknown argument: $arg" >&2
exit 1
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/periodic.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,6 @@ jobs:
echo "::group::Run inference"
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "compile"
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "aoti"
echo "::group::Run eval"
bash .ci/scripts/validate.sh "./checkpoints/${REPO_NAME}/model.pth" "cuda" "eval"
echo "::endgroup::"
Loading

0 comments on commit 85132f2

Please sign in to comment.