Skip to content

Commit

Permalink
launcher for rope
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Jul 20, 2023
1 parent 39dc1a4 commit e85557f
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,10 @@ def getfilename(var, text):

context_var = ctk.IntVar()

customrope_var = ctk.IntVar()
customrope_scale = ctk.StringVar(value="1.0")
customrope_base = ctk.StringVar(value="10000")

model_var = ctk.StringVar()
lora_var = ctk.StringVar()
lora_base_var = ctk.StringVar()
Expand Down Expand Up @@ -904,6 +908,19 @@ def togglemiro(a,b,c):
# context size
makeslider(tokens_tab, "Context Size:",contextsize_text, context_var, 0, 4, 20, set=2)


customrope_scale_entry, customrope_scale_label = makelabelentry(tokens_tab, "RoPE Scale:", customrope_scale)
customrope_base_entry, customrope_base_label = makelabelentry(tokens_tab, "RoPE Base:", customrope_base)
def togglerope(a,b,c):
items = [customrope_scale_label, customrope_scale_entry,customrope_base_label, customrope_base_entry]
for idx, item in enumerate(items):
if customrope_var.get() == 1:
item.grid(row=23 + int(idx/2), column=idx%2, padx=8, stick="nw")
else:
item.grid_forget()
makecheckbox(tokens_tab, "Custom RoPE Config", variable=customrope_var, row=22, command=togglerope)
togglerope(1,1,1)

# Model Tab
model_tab = tabcontent["Model"]

Expand Down Expand Up @@ -996,6 +1013,9 @@ def export_vars():
args.mirostat = [int(mirostat_var.get()), float(mirostat_tau.get()), float(mirostat_eta.get())] if usemirostat.get()==1 else None
args.contextsize = int(contextsize_text[context_var.get()])

if customrope_var.get()==1:
args.ropeconfig = [float(customrope_scale.get()),float(customrope_base.get())]

args.model_param = None if model_var.get() == "" else model_var.get()
args.lora = None if lora_var.get() == "" else ([lora_var.get()] if lora_base_var.get()=="" else [lora_var.get(), lora_base_var.get()])

Expand Down Expand Up @@ -1046,6 +1066,15 @@ def import_vars(dict):

if dict["contextsize"]:
context_var.set(contextsize_text.index(str(dict["contextsize"])))

if dict["ropeconfig"] and len(dict["ropeconfig"])>1:
if dict["ropeconfig"][0]>0:
customrope_var.set(1)
customrope_scale.set(str(dict["ropeconfig"][0]))
customrope_base.set(str(dict["ropeconfig"][1]))
else:
customrope_var.set(0)

if dict["blasbatchsize"]:
blas_size_var.set(blasbatchsize_values.index(str(dict["blasbatchsize"])))
if dict["forceversion"]:
Expand Down

0 comments on commit e85557f

Please sign in to comment.