Skip to content

Commit

Permalink
unified standar and modernui auth
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <mandic00@live.com>
  • Loading branch information
vladmandic committed Jan 24, 2025
1 parent 8b787ec commit 6811541
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
- **loader**: ability to run in-memory models
- **schedulers**: ability to create model-less schedulers
- **quantiation**: code refactor into dedicated module
- **Authentication**:
- perform auth check on ui startup
- unified standard and modern-ui authentication method
- cleanup auth logging
- **Fixes**:
- non-full vae decode
- send-to image transfer
Expand Down
72 changes: 72 additions & 0 deletions javascript/login.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
const loginCSS = `
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background: var(--background-fill-primary);
color: var(--body-text-color-subdued);
font-family: monospace;
z-index: 100;
`;

const loginHTML = `
<div id="loginDiv" style="margin: 15% auto; max-width: 200px; padding: 2em; background: var(--background-fill-secondary);">
<h2>Login</h2>
<label for="username" style="margin-top: 0.5em">Username</label>
<input type="text" id="loginUsername" name="username" style="width: 92%; padding: 0.5em; margin-top: 0.5em">
<label for="password" style="margin-top: 0.5em">Password</label>
<input type="text" id="loginPassword" name="password" style="width: 92%; padding: 0.5em; margin-top: 0.5em">
<div id="loginStatus" style="margin-top: 0.5em"></div>
<button type="submit" style="width: 100%; padding: 0.5em; margin-top: 0.5em; background: var(--button-primary-background-fill); color: var(--button-primary-text-color); border: var(--button-primary-border-color);">Login</button>
</div>
`;

function forceLogin() {
const form = document.createElement('form');
form.method = 'POST';
form.action = '/login';
form.id = 'loginForm';
form.style.cssText = loginCSS;
form.innerHTML = loginHTML;
document.body.appendChild(form);
const username = form.querySelector('#loginUsername');
const password = form.querySelector('#loginPassword');
const status = form.querySelector('#loginStatus');

form.addEventListener('submit', (event) => {
event.preventDefault();
const formData = new FormData(form);
formData.append('username', username.value);
formData.append('password', password.value);
console.warn('login', formData);
fetch('/login', {
method: 'POST',
body: formData,
})
.then(async (res) => {
const json = await res.json();
const txt = `${res.status}: ${res.statusText} - ${json.detail}`;
status.textContent = txt;
console.log('login', txt);
if (res.status === 200) location.reload();
})
.catch((err) => {
status.textContent = err;
console.error('login', err);
});
});
}

function loginCheck() {
fetch('/login_check', {})
.then((res) => {
if (res.status === 200) console.log('login ok');
else forceLogin();
})
.catch((err) => {
console.error('login', err);
});
}

window.onload = loginCheck;
8 changes: 6 additions & 2 deletions modules/api/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def log_and_time(req: Request, call_next):
if (cmd_opts.api_log or cmd_opts.api_only) and endpoint.startswith('/sdapi'):
if '/sdapi/v1/log' in endpoint or '/sdapi/v1/browser' in endpoint:
return res
log.info('API {user} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( # pylint: disable=consider-using-f-string, logging-format-interpolation
log.info('API user={user} code={code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( # pylint: disable=consider-using-f-string, logging-format-interpolation
user = app.tokens.get(token) if hasattr(app, 'tokens') else None,
code = res.status_code,
ver = req.scope.get('http_version', '0.0'),
Expand All @@ -69,10 +69,14 @@ def handle_exception(req: Request, e: Exception):
"body": vars(e).get('body', ''),
"errors": str(e),
}
if err['code'] == 401 and 'file=' in req.url.path: # dont spam with unauth
return JSONResponse(status_code=err['code'], content=jsonable_encoder(err))

log.error(f"API error: {req.method}: {req.url} {err}")

if not isinstance(e, HTTPException) and err['error'] != 'TypeError': # do not print backtrace on known httpexceptions
errors.display(e, 'HTTP API', [anyio, fastapi, uvicorn, starlette])
elif err['code'] == 404 or err['code'] == 401:
elif err['code'] in [404, 401, 400]:
pass
else:
log.debug(e, exc_info=True) # print stack trace
Expand Down
21 changes: 16 additions & 5 deletions modules/ui_javascript.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ def webpath(fn):
def html_head():
head = ''
main = ['script.js']
skip = ['login.js']
for js in main:
script_js = os.path.join(script_path, "javascript", js)
head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
added = []
for script in modules.scripts.list_scripts("javascript", ".js"):
if script.filename in main:
if script.filename in main or script.filename in skip:
continue
head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
added.append(script.path)
Expand All @@ -43,6 +44,14 @@ def html_body():
return body


def html_login():
fn = os.path.join(script_path, "javascript", "login.js")
with open(fn, 'r', encoding='utf8') as f:
inline = f.read()
js = f'<script type="text/javascript">{inline}</script>\n'
return js


def html_css(css: str):
def stylesheet(fn):
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
Expand Down Expand Up @@ -78,17 +87,19 @@ def stylesheet(fn):

def reload_javascript():
base_css = theme.reload_gradio_theme()
head = html_head()
css = html_css(base_css)
body = html_body()
title = '<title>SD.Next</title>'
manifest = f'<link rel="manifest" href="{webpath(os.path.join(script_path, "html", "manifest.json"))}">'
login = html_login()
js = html_head()
css = html_css(base_css)
body = html_body()

def template_response(*args, **kwargs):
res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(b'<head>', f'<head>{title}'.encode("utf8"))
res.body = res.body.replace(b'</head>', f'{head}</head>'.encode("utf8"))
res.body = res.body.replace(b'</head>', f'{manifest}</head>'.encode("utf8"))
res.body = res.body.replace(b'</head>', f'{login}</head>'.encode("utf8"))
res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
res.body = res.body.replace(b'</body>', f'{css}{body}</body>'.encode("utf8"))
lines = res.body.decode("utf8").split('\n')
for line in lines:
Expand Down

0 comments on commit 6811541

Please sign in to comment.