Skip to content

Commit

Permalink
Will test this later
Browse files Browse the repository at this point in the history
  • Loading branch information
peregrineshahin committed Jul 1, 2024
1 parent 9fbd700 commit 999899b
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 40 deletions.
54 changes: 37 additions & 17 deletions server/fishtest/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pyramid.security import forget, remember
from pyramid.view import forbidden_view_config, notfound_view_config, view_config
from requests.exceptions import ConnectionError, HTTPError
from urllib.parse import urljoin, urlparse, urlsplit, urlunparse, urlunsplit
from vtjson import ValidationError, union, validate

HTTP_TIMEOUT = 15.0
Expand Down Expand Up @@ -772,24 +773,31 @@ def get_valid_books():

def get_sha(branch, repo_url):
"""Resolves the git branch to sha commit"""
api_url = repo_url.replace(
"https://github.com", "https://api.github.com/repos"
).rstrip("/")
parsed_url = urlparse(repo_url)
api_url = parsed_url._replace(
netloc="api.github.com", path=f"/repos{parsed_url.path}"
)._replace(scheme="https")
api_url = urlunparse(api_url)
try:
commit = requests.get(api_url + "/commits/" + branch).json()
except:
response = requests.get(urljoin(api_url, f"commits/{branch}"))
response.raise_for_status()
commit = response.json()
except requests.RequestException:
raise Exception("Unable to access developer repository")
if "sha" in commit:
return commit["sha"], commit["commit"]["message"].split("\n")[0]
else:
return "", ""

# Extract the SHA and commit message
sha = commit.get("sha", "")
message = commit.get("commit", {}).get("message", "").split("\n")[0]

return sha, message


def get_nets(commit_sha, repo_url):
"""Get the nets from evaluate.h or ucioption.cpp in the repo"""
api_url = repo_url.replace(
"https://github.com", "https://raw.githubusercontent.com"
).rstrip("/")
urlparts = urlsplit(repo_url)
raw_netloc = urlparts.netloc.replace("github.com", "raw.githubusercontent.com")
raw_url_parts = urlparts._replace(netloc=raw_netloc)
api_url = urlunsplit(raw_url_parts)
try:
nets = []
pattern = re.compile("nn-[a-f0-9]{12}.nnue")
Expand Down Expand Up @@ -1179,13 +1187,25 @@ def new_run_message(request, run):


def get_master_sha(repo_url):
"""Resolves the master branch to its sha commit"""
try:
repo_url = repo_url.rstrip("/") + "/commits/master"
response = requests.get(repo_url).json()
if "commit" not in response:
urlparts = urlsplit(repo_url)
urlparts = urlparts._replace(
netloc=urlparts.netloc.replace("github.com", "api.github.com")
)
api_url = urlunsplit(urlparts)
get_url = urljoin(api_url, f"repos{urlparts.path}/commits/master")

# Fetch the commit information
response = requests.get(get_url)
response.raise_for_status()
commit = response.json()

if "commit" not in commit:
raise Exception("Cannot find branch in repository")
return response["sha"]
except Exception as e:

return commit["sha"]
except requests.RequestException as e:
raise Exception("Unable to access repository") from e


Expand Down
15 changes: 8 additions & 7 deletions server/utils/clone_fish.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests
from bson.binary import Binary
from pymongo import ASCENDING, MongoClient
from urllib.parse import urljoin

# fish_host = 'http://localhost:6543'
fish_host = "http://94.198.98.239" # 'http://tests.stockfishchess.org'
Expand All @@ -29,9 +30,9 @@ def main():
in_sync = False
loaded = {}
while True:
pgn_list = requests.get(
fish_host.rstrip("/") + "/api/pgn_100/" + str(skip)
).json()
api_url = urljoin(fish_host, f"/api/pgn_100/{skip}")
pgn_list = requests.get(api_url).json()

for pgn_file in pgn_list:
print(pgn_file)
if pgndb.find_one({"run_id": pgn_file}):
Expand All @@ -43,11 +44,11 @@ def main():
run_id = pgn_file.split("-")[0]
if not runs.find_one({"_id": run_id}):
print("New run: " + run_id)
run = requests.get(
fish_host.rstrip("/") + "/api/get_run/" + run_id
).json()
run_url = urljoin(fish_host, f"/api/get_run/{run_id}")
run = requests.get(api_url).json()
runs.insert(run)
pgn = requests.get(fish_host.rstrip("/") + "/api/pgn/" + pgn_file)
pgn_url = urljoin(fish_host, f"/api/pgn/{pgn_file}")
pgn = requests.get(api_url)
pgndb.insert(
dict(pgn_bz2=Binary(bz2.compress(pgn.content)), run_id=pgn_file)
)
Expand Down
11 changes: 5 additions & 6 deletions worker/games.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from datetime import datetime, timedelta, timezone
from pathlib import Path
from queue import Empty, Queue
from urllib.parse import urljoin
from zipfile import ZipFile

import requests
Expand Down Expand Up @@ -287,7 +288,7 @@ def required_nets_from_source():


def download_net(remote, testing_dir, net):
url = remote.rstrip("/") + "/api/nn/" + net
url = urljoin(remote, f"/api/nn/{net}")
print("Downloading {}".format(net))
r = requests_get(url, allow_redirects=True, timeout=HTTP_TIMEOUT)
(testing_dir / net).write_bytes(r.content)
Expand Down Expand Up @@ -672,7 +673,7 @@ def setup_engine(
tmp_dir = Path(tempfile.mkdtemp(dir=worker_dir))

try:
item_url = github_api(repo_url).rstrip("/") + "/zipball/" + sha
item_url = urljoin(github_api(repo_url), f"/zipball/{sha}")
print("Downloading {}".format(item_url))
blob = requests_get(item_url).content
file_list = unzip(blob, tmp_dir)
Expand Down Expand Up @@ -1023,9 +1024,7 @@ def shorten_hash(match):
update_succeeded = False
for _ in range(5):
try:
response = send_api_post_request(
remote.rstrip("/") + "/api/update_task", result
)
requests.post(urljoin(remote, "/api/update_task"), json=result)
if "error" in response:
break
except Exception as e:
Expand Down Expand Up @@ -1071,7 +1070,7 @@ def launch_cutechess(
):
if spsa_tuning:
# Request parameters for next game.
req = send_api_post_request(remote.rstrip("/") + "/api/request_spsa", result)
req = send_api_post_request(urljoin(remote, "/api/request_spsa"), result)
if "error" in req:
raise WorkerException(req["error"])

Expand Down
2 changes: 1 addition & 1 deletion worker/sri.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"__version": 239, "updater.py": "Mg+pWOgGA0gSo2TuXuuLCWLzwGwH91rsW1W3ixg3jYauHQpRMtNdGnCfuD1GqOhV", "worker.py": "+ubGHk3rIV0ILVhg1dxpsOkRF8GUaO5otPVnAyw8kKKq9Rqzksv02xj6wjYpSTmA", "games.py": "6vKH51UtL56oNvA539hLXRzgE1ADXy3QZNJohoK94RntM72+iMancSJZHaNjEb5+"}
{"__version": 239, "updater.py": "Mg+pWOgGA0gSo2TuXuuLCWLzwGwH91rsW1W3ixg3jYauHQpRMtNdGnCfuD1GqOhV", "worker.py": "mi75Jx7bQEyqkcd7BhvWFYqbOM2nGJisludtEom/MaKUcyRiENdWTk2pgpOu87BE", "games.py": "Q/h9MKhaqKNdP9dnarrxDpr8r8GK/SYv46iL1ooDmdPYsM4XboEzx6zVC/4rWzMy"}
17 changes: 8 additions & 9 deletions worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from datetime import datetime, timedelta, timezone
from functools import partial
from pathlib import Path
from urllib.parse import urljoin

# Fall back to the provided packages if missing in the local system.

Expand Down Expand Up @@ -341,7 +342,7 @@ def verify_credentials(remote, username, password, cached):
payload = {"worker_info": {"username": username}, "password": password}
try:
req = send_api_post_request(
remote.rstrip("/") + "/api/request_version", payload, quiet=True
urljoin(remote, "/api/request_version"), payload, quiet=True
)
except:
return None # network problem (unrecoverable)
Expand Down Expand Up @@ -1202,7 +1203,7 @@ def heartbeat(worker_info, password, remote, current_state):
continue
try:
req = send_api_post_request(
remote.rstrip("/") + "/api/beat", payload, quiet=True
urljoin(remote, "/api/beat"), payload, quiet=True
)
except Exception as e:
print("Exception calling heartbeat:\n", e, sep="", file=sys.stderr)
Expand Down Expand Up @@ -1324,9 +1325,7 @@ def verify_worker_version(remote, username, password):
print("Verify worker version...")
payload = {"worker_info": {"username": username}, "password": password}
try:
req = send_api_post_request(
remote.rstrip("/") + "/api/request_version", payload
)
req = send_api_post_request(urljoin(remote, "/api/request_version"), payload)
except WorkerException:
return None # the error message has already been written
if "error" in req:
Expand Down Expand Up @@ -1393,7 +1392,7 @@ def fetch_and_handle_task(
print("Fetching task...")
payload = {"worker_info": worker_info, "password": password}
try:
req = send_api_post_request(remote.rstrip("/") + "/api/request_task", payload)
req = send_api_post_request(urljoin(remote, "/api/request_task"), payload)
except WorkerException:
return False # error message has already been printed

Expand Down Expand Up @@ -1435,7 +1434,7 @@ def fetch_and_handle_task(
success = False
message = ""
server_message = ""
api = remote.rstrip("/") + "/api/failed_task"
api = urljoin(remote, "/api/failed_task")
pgn_file = [None]
try:
run_games(
Expand All @@ -1456,7 +1455,7 @@ def fetch_and_handle_task(
except RunException as e:
message = str(e)
server_message = message
api = remote.rstrip("/") + "/api/stop_run"
api = urljoin(remote, "/api/stop_run")
except WorkerException as e:
message = str(e)
server_message = message
Expand Down Expand Up @@ -1505,7 +1504,7 @@ def fetch_and_handle_task(
)
)
req = send_api_post_request(
remote.rstrip("/") + "/api/upload_pgn", payload
urljoin(remote, "/api/upload_pgn"), payload
)
except Exception as e:
print(
Expand Down

0 comments on commit 999899b

Please sign in to comment.