Skip to content

Commit

Permalink
create demo folder
Browse files Browse the repository at this point in the history
  • Loading branch information
FemkeBakker committed Jul 9, 2024
1 parent 3a104d7 commit 1758cf7
Show file tree
Hide file tree
Showing 182 changed files with 41,777 additions and 2,127 deletions.
1,070 changes: 317 additions & 753 deletions PredictionAnalysis/BaselineAnalysis.ipynb

Large diffs are not rendered by default.

230 changes: 226 additions & 4 deletions PredictionAnalysis/FinetuningExperiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -396,18 +396,240 @@
"\n",
"predictions = pd.read_pickle(f\"{cf.output_path}/predictionsFinal/finetuning/2epochs/GEITjeFirst200Last0Predictions.pkl\")\n",
"predictions2 = pd.read_pickle(f\"{cf.output_path}/predictionsFinal/finetuning/2epochs/LlamaFirst200Last0Predictions.pkl\")\n",
"predictions3 = pd.read_pickle(f\"{cf.output_path}/predictionsFinal/finetuning/2epochs/MistralTry2epochsFirst200Last0Predictions.pkl\")\n",
"predictions3 = pd.read_pickle(f\"{cf.output_path}/predictionsFinal/finetuning/2epochs/MistralFirst200Last0Predictions.pkl\")\n",
"epoch2 = pd.concat([predictions, predictions2, predictions3])\n",
"\n",
"predictions = pd.read_pickle(f\"{cf.output_path}/predictionsFinal/finetuning/3epochs/GEITjeFirst200Last0Predictions.pkl\")\n",
"predictions2 = pd.read_pickle(f\"{cf.output_path}/predictionsFinal/finetuning/3epochs/LlamaFirst200Last0Predictions.pkl\")\n",
"predictions3 = pd.read_pickle(f\"{cf.output_path}/predictionsFinal/finetuning/3epochs/MistralTry3epochsFirst200Last0Predictions.pkl\")\n",
"predictions3 = pd.read_pickle(f\"{cf.output_path}/predictionsFinal/finetuning/3epochs/MistralFirst200Last0Predictions.pkl\")\n",
"epoch3 = pd.concat([predictions, predictions2, predictions3])\n",
"\n",
"combined = pd.concat([epoch0, epoch1, epoch2, epoch3])\n",
"combined = pd.concat([epoch1, epoch2, epoch3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confusion Matrix"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"FT_AmsterdamDocClassificationGEITje200T1Epochszeroshot_prompt_geitjeLlamaTokens200_0traintest_numEx0\n",
"0\n",
" precision recall f1-score support\n",
"\n",
" actualiteit 0.93 0.83 0.88 100\n",
" agenda 0.99 0.98 0.98 100\n",
" besluit 0.97 0.96 0.96 100\n",
" brief 0.99 0.92 0.95 100\n",
" factsheet 1.00 0.25 0.40 100\n",
" motie 0.89 0.99 0.94 100\n",
" onderzoeksrapport 0.55 1.00 0.71 100\n",
" raadsadres 0.89 0.98 0.93 100\n",
" raadsnotulen 1.00 0.97 0.98 100\n",
"schriftelijke vraag 1.00 0.93 0.96 100\n",
" voordracht 0.97 0.99 0.98 100\n",
"\n",
" accuracy 0.89 1100\n",
" macro avg 0.93 0.89 0.88 1100\n",
" weighted avg 0.93 0.89 0.88 1100\n",
"\n",
"FT_AmsterdamDocClassificationGEITje200T2Epochszeroshot_prompt_geitjeLlamaTokens200_0traintest_numEx0\n",
"0\n",
" precision recall f1-score support\n",
"\n",
" actualiteit 0.93 0.81 0.87 100\n",
" agenda 0.95 0.99 0.97 100\n",
" besluit 0.96 0.99 0.98 100\n",
" brief 0.98 0.93 0.95 100\n",
" factsheet 1.00 0.34 0.51 100\n",
" motie 0.90 0.98 0.94 100\n",
" onderzoeksrapport 0.58 0.98 0.73 100\n",
" raadsadres 0.89 0.99 0.94 100\n",
" raadsnotulen 1.00 0.97 0.98 100\n",
"schriftelijke vraag 1.00 0.92 0.96 100\n",
" voordracht 1.00 0.98 0.99 100\n",
"\n",
" accuracy 0.90 1100\n",
" macro avg 0.93 0.90 0.89 1100\n",
" weighted avg 0.93 0.90 0.89 1100\n",
"\n",
"FT_AmsterdamDocClassificationGEITje200T3Epochszeroshot_prompt_geitjeLlamaTokens200_0traintest_numEx0\n",
"0\n",
" precision recall f1-score support\n",
"\n",
" actualiteit 0.93 0.83 0.88 100\n",
" agenda 0.94 0.99 0.97 100\n",
" besluit 0.95 0.99 0.97 100\n",
" brief 0.99 0.93 0.96 100\n",
" factsheet 1.00 0.38 0.55 100\n",
" motie 0.91 0.98 0.94 100\n",
" onderzoeksrapport 0.60 0.98 0.74 100\n",
" raadsadres 0.90 0.98 0.94 100\n",
" raadsnotulen 1.00 0.97 0.98 100\n",
"schriftelijke vraag 0.99 0.93 0.96 100\n",
" voordracht 1.00 0.98 0.99 100\n",
"\n",
" accuracy 0.90 1100\n",
" macro avg 0.93 0.90 0.90 1100\n",
" weighted avg 0.93 0.90 0.90 1100\n",
"\n",
"FT_AmsterdamDocClassificationLlama200T1Epochszeroshot_prompt_mistral_llamaLlamaTokens200_0traintest_numEx0\n",
"1\n",
" precision recall f1-score support\n",
"\n",
" PredictionError 0.00 0.00 0.00 0\n",
" actualiteit 0.93 0.67 0.78 100\n",
" agenda 0.84 0.97 0.90 100\n",
" besluit 0.96 0.93 0.94 100\n",
" brief 0.96 0.91 0.93 100\n",
" factsheet 0.96 0.26 0.41 100\n",
" motie 0.89 0.95 0.92 100\n",
" onderzoeksrapport 0.54 0.97 0.69 100\n",
" raadsadres 0.85 0.98 0.91 100\n",
" raadsnotulen 1.00 0.97 0.98 100\n",
"schriftelijke vraag 0.99 0.93 0.96 100\n",
" voordracht 0.99 0.99 0.99 100\n",
"\n",
" accuracy 0.87 1100\n",
" macro avg 0.83 0.79 0.79 1100\n",
" weighted avg 0.90 0.87 0.86 1100\n",
"\n",
"FT_AmsterdamDocClassificationLlama200T2Epochszeroshot_prompt_mistral_llamaLlamaTokens200_0traintest_numEx0\n",
"1\n",
" precision recall f1-score support\n",
"\n",
" PredictionError 0.00 0.00 0.00 0\n",
" actualiteit 0.95 0.71 0.81 100\n",
" agenda 0.87 0.97 0.92 100\n",
" besluit 0.94 0.96 0.95 100\n",
" brief 0.94 0.95 0.95 100\n",
" factsheet 1.00 0.18 0.31 100\n",
" motie 0.93 0.94 0.94 100\n",
" onderzoeksrapport 0.51 0.96 0.66 100\n",
" raadsadres 0.86 0.97 0.91 100\n",
" raadsnotulen 1.00 0.97 0.98 100\n",
"schriftelijke vraag 0.99 0.92 0.95 100\n",
" voordracht 1.00 0.98 0.99 100\n",
"\n",
" accuracy 0.86 1100\n",
" macro avg 0.83 0.79 0.78 1100\n",
" weighted avg 0.91 0.86 0.85 1100\n",
"\n",
"FT_AmsterdamDocClassificationLlama200T3Epochszeroshot_prompt_mistral_llamaLlamaTokens200_0traintest_numEx0\n",
"1\n",
" precision recall f1-score support\n",
"\n",
" PredictionError 0.00 0.00 0.00 0\n",
" actualiteit 0.92 0.70 0.80 100\n",
" agenda 0.87 0.97 0.92 100\n",
" besluit 0.94 0.95 0.95 100\n",
" brief 0.95 0.93 0.94 100\n",
" factsheet 1.00 0.21 0.35 100\n",
" motie 0.92 0.96 0.94 100\n",
" onderzoeksrapport 0.52 0.97 0.68 100\n",
" raadsadres 0.85 0.97 0.91 100\n",
" raadsnotulen 1.00 0.97 0.98 100\n",
"schriftelijke vraag 1.00 0.92 0.96 100\n",
" voordracht 0.99 0.98 0.98 100\n",
"\n",
" accuracy 0.87 1100\n",
" macro avg 0.83 0.79 0.78 1100\n",
" weighted avg 0.91 0.87 0.85 1100\n",
"\n",
"FT_AmsterdamDocClassificationMistral200T1Epochszeroshot_prompt_mistral_llamaLlamaTokens200_0traintest_numEx0\n",
"0\n",
" precision recall f1-score support\n",
"\n",
" actualiteit 0.90 0.83 0.86 100\n",
" agenda 0.97 0.99 0.98 100\n",
" besluit 0.98 0.96 0.97 100\n",
" brief 0.95 0.96 0.96 100\n",
" factsheet 0.94 0.31 0.47 100\n",
" motie 0.93 0.94 0.94 100\n",
" onderzoeksrapport 0.56 0.98 0.72 100\n",
" raadsadres 0.90 0.99 0.94 100\n",
" raadsnotulen 1.00 0.98 0.99 100\n",
"schriftelijke vraag 1.00 0.93 0.96 100\n",
" voordracht 1.00 0.98 0.99 100\n",
"\n",
" accuracy 0.90 1100\n",
" macro avg 0.92 0.90 0.89 1100\n",
" weighted avg 0.92 0.90 0.89 1100\n",
"\n",
"FT_AmsterdamDocClassificationMistral200T2Epochszeroshot_prompt_mistral_llamaLlamaTokens200_0traintest_numEx0\n",
"0\n",
" precision recall f1-score support\n",
"\n",
" actualiteit 0.90 0.85 0.88 100\n",
" agenda 0.97 0.99 0.98 100\n",
" besluit 0.97 0.95 0.96 100\n",
" brief 0.96 0.96 0.96 100\n",
" factsheet 1.00 0.39 0.56 100\n",
" motie 0.94 0.95 0.95 100\n",
" onderzoeksrapport 0.61 0.97 0.75 100\n",
" raadsadres 0.91 1.00 0.95 100\n",
" raadsnotulen 1.00 0.99 0.99 100\n",
"schriftelijke vraag 0.96 0.94 0.95 100\n",
" voordracht 0.99 0.98 0.98 100\n",
"\n",
" accuracy 0.91 1100\n",
" macro avg 0.93 0.91 0.90 1100\n",
" weighted avg 0.93 0.91 0.90 1100\n",
"\n",
"FT_AmsterdamDocClassificationMistral200T3Epochszeroshot_prompt_mistral_llamaLlamaTokens200_0traintest_numEx0\n",
"0\n",
" precision recall f1-score support\n",
"\n",
" actualiteit 0.90 0.88 0.89 100\n",
" agenda 0.95 0.99 0.97 100\n",
" besluit 0.96 0.96 0.96 100\n",
" brief 0.97 0.97 0.97 100\n",
" factsheet 1.00 0.40 0.57 100\n",
" motie 0.97 0.95 0.96 100\n",
" onderzoeksrapport 0.61 0.95 0.74 100\n",
" raadsadres 0.91 0.98 0.94 100\n",
" raadsnotulen 1.00 0.98 0.99 100\n",
"schriftelijke vraag 0.96 0.94 0.95 100\n",
" voordracht 0.98 0.98 0.98 100\n",
"\n",
" accuracy 0.91 1100\n",
" macro avg 0.93 0.91 0.90 1100\n",
" weighted avg 0.93 0.91 0.90 1100\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"def cf_ma(df):\n",
" run_ids = sorted(list(set(df['run_id'])))\n",
" error_names = ['NoPredictionInOutput', 'MultiplePredictionErrorInFormatting','NoPredictionFormat', 'MultiplePredictionErrorInOutput']\n",
"\n",
" for run_id in run_ids:\n",
" print(run_id)\n",
" subdf = df.loc[df['run_id']==run_id]\n",
" print(len(subdf.loc[subdf['prediction'].isin(error_names)]))\n",
" subdf.loc[subdf['prediction'].isin(error_names),'prediction'] = 'PredictionError'\n",
"\n",
" y_pred = subdf['prediction']\n",
" y_true = subdf['label']\n",
" print(classification_report(y_true, y_pred))\n",
"\n",
"cf_ma(combined)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -417,7 +639,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down
Loading

0 comments on commit 1758cf7

Please sign in to comment.