This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import threading
|
||||
import time as mod_time
|
||||
import uuid
|
||||
@@ -7,6 +8,8 @@ from typing import Optional, Type
|
||||
from redis.exceptions import LockError, LockNotOwnedError
|
||||
from redis.typing import Number
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Lock:
|
||||
"""
|
||||
@@ -82,6 +85,7 @@ class Lock:
|
||||
blocking: bool = True,
|
||||
blocking_timeout: Optional[Number] = None,
|
||||
thread_local: bool = True,
|
||||
raise_on_release_error: bool = True,
|
||||
):
|
||||
"""
|
||||
Create a new Lock instance named ``name`` using the Redis client
|
||||
@@ -125,6 +129,11 @@ class Lock:
|
||||
thread-1 would see the token value as "xyz" and would be
|
||||
able to successfully release the thread-2's lock.
|
||||
|
||||
``raise_on_release_error`` indicates whether to raise an exception when
|
||||
the lock is no longer owned when exiting the context manager. By default,
|
||||
this is True, meaning an exception will be raised. If False, the warning
|
||||
will be logged and the exception will be suppressed.
|
||||
|
||||
In some use cases it's necessary to disable thread local storage. For
|
||||
example, if you have code where one thread acquires a lock and passes
|
||||
that lock instance to a worker thread to release later. If thread
|
||||
@@ -140,6 +149,7 @@ class Lock:
|
||||
self.blocking = blocking
|
||||
self.blocking_timeout = blocking_timeout
|
||||
self.thread_local = bool(thread_local)
|
||||
self.raise_on_release_error = raise_on_release_error
|
||||
self.local = threading.local() if self.thread_local else SimpleNamespace()
|
||||
self.local.token = None
|
||||
self.register_scripts()
|
||||
@@ -157,7 +167,10 @@ class Lock:
|
||||
def __enter__(self) -> "Lock":
|
||||
if self.acquire():
|
||||
return self
|
||||
raise LockError("Unable to acquire lock within the time specified")
|
||||
raise LockError(
|
||||
"Unable to acquire lock within the time specified",
|
||||
lock_name=self.name,
|
||||
)
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
@@ -165,7 +178,14 @@ class Lock:
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.release()
|
||||
try:
|
||||
self.release()
|
||||
except LockError:
|
||||
if self.raise_on_release_error:
|
||||
raise
|
||||
logger.warning(
|
||||
"Lock was unlocked or no longer owned when exiting context manager."
|
||||
)
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
@@ -248,7 +268,10 @@ class Lock:
|
||||
"""
|
||||
expected_token = self.local.token
|
||||
if expected_token is None:
|
||||
raise LockError("Cannot release an unlocked lock")
|
||||
raise LockError(
|
||||
"Cannot release a lock that's not owned or is already unlocked.",
|
||||
lock_name=self.name,
|
||||
)
|
||||
self.local.token = None
|
||||
self.do_release(expected_token)
|
||||
|
||||
@@ -256,9 +279,12 @@ class Lock:
|
||||
if not bool(
|
||||
self.lua_release(keys=[self.name], args=[expected_token], client=self.redis)
|
||||
):
|
||||
raise LockNotOwnedError("Cannot release a lock that's no longer owned")
|
||||
raise LockNotOwnedError(
|
||||
"Cannot release a lock that's no longer owned",
|
||||
lock_name=self.name,
|
||||
)
|
||||
|
||||
def extend(self, additional_time: int, replace_ttl: bool = False) -> bool:
|
||||
def extend(self, additional_time: Number, replace_ttl: bool = False) -> bool:
|
||||
"""
|
||||
Adds more time to an already acquired lock.
|
||||
|
||||
@@ -270,12 +296,12 @@ class Lock:
|
||||
`additional_time`.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot extend an unlocked lock")
|
||||
raise LockError("Cannot extend an unlocked lock", lock_name=self.name)
|
||||
if self.timeout is None:
|
||||
raise LockError("Cannot extend a lock with no timeout")
|
||||
raise LockError("Cannot extend a lock with no timeout", lock_name=self.name)
|
||||
return self.do_extend(additional_time, replace_ttl)
|
||||
|
||||
def do_extend(self, additional_time: int, replace_ttl: bool) -> bool:
|
||||
def do_extend(self, additional_time: Number, replace_ttl: bool) -> bool:
|
||||
additional_time = int(additional_time * 1000)
|
||||
if not bool(
|
||||
self.lua_extend(
|
||||
@@ -284,7 +310,10 @@ class Lock:
|
||||
client=self.redis,
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError("Cannot extend a lock that's no longer owned")
|
||||
raise LockNotOwnedError(
|
||||
"Cannot extend a lock that's no longer owned",
|
||||
lock_name=self.name,
|
||||
)
|
||||
return True
|
||||
|
||||
def reacquire(self) -> bool:
|
||||
@@ -292,9 +321,12 @@ class Lock:
|
||||
Resets a TTL of an already acquired lock back to a timeout value.
|
||||
"""
|
||||
if self.local.token is None:
|
||||
raise LockError("Cannot reacquire an unlocked lock")
|
||||
raise LockError("Cannot reacquire an unlocked lock", lock_name=self.name)
|
||||
if self.timeout is None:
|
||||
raise LockError("Cannot reacquire a lock with no timeout")
|
||||
raise LockError(
|
||||
"Cannot reacquire a lock with no timeout",
|
||||
lock_name=self.name,
|
||||
)
|
||||
return self.do_reacquire()
|
||||
|
||||
def do_reacquire(self) -> bool:
|
||||
@@ -304,5 +336,8 @@ class Lock:
|
||||
keys=[self.name], args=[self.local.token, timeout], client=self.redis
|
||||
)
|
||||
):
|
||||
raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned")
|
||||
raise LockNotOwnedError(
|
||||
"Cannot reacquire a lock that's no longer owned",
|
||||
lock_name=self.name,
|
||||
)
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user