Skip to content

Commit

Permalink
add: add queue name to job prefixes around the globe - this is totall…
Browse files Browse the repository at this point in the history
…y broken and not completed.
  • Loading branch information
JonasKs committed May 19, 2023
1 parent e0cd916 commit 038a363
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 72 deletions.
58 changes: 41 additions & 17 deletions arq/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
from redis.asyncio.sentinel import Sentinel
from redis.exceptions import RedisError, WatchError

from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix
from .constants import (
arq_prefix,
default_queue_name,
expires_extra_ms,
default_job_key_suffix,
default_result_key_suffix,
)
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
from .utils import timestamp_ms, to_ms, to_unix_ms

Expand Down Expand Up @@ -102,6 +108,8 @@ def __init__(
self.job_serializer = job_serializer
self.job_deserializer = job_deserializer
self.default_queue_name = default_queue_name
self.job_key_prefix = self.default_queue_name + default_job_key_suffix
self.result_key_prefix = self.default_queue_name + default_result_key_suffix
if pool_or_conn:
kwargs['connection_pool'] = pool_or_conn
self.expires_extra_ms = expires_extra_ms
Expand Down Expand Up @@ -136,16 +144,21 @@ async def enqueue_job(
"""
if _queue_name is None:
_queue_name = self.default_queue_name
_job_key_prefix = self.job_key_prefix
_result_key_prefix = self.result_key_prefix
else:
_job_key_prefix = _queue_name + default_job_key_suffix
_result_key_prefix = _queue_name + default_result_key_suffix
job_id = _job_id or uuid4().hex
job_key = job_key_prefix + job_id
job_key = _job_key_prefix + job_id
assert not (_defer_until and _defer_by), "use either 'defer_until' or 'defer_by' or neither, not both"

defer_by_ms = to_ms(_defer_by)
expires_ms = to_ms(_expires)

async with self.pipeline(transaction=True) as pipe:
await pipe.watch(job_key)
if await pipe.exists(job_key, result_key_prefix + job_id):
await pipe.watch(arq_prefix + job_key)
if await pipe.exists(arq_prefix + job_key, arq_prefix + _result_key_prefix + job_id):
await pipe.reset()
return None

Expand All @@ -161,35 +174,44 @@ async def enqueue_job(

job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
pipe.multi()
pipe.psetex(job_key, expires_ms, job) # type: ignore[no-untyped-call]
pipe.zadd(_queue_name, {job_id: score}) # type: ignore[unused-coroutine]
pipe.psetex(arq_prefix + job_key, expires_ms, job) # type: ignore[no-untyped-call]
pipe.zadd(arq_prefix + _queue_name, {job_id: score}) # type: ignore[unused-coroutine]
try:
await pipe.execute()
except WatchError:
# job got enqueued since we checked 'job_exists'
return None
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)

async def _get_job_result(self, key: bytes) -> JobResult:
job_id = key[len(result_key_prefix) :].decode()
async def _get_job_result(self, key: bytes, result_key: str) -> JobResult:
job_id = key[len(result_key) :].decode()
job = Job(job_id, self, _deserializer=self.job_deserializer)
r = await job.result_info()
if r is None:
raise KeyError(f'job "{key.decode()}" not found')
r.job_id = job_id
return r

async def all_job_results(self) -> List[JobResult]:
async def all_job_results(self, queue_name: Optional[str] = None) -> List[JobResult]:
"""
Get results for all jobs in redis.
Get results for all jobs in a redis queue.
"""
keys = await self.keys(result_key_prefix + '*')
results = await asyncio.gather(*[self._get_job_result(k) for k in keys])
if queue_name is None:
_result_key_prefix = self.result_key_prefix
else:
_result_key_prefix = queue_name + default_result_key_suffix

keys = await self.keys(arq_prefix + _result_key_prefix + '*')
results = await asyncio.gather(*[self._get_job_result(k, result_key=_result_key_prefix) for k in keys])
return sorted(results, key=attrgetter('enqueue_time'))

async def _get_job_def(self, job_id: bytes, score: int) -> JobDef:
key = job_key_prefix + job_id.decode()
v = await self.get(key)
async def _get_job_def(self, job_id: bytes, score: int, queue_name: Optional[str]) -> JobDef:
if queue_name is None:
_job_key_prefix = self.job_key_prefix
else:
_job_key_prefix = queue_name + default_job_key_suffix
key = _job_key_prefix + job_id.decode()
v = await self.get(arq_prefix + key)
assert v is not None, f'job "{key}" not found'
jd = deserialize_job(v, deserializer=self.job_deserializer)
jd.score = score
Expand All @@ -201,8 +223,10 @@ async def queued_jobs(self, *, queue_name: Optional[str] = None) -> List[JobDef]
"""
if queue_name is None:
queue_name = self.default_queue_name
jobs = await self.zrange(queue_name, withscores=True, start=0, end=-1)
return await asyncio.gather(*[self._get_job_def(job_id, int(score)) for job_id, score in jobs])
jobs = await self.zrange(arq_prefix + queue_name, withscores=True, start=0, end=-1)
return await asyncio.gather(
*[self._get_job_def(job_id, int(score), queue_name=queue_name) for job_id, score in jobs]
)


async def create_pool(
Expand Down
13 changes: 7 additions & 6 deletions arq/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
default_queue_name = 'arq:queue'
job_key_prefix = 'arq:job:'
in_progress_key_prefix = 'arq:in-progress:'
result_key_prefix = 'arq:result:'
retry_key_prefix = 'arq:retry:'
abort_jobs_ss = 'arq:abort'
arq_prefix = 'arq:'
default_queue_name = 'queue'
default_job_key_suffix = ':job:'
default_in_progress_key_suffix = ':in-progress:'
default_result_key_suffix = ':result:'
default_retry_key_suffix = ':retry:'
abort_jobs_ss = 'abort'
# age of items in the abort_key sorted set after which they're deleted
abort_job_max_age = 60
health_check_key_suffix = ':health-check'
Expand Down
40 changes: 26 additions & 14 deletions arq/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from redis.asyncio import Redis

from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix
from .constants import abort_jobs_ss, arq_prefix, default_queue_name, default_in_progress_key_suffix, \
default_job_key_suffix, default_result_key_suffix
from .utils import ms_to_datetime, poll, timestamp_ms

logger = logging.getLogger('arq.jobs')
Expand Down Expand Up @@ -68,7 +69,15 @@ class Job:
Holds data a reference to a job.
"""

__slots__ = 'job_id', '_redis', '_queue_name', '_deserializer'
__slots__ = (
'job_id',
'_redis',
'queue_name',
'_job_key_prefix',
'_result_key_prefix',
'_in_progress_key_prefix',
'_deserializer',
)

def __init__(
self,
Expand All @@ -79,7 +88,10 @@ def __init__(
):
self.job_id = job_id
self._redis = redis
self._queue_name = _queue_name
self.queue_name = _queue_name
self._job_key_prefix = _queue_name + default_job_key_suffix
self._result_key_prefix = _queue_name + default_result_key_suffix
self._in_progress_key_prefix = _queue_name + default_in_progress_key_suffix
self._deserializer = _deserializer

async def result(
Expand All @@ -103,8 +115,8 @@ async def result(

async for delay in poll(poll_delay):
async with self._redis.pipeline(transaction=True) as tr:
tr.get(result_key_prefix + self.job_id) # type: ignore[unused-coroutine]
tr.zscore(self._queue_name, self.job_id) # type: ignore[unused-coroutine]
tr.get(arq_prefix + self.queue_name + self._result_key_prefix + self.job_id) # type: ignore[unused-coroutine]
tr.zscore(arq_prefix + self.queue_name, self.job_id) # type: ignore[unused-coroutine]
v, s = await tr.execute()

if v:
Expand All @@ -130,11 +142,11 @@ async def info(self) -> Optional[JobDef]:
"""
info: Optional[JobDef] = await self.result_info()
if not info:
v = await self._redis.get(job_key_prefix + self.job_id)
v = await self._redis.get(arq_prefix + self.queue_name + self._job_key_prefix + self.job_id)
if v:
info = deserialize_job(v, deserializer=self._deserializer)
if info:
s = await self._redis.zscore(self._queue_name, self.job_id)
s = await self._redis.zscore(arq_prefix + self.queue_name, self.job_id)
info.score = None if s is None else int(s)
return info

Expand All @@ -143,7 +155,7 @@ async def result_info(self) -> Optional[JobResult]:
Information about the job result if available, does not wait for the result. Does not raise an exception
even if the job raised one.
"""
v = await self._redis.get(result_key_prefix + self.job_id)
v = await self._redis.get(arq_prefix + self._result_key_prefix + self.job_id)
if v:
return deserialize_result(v, deserializer=self._deserializer)
else:
Expand All @@ -154,9 +166,9 @@ async def status(self) -> JobStatus:
Status of the job.
"""
async with self._redis.pipeline(transaction=True) as tr:
tr.exists(result_key_prefix + self.job_id) # type: ignore[unused-coroutine]
tr.exists(in_progress_key_prefix + self.job_id) # type: ignore[unused-coroutine]
tr.zscore(self._queue_name, self.job_id) # type: ignore[unused-coroutine]
tr.exists(arq_prefix + self._result_key_prefix + self.job_id) # type: ignore[unused-coroutine]
tr.exists(arq_prefix + self._in_progress_key_prefix + self.job_id) # type: ignore[unused-coroutine]
tr.zscore(arq_prefix + self.queue_name, self.job_id) # type: ignore[unused-coroutine]
is_complete, is_in_progress, score = await tr.execute()

if is_complete:
Expand All @@ -180,11 +192,11 @@ async def abort(self, *, timeout: Optional[float] = None, poll_delay: float = 0.
job_info = await self.info()
if job_info and job_info.score and job_info.score > timestamp_ms():
async with self._redis.pipeline(transaction=True) as tr:
tr.zrem(self._queue_name, self.job_id) # type: ignore[unused-coroutine]
tr.zadd(self._queue_name, {self.job_id: 1}) # type: ignore[unused-coroutine]
tr.zrem(arq_prefix + self.queue_name, self.job_id) # type: ignore[unused-coroutine]
tr.zadd(arq_prefix + self.queue_name, {self.job_id: 1}) # type: ignore[unused-coroutine]
await tr.execute()

await self._redis.zadd(abort_jobs_ss, {self.job_id: timestamp_ms()})
await self._redis.zadd(abort_jobs_ss, {arq_prefix + self.job_id: timestamp_ms()})

try:
await self.result(timeout=timeout, poll_delay=poll_delay)
Expand Down
50 changes: 27 additions & 23 deletions arq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from .constants import (
abort_job_max_age,
abort_jobs_ss,
default_queue_name,
arq_prefix, default_queue_name,
expires_extra_ms,
health_check_key_suffix,
in_progress_key_prefix,
job_key_prefix,
default_in_progress_key_suffix,
default_job_key_suffix,
keep_cronjob_progress,
result_key_prefix,
retry_key_prefix,
default_result_key_suffix,
default_retry_key_suffix,
)
from .utils import (
args_to_string,
Expand Down Expand Up @@ -224,6 +224,10 @@ def __init__(
else:
raise ValueError('If queue_name is absent, redis_pool must be present.')
self.queue_name = queue_name
self.job_key_prefix = self.queue_name + default_job_key_suffix
self.result_key_prefix = self.queue_name + default_result_key_suffix
self.in_progress_key_prefix = self.queue_name + default_in_progress_key_suffix
self.retry_key_prefix = self.queue_name + default_retry_key_suffix
self.cron_jobs: List[CronJob] = []
if cron_jobs is not None:
assert all(isinstance(cj, CronJob) for cj in cron_jobs), 'cron_jobs, must be instances of CronJob'
Expand Down Expand Up @@ -426,7 +430,7 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
for job_id_b in job_ids:
await self.sem.acquire()
job_id = job_id_b.decode()
in_progress_key = in_progress_key_prefix + job_id
in_progress_key = self.in_progress_key_prefix + job_id
async with self.pool.pipeline(transaction=True) as pipe:
await pipe.watch(in_progress_key)
ongoing_exists = await pipe.exists(in_progress_key)
Expand Down Expand Up @@ -455,9 +459,9 @@ async def start_jobs(self, job_ids: List[bytes]) -> None:
async def run_job(self, job_id: str, score: int) -> None: # noqa: C901
start_ms = timestamp_ms()
async with self.pool.pipeline(transaction=True) as pipe:
pipe.get(job_key_prefix + job_id) # type: ignore[unused-coroutine]
pipe.incr(retry_key_prefix + job_id) # type: ignore[unused-coroutine]
pipe.expire(retry_key_prefix + job_id, 88400) # type: ignore[unused-coroutine]
pipe.get(self.job_key_prefix + job_id) # type: ignore[unused-coroutine]
pipe.incr(self.retry_key_prefix + job_id) # type: ignore[unused-coroutine]
pipe.expire(self.retry_key_prefix + job_id, 88400) # type: ignore[unused-coroutine]
if self.allow_abort_jobs:
pipe.zrem(abort_jobs_ss, job_id) # type: ignore[unused-coroutine]
v, job_try, _, abort_job = await pipe.execute()
Expand Down Expand Up @@ -520,7 +524,7 @@ async def job_failed(exc: BaseException) -> None:

if enqueue_job_try and enqueue_job_try > job_try:
job_try = enqueue_job_try
await self.pool.setex(retry_key_prefix + job_id, 88400, str(job_try))
await self.pool.setex(self.retry_key_prefix + job_id, 88400, str(job_try))

max_tries = self.max_tries if function.max_tries is None else function.max_tries
if job_try > max_tries:
Expand Down Expand Up @@ -665,39 +669,39 @@ async def finish_job(
) -> None:
async with self.pool.pipeline(transaction=True) as tr:
delete_keys = []
in_progress_key = in_progress_key_prefix + job_id
in_progress_key = self.in_progress_key_prefix + job_id
if keep_in_progress is None:
delete_keys += [in_progress_key]
else:
tr.pexpire(in_progress_key, to_ms(keep_in_progress)) # type: ignore[unused-coroutine]
tr.pexpire(arq_prefix + in_progress_key, to_ms(keep_in_progress)) # type: ignore[unused-coroutine]

if finish:
if result_data:
expire = None if keep_result_forever else result_timeout_s
tr.set(result_key_prefix + job_id, result_data, px=to_ms(expire)) # type: ignore[unused-coroutine]
delete_keys += [retry_key_prefix + job_id, job_key_prefix + job_id]
tr.zrem(abort_jobs_ss, job_id) # type: ignore[unused-coroutine]
tr.zrem(self.queue_name, job_id) # type: ignore[unused-coroutine]
tr.set(arq_prefix + self.result_key_prefix + job_id, result_data, px=to_ms(expire)) # type: ignore[unused-coroutine]
delete_keys += [arq_prefix + self.retry_key_prefix + job_id, arq_prefix + self.job_key_prefix + job_id]
tr.zrem(abort_jobs_ss, arq_prefix + job_id) # type: ignore[unused-coroutine]
tr.zrem(arq_prefix + self.queue_name, job_id) # type: ignore[unused-coroutine]
elif incr_score:
tr.zincrby(self.queue_name, incr_score, job_id) # type: ignore[unused-coroutine]
tr.zincrby(arq_prefix + self.queue_name, incr_score, job_id) # type: ignore[unused-coroutine]
if delete_keys:
tr.delete(*delete_keys) # type: ignore[unused-coroutine]
await tr.execute()

async def finish_failed_job(self, job_id: str, result_data: Optional[bytes]) -> None:
async with self.pool.pipeline(transaction=True) as tr:
tr.delete( # type: ignore[unused-coroutine]
retry_key_prefix + job_id,
in_progress_key_prefix + job_id,
job_key_prefix + job_id,
arq_prefix + self.retry_key_prefix + job_id,
arq_prefix + self.in_progress_key_prefix + job_id,
arq_prefix + self.job_key_prefix + job_id,
)
tr.zrem(abort_jobs_ss, job_id) # type: ignore[unused-coroutine]
tr.zrem(self.queue_name, job_id) # type: ignore[unused-coroutine]
tr.zrem(abort_jobs_ss, arq_prefix + job_id) # type: ignore[unused-coroutine]
tr.zrem(arq_prefix + self.queue_name, job_id) # type: ignore[unused-coroutine]
# result_data would only be None if serializing the result fails
keep_result = self.keep_result_forever or self.keep_result_s > 0
if result_data is not None and keep_result: # pragma: no branch
expire = 0 if self.keep_result_forever else self.keep_result_s
tr.set(result_key_prefix + job_id, result_data, px=to_ms(expire)) # type: ignore[unused-coroutine]
tr.set(arq_prefix + self.result_key_prefix + job_id, result_data, px=to_ms(expire)) # type: ignore[unused-coroutine]
await tr.execute()

async def heart_beat(self) -> None:
Expand Down
9 changes: 9 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: '3.8'

services:
redis:
container_name: arq_redis
image: redis:7-alpine
ports:
- '127.0.0.1:6379:6379'
restart: always
4 changes: 2 additions & 2 deletions tests/test_cron.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import arq
from arq import Worker
from arq.constants import in_progress_key_prefix
from arq.constants import default_in_progress_key_suffix, default_queue_name
from arq.cron import cron, next_cron


Expand Down Expand Up @@ -112,7 +112,7 @@ async def test_job_successful(worker, caplog, arq_redis, poll_delay):
assert ' 0.XXs → cron:foobar()\n 0.XXs ← cron:foobar ● 42' in log

# Assert the in-progress key still exists.
keys = await arq_redis.keys(in_progress_key_prefix + '*')
keys = await arq_redis.keys(default_queue_name + default_in_progress_key_suffix + '*')
assert len(keys) == 1
assert await arq_redis.pttl(keys[0]) > 0.0

Expand Down
Loading

0 comments on commit 038a363

Please sign in to comment.