Skip to content

Commit

Permalink
XTC: added xtc_threshold_max parameter as an upper limit
Browse files Browse the repository at this point in the history
* 1.0 by default, so doesn't affect anything
* can be used to eliminate tokens within a range if you are sure that some top tokens are not clichéd (in finetuned models, for example)
  • Loading branch information
MaggotHATE committed Aug 25, 2024
1 parent 44eb8c9 commit 6b69d0b
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 5 deletions.
6 changes: 3 additions & 3 deletions base/llama-addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
#include <type_traits>
#include <unordered_map>

void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array * candidates, float xtc_probability, float xtc_threshold, bool xtc_probability_once, int xtc_min, size_t min_keep) {
if (xtc_probability <= 0.0f || xtc_threshold <= 0.0f || candidates->size <= 1) {
void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array * candidates, float xtc_probability, float xtc_threshold, float xtc_threshold_max, bool xtc_probability_once, int xtc_min, size_t min_keep) {
if (xtc_probability <= 0.0f || xtc_threshold <= 0.0f || xtc_threshold_max == xtc_threshold || xtc_min < 1 || candidates->size <= 1) {
return;
}

Expand All @@ -52,7 +52,7 @@ void llama_sample_xtc_addon(struct llama_context * ctx, llama_token_data_array *
size_t removed = 0;
// going through all candidates to correctly trigget the effect
for (size_t i = 0; i < candidates->size; ++i) {
if (candidates->data[i].p >= xtc_threshold) {
if (candidates->data[i].p >= xtc_threshold && candidates->data[i].p <= xtc_threshold_max) {
if (id_first == -1) {
id_first = i;
++removed;
Expand Down
1 change: 1 addition & 0 deletions base/llama-addon.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
llama_token_data_array * candidates,
float xtc_probability,
float xtc_threshold,
float xtc_threshold_max,
bool xtc_probability_once,
int xtc_min,
size_t min_keep);
Expand Down
3 changes: 2 additions & 1 deletion base/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ void sampler_queue(
const float p_step = params.p_step;
const float xtc_probability = params.xtc_probability;
const float xtc_threshold = params.xtc_threshold;
const float xtc_threshold_max = params.xtc_threshold_max;
const float xtc_probability_once = params.xtc_probability_once;
const float xtc_min = params.xtc_min;
const std::string samplers_sequence = params.samplers_sequence;
Expand All @@ -158,7 +159,7 @@ void sampler_queue(
case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case 'm': llama_sample_min_p_addon (ctx_main, &cur_p, min_p, min_keep); break;
case 's': llama_sample_p_step_addon (ctx_main, &cur_p, p_step, min_keep); break;
case 'x': llama_sample_xtc_addon (ctx_main, &cur_p, xtc_probability, xtc_threshold, xtc_probability_once, xtc_min, min_keep); break;
case 'x': llama_sample_xtc_addon (ctx_main, &cur_p, xtc_probability, xtc_threshold, xtc_threshold_max, xtc_probability_once, xtc_min, min_keep); break;
case 't': {
if (dynatemp_range>0)
{
Expand Down
1 change: 1 addition & 0 deletions base/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ typedef struct llama_sampling_params {
int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size)
float xtc_probability = 0.5; // probability of removing a top token
float xtc_threshold = 0.1; // minimum tokens probablitity for this to run
float xtc_threshold_max = 1.0; // maximum tokens probablitity for this to run
bool xtc_probability_once = false; // should we calculate chances one or for each token
int xtc_min = 2; // minimum number of penalizeable tokens
std::string samplers_sequence = "kfypmts"; // top_k, tail_free, typical_p, top_p, min_p, temp, p_step
Expand Down
2 changes: 1 addition & 1 deletion chat_plain.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ class chat
case 'f': result += name_tfs_z; if (params.sparams.tfs_z != paramsDefault.sparams.tfs_z) result += std::format("={:.3f}",params.sparams.tfs_z); break;
case 'y': result += name_typical_p; if (params.sparams.typical_p != paramsDefault.sparams.typical_p) result += std::format("={:.3f}",params.sparams.typical_p); break;
case 's': result += name_p_step; if (params.sparams.p_step != paramsDefault.sparams.p_step) result += std::format("={:.3f}",params.sparams.p_step); break;
case 'x': result += std::format("xtc={:.3f}-{:.03f}%",params.sparams.xtc_threshold,params.sparams.xtc_probability); break;
case 'x': result += std::format("xtc={:.3f}-{:.3f}({}%/{})",params.sparams.xtc_threshold,params.sparams.xtc_threshold_max,params.sparams.xtc_probability*100,params.sparams.xtc_min); if (params.sparams.xtc_probability_once) result += "once"; else result += "each"; break;
case 'p': result += name_top_p; if (params.sparams.top_p != paramsDefault.sparams.top_p) result += std::format("={:.3f}",params.sparams.top_p); break;
case 'm': result += name_min_p; if (params.sparams.min_p != paramsDefault.sparams.min_p) result += std::format("={:.3f}",params.sparams.min_p); break;
case 't': {
Expand Down
1 change: 1 addition & 0 deletions include/jsonParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ static void getParamsFromJson(nlohmann::json& config, gpt_params& params, bool h
if (checkJNum(config, "tfs_z")) params.sparams.tfs_z = config["tfs_z"];
if (checkJNum(config, "xtc_probability")) params.sparams.xtc_probability = config["xtc_probability"];
if (checkJNum(config, "xtc_threshold")) params.sparams.xtc_threshold = config["xtc_threshold"];
if (checkJNum(config, "xtc_threshold_max")) params.sparams.xtc_threshold_max = config["xtc_threshold_max"];
if (checkJNum(config, "xtc_min")) params.sparams.xtc_min = config["xtc_min"];
if (checkJBool(config, "xtc_probability_once")) params.sparams.xtc_probability_once = config["xtc_probability_once"];

Expand Down
2 changes: 2 additions & 0 deletions thread_chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,7 @@ struct configurableChat{
} else if (params.sparams.p_step != paramsDefault.sparams.p_step) modelConfig[model]["p_step"] = params.sparams.p_step;
if (params.sparams.xtc_probability != paramsDefault.sparams.xtc_probability) modelConfig[model]["xtc_probability"] = params.sparams.xtc_probability;
if (params.sparams.xtc_threshold != paramsDefault.sparams.xtc_threshold) modelConfig[model]["xtc_threshold"] = params.sparams.xtc_threshold;
if (params.sparams.xtc_threshold_max != paramsDefault.sparams.xtc_threshold_max) modelConfig[model]["xtc_threshold_max"] = params.sparams.xtc_threshold_max;
if (params.sparams.xtc_min != paramsDefault.sparams.xtc_min) modelConfig[model]["xtc_min"] = params.sparams.xtc_min;
if (params.sparams.xtc_probability_once != paramsDefault.sparams.xtc_probability_once) modelConfig[model]["xtc_probability_once"] = params.sparams.xtc_probability_once;
// penalties
Expand Down Expand Up @@ -1424,6 +1425,7 @@ struct configurableChat{
if (params.sparams.p_step != paramsDefault.sparams.p_step) newCard["p_step"] = params.sparams.p_step;
if (params.sparams.xtc_probability != paramsDefault.sparams.xtc_probability) newCard["xtc_probability"] = params.sparams.xtc_probability;
if (params.sparams.xtc_threshold != paramsDefault.sparams.xtc_threshold) newCard["xtc_threshold"] = params.sparams.xtc_threshold;
if (params.sparams.xtc_threshold_max != paramsDefault.sparams.xtc_threshold_max) newCard["xtc_threshold_max"] = params.sparams.xtc_threshold_max;
if (params.sparams.xtc_min != paramsDefault.sparams.xtc_min) newCard["xtc_min"] = params.sparams.xtc_min;
if (params.sparams.xtc_probability_once != paramsDefault.sparams.xtc_probability_once) newCard["xtc_probability_once"] = params.sparams.xtc_probability_once;
//penalties
Expand Down

0 comments on commit 6b69d0b

Please sign in to comment.