diff --git a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts index f358e0823c..248aa5eee1 100644 --- a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts +++ b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts @@ -88,9 +88,10 @@ const LORA_MODEL_REF_INPUT_STREAM = 'lora_model_ref_in'; const LORA_MODEL_ID_TO_LOAD_INPUT_STREAM = 'lora_model_id_to_load_in'; const DEFAULT_MAX_TOKENS = 512; -const DEFAULT_TOP_K = 1; +const DEFAULT_TOP_K = 40; const DEFAULT_TOP_P = 1.0; -const DEFAULT_TEMPERATURE = 1.0; +const DEFAULT_TEMPERATURE = 0.8; +const DEFAULT_RANDOM_SEED = 0; const DEFAULT_SAMPLER_TYPE = SamplerParameters.Type.TOP_P; const DEFAULT_NUM_RESPONSES = 1; @@ -407,24 +408,14 @@ export class LlmInference extends TaskRunner { } if ('topK' in options) { this.samplerParams.setK(options.topK ?? DEFAULT_TOP_K); - if (options.topK && !options.randomSeed) { - console.warn( - `'topK' option ignored; it requires randomSeed to be set.`, - ); - } } if ('temperature' in options) { this.samplerParams.setTemperature( options.temperature ?? DEFAULT_TEMPERATURE, ); - if (options.temperature && !options.randomSeed) { - console.warn( - `'temperature' option ignored; it requires randomSeed to be set.`, - ); - } } - if (options.randomSeed) { - this.samplerParams.setSeed(options.randomSeed); + if ('randomSeed' in options) { + this.samplerParams.setSeed(options.randomSeed ?? DEFAULT_RANDOM_SEED); } if ('loraRanks' in options) { this.options.setLoraRanksList(options.loraRanks ?? []); @@ -804,6 +795,7 @@ export class LlmInference extends TaskRunner { this.samplerParams.setType(DEFAULT_SAMPLER_TYPE); this.samplerParams.setK(DEFAULT_TOP_K); this.samplerParams.setP(DEFAULT_TOP_P); + this.samplerParams.setSeed(DEFAULT_RANDOM_SEED); this.samplerParams.setTemperature(DEFAULT_TEMPERATURE); this.options.setNumResponses(DEFAULT_NUM_RESPONSES); }