Source code for pjrpc.client.retry

import asyncio
import dataclasses as dc
import itertools as it
import logging
import time
from typing import Any, Callable, Generator, Iterator, Mapping, Optional

from pjrpc.client import AsyncMiddlewareHandler, MiddlewareHandler
from pjrpc.common import AbstractRequest, AbstractResponse

logger = logging.getLogger(__package__)

Jitter = Callable[[int], float]


[docs]@dc.dataclass(frozen=True) class Backoff: """ JSON-RPC request retry strategy. :param attempts: retries number :param jitter: retry delay jitter generator """ attempts: int jitter: Jitter = lambda attempt: 0.0 def __call__(self) -> Iterator[float]: """ Returns delay iterator. """ raise NotImplementedError
[docs]@dc.dataclass(frozen=True) class PeriodicBackoff(Backoff): """ Periodic request retry strategy. :param interval: retry delay """ interval: float = 1.0 def __call__(self) -> Iterator[float]: def gen() -> Generator[float, None, None]: for attempt in range(self.attempts): yield self.interval + self.jitter(attempt) return gen()
[docs]@dc.dataclass(frozen=True) class ExponentialBackoff(Backoff): """ Exponential request retry strategy. :param base: exponentially growing delay base :param factor: exponentially growing delay factor (multiplier) :param max_value: delay max value """ base: float = 1.0 factor: float = 2.0 max_value: Optional[float] = None def __call__(self) -> Iterator[float]: def gen() -> Generator[float, None, None]: for attempt, base in enumerate(it.repeat(self.base, self.attempts)): value = base * (self.factor ** attempt) + self.jitter(attempt) yield min(self.max_value, value) if self.max_value is not None else value return gen()
[docs]@dc.dataclass(frozen=True) class FibonacciBackoff(Backoff): """ Fibonacci request retry strategy. :param multiplier: fibonacci interval sequence multiplier :param max_value: delay max value """ multiplier: float = 1.0 max_value: float = 1.0 def __call__(self) -> Iterator[float]: def gen() -> Generator[float, None, None]: prev, cur = 1, 1 for attempt in range(self.attempts): value = cur * self.multiplier + self.jitter(attempt) yield min(self.max_value, value) if self.max_value is not None else value tmp = cur cur = prev + cur prev = tmp return gen()
[docs]@dc.dataclass(frozen=True) class RetryStrategy: """ JSON-RPC request retry strategy. :param backoff: backoff delay generator :param codes: JSON-RPC response codes receiving which the request will be retried :param exceptions: exceptions catching which the request will be retried """ backoff: Backoff codes: Optional[set[int]] = None exceptions: Optional[set[type[Exception]]] = None
class RetryMiddleware: def __init__(self, retry_strategy: RetryStrategy): self._retry_strategy = retry_strategy def __call__( self, request: AbstractRequest, request_kwargs: Mapping[str, Any], /, handler: MiddlewareHandler, ) -> Optional[AbstractResponse]: """ Request retrying middleware """ delays = self._retry_strategy.backoff() for attempt in it.count(start=1): try: response = handler(request, request_kwargs) if response is not None and response.is_error and self._retry_strategy.codes: if (code := response.unwrap_error().code) in self._retry_strategy.codes: delay = next(delays, None) if delay is not None: logger.debug("retrying request: attempt=%d, code=%s", attempt, code) time.sleep(delay) continue return response except tuple(self._retry_strategy.exceptions or {}) as e: delay = next(delays, None) if delay is not None: logger.debug("retrying request: attempt=%d, exception=%r", attempt, e) time.sleep(delay) else: raise e else: raise AssertionError("unreachable") class AsyncRetryMiddleware: def __init__(self, retry_strategy: RetryStrategy): self._retry_strategy = retry_strategy async def __call__( self, request: AbstractRequest, request_kwargs: Mapping[str, Any], /, handler: AsyncMiddlewareHandler, ) -> Optional[AbstractResponse]: """ Asynchronous request retrying middleware """ delays = self._retry_strategy.backoff() for attempt in it.count(start=1): try: response = await handler(request, request_kwargs) if response is not None and response.is_error and self._retry_strategy.codes: if (code := response.unwrap_error().code) in self._retry_strategy.codes: delay = next(delays, None) if delay is not None: logger.debug("retrying request: attempt=%d, code=%s", attempt, code) await asyncio.sleep(delay) continue return response except tuple(self._retry_strategy.exceptions or {}) as e: delay = next(delays, None) if delay is not None: logger.debug("retrying request: attempt=%d, exception=%r", attempt, e) await asyncio.sleep(delay) else: raise e else: raise AssertionError("unreachable")