diff --git a/cassandra/policies.py b/cassandra/policies.py index e742708019..8522229c6e 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -1194,6 +1194,120 @@ def on_request_error(self, query, consistency, error, retry_num): return self.RETHROW, None, None +class LWTRetryPolicy(ExponentialBackoffRetryPolicy): + """ + A retry policy tailored for Lightweight Transaction (LWT) queries. + + LWT queries use Paxos consensus, where the first replica in the token ring + acts as the Paxos coordinator (leader). Retrying LWT queries on a *different* + host causes Paxos contention — the new coordinator must compete with the + original one, potentially causing cascading timeouts. + + This policy addresses that by: + + - **CAS write timeouts**: Retrying on the **same host** (the Paxos coordinator) + with exponential backoff, giving the coordinator time to complete the Paxos round. + - **CAS read timeouts** (serial consistency): Retrying on the same host. + - **Unavailable at serial consistency**: Retrying on the **next host**, since the + Paxos phase failed on this node (not enough replicas alive to form quorum). + - **Non-CAS operations**: Delegating to the standard :class:`ExponentialBackoffRetryPolicy` + behavior. + + This is modeled after gocql's ``LWTRetryPolicy`` interface, which retries LWT + queries on the same host to avoid Paxos contention. + + Example usage:: + + from cassandra.cluster import Cluster + from cassandra.policies import LWTRetryPolicy + + # Use as the default retry policy for the cluster + cluster = Cluster( + default_retry_policy=LWTRetryPolicy(max_num_retries=3) + ) + + # Or assign to a specific statement + statement.retry_policy = LWTRetryPolicy(max_num_retries=5) + + :param max_num_retries: Maximum number of retry attempts (default: 3). + :param min_interval: Initial backoff delay in seconds (default: 0.1). + :param max_interval: Maximum backoff delay in seconds (default: 10.0). + """ + + def __init__(self, max_num_retries=3, min_interval=0.1, max_interval=10.0, + *args, **kwargs): + super(LWTRetryPolicy, self).__init__( + max_num_retries=max_num_retries, + min_interval=min_interval, + max_interval=max_interval, + *args, **kwargs) + + def on_write_timeout(self, query, consistency, write_type, + required_responses, received_responses, retry_num): + """ + For CAS (LWT) write timeouts, retry on the **same host** with exponential + backoff. Retrying on a different host would cause Paxos contention. + + For non-CAS writes, delegates to the base ExponentialBackoffRetryPolicy + behavior (retry BATCH_LOG only, RETHROW otherwise). + """ + if retry_num >= self.max_num_retries: + return self.RETHROW, None, None + + if write_type == WriteType.CAS: + # Retry on the SAME host — this is the Paxos coordinator. + # Moving to another host causes contention in the Paxos protocol. + return self.RETRY, consistency, self._calculate_backoff(retry_num) + + # Non-CAS: delegate to parent (retries BATCH_LOG, rethrows others) + return super(LWTRetryPolicy, self).on_write_timeout( + query, consistency, write_type, + required_responses, received_responses, retry_num) + + def on_read_timeout(self, query, consistency, required_responses, + received_responses, data_retrieved, retry_num): + """ + For reads at serial consistency (CAS reads), retry on the **same host** + with backoff. + + For non-serial reads, delegates to the base ExponentialBackoffRetryPolicy + behavior. + """ + if retry_num >= self.max_num_retries: + return self.RETHROW, None, None + + if ConsistencyLevel.is_serial(consistency): + # Serial read = CAS/Paxos read. Retry on same host. + return self.RETRY, consistency, self._calculate_backoff(retry_num) + + # Non-serial: delegate to parent + return super(LWTRetryPolicy, self).on_read_timeout( + query, consistency, required_responses, + received_responses, data_retrieved, retry_num) + + def on_unavailable(self, query, consistency, required_replicas, + alive_replicas, retry_num): + """ + For serial consistency (CAS/Paxos phase), retry on the **next host** — + this node couldn't form a Paxos quorum, so a different coordinator + might see a different set of available replicas. + + For non-serial consistency, delegates to the base ExponentialBackoffRetryPolicy + behavior. + """ + if retry_num >= self.max_num_retries: + return self.RETHROW, None, None + + if ConsistencyLevel.is_serial(consistency): + # Paxos phase failed — not enough replicas for serial quorum. + # Try a different coordinator; it might have better connectivity. + return self.RETRY_NEXT_HOST, None, self._calculate_backoff(retry_num) + + # Non-serial: delegate to parent + return super(LWTRetryPolicy, self).on_unavailable( + query, consistency, required_replicas, alive_replicas, retry_num) + + class AddressTranslator(object): """ Interface for translating cluster-defined endpoints. diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..3fd5cb10ca 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -33,7 +33,8 @@ RetryPolicy, WriteType, DowngradingConsistencyRetryPolicy, ConstantReconnectionPolicy, LoadBalancingPolicy, ConvictionPolicy, ReconnectionPolicy, FallthroughRetryPolicy, - IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy) + IdentityTranslator, EC2MultiRegionTranslator, HostFilterPolicy, ExponentialBackoffRetryPolicy, + LWTRetryPolicy) from cassandra.connection import DefaultEndPoint, UnixSocketEndPoint from cassandra.pool import Host from cassandra.query import Statement @@ -1408,6 +1409,270 @@ def test_calculate_backoff(self): assert d < delay + (0.1 / 2), f"d={d} attempts={attempts}, delay={delay}" +class LWTRetryPolicyTest(unittest.TestCase): + """Tests for LWTRetryPolicy — LWT-aware retry with same-host preference.""" + + def _make_policy(self, max_retries=3): + return LWTRetryPolicy(max_num_retries=max_retries) + + # --- CAS write timeout: retry on SAME host --- + + def test_cas_write_timeout_retries_same_host(self): + """CAS write timeout on first attempt should retry on SAME host.""" + policy = self._make_policy() + retry, consistency, delay = policy.on_write_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + write_type=WriteType.CAS, + required_responses=3, received_responses=1, retry_num=0) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.QUORUM + assert delay is not None and delay > 0 + + def test_cas_write_timeout_retries_with_backoff(self): + """CAS write timeout backoff delay should increase with retry_num.""" + policy = self._make_policy(max_retries=5) + delays = [] + for attempt in range(3): + _, _, delay = policy.on_write_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + write_type=WriteType.CAS, + required_responses=3, received_responses=1, retry_num=attempt) + delays.append(delay) + # Delays should generally increase (with some jitter tolerance). + # Jitter is ± min_interval/2 = ±0.05s. + # delay_0 base=0.1s (range 0.05–0.15), delay_2 base=0.4s (range 0.35–0.45). + # The gap is wide enough that delay_0 < delay_2 always holds. + assert delays[0] < delays[2], ( + f"Backoff should increase: delays={delays}") + + def test_cas_write_timeout_max_retries_exceeded(self): + """CAS write timeout should RETHROW when max retries exceeded.""" + policy = self._make_policy(max_retries=2) + retry, consistency, delay = policy.on_write_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + write_type=WriteType.CAS, + required_responses=3, received_responses=1, retry_num=2) + assert retry == RetryPolicy.RETHROW + + def test_cas_write_timeout_preserves_consistency(self): + """CAS retry should preserve the original consistency level.""" + policy = self._make_policy() + for cl in [ConsistencyLevel.QUORUM, ConsistencyLevel.LOCAL_QUORUM, + ConsistencyLevel.ONE, ConsistencyLevel.ALL]: + retry, consistency, _ = policy.on_write_timeout( + query=None, consistency=cl, + write_type=WriteType.CAS, + required_responses=3, received_responses=1, retry_num=0) + assert retry == RetryPolicy.RETRY + assert consistency == cl, f"Expected {cl}, got {consistency}" + + # --- Non-CAS write timeout: delegate to parent --- + + def test_simple_write_timeout_rethrows(self): + """SIMPLE write timeout should RETHROW (same as base policy).""" + policy = self._make_policy() + retry, consistency, delay = policy.on_write_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + write_type=WriteType.SIMPLE, + required_responses=3, received_responses=1, retry_num=0) + assert retry == RetryPolicy.RETHROW + + def test_batch_log_write_timeout_retries(self): + """BATCH_LOG write timeout should retry (inherited from base).""" + policy = self._make_policy() + retry, consistency, delay = policy.on_write_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + write_type=WriteType.BATCH_LOG, + required_responses=3, received_responses=1, retry_num=0) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.QUORUM + + def test_batch_log_write_timeout_max_retries_exceeded(self): + """BATCH_LOG write timeout should RETHROW when max retries exceeded.""" + policy = self._make_policy(max_retries=1) + retry, consistency, delay = policy.on_write_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + write_type=WriteType.BATCH_LOG, + required_responses=3, received_responses=1, retry_num=1) + assert retry == RetryPolicy.RETHROW + + def test_counter_write_timeout_rethrows(self): + """COUNTER write timeout should RETHROW (same as base policy).""" + policy = self._make_policy() + retry, consistency, delay = policy.on_write_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + write_type=WriteType.COUNTER, + required_responses=3, received_responses=1, retry_num=0) + assert retry == RetryPolicy.RETHROW + + # --- Serial (CAS) read timeout: retry on SAME host --- + + def test_serial_read_timeout_retries_same_host(self): + """Read timeout at SERIAL consistency should retry on SAME host.""" + policy = self._make_policy() + retry, consistency, delay = policy.on_read_timeout( + query=None, consistency=ConsistencyLevel.SERIAL, + required_responses=3, received_responses=1, + data_retrieved=False, retry_num=0) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.SERIAL + assert delay is not None and delay > 0 + + def test_local_serial_read_timeout_retries_same_host(self): + """Read timeout at LOCAL_SERIAL should retry on SAME host.""" + policy = self._make_policy() + retry, consistency, delay = policy.on_read_timeout( + query=None, consistency=ConsistencyLevel.LOCAL_SERIAL, + required_responses=3, received_responses=1, + data_retrieved=False, retry_num=0) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.LOCAL_SERIAL + assert delay is not None and delay > 0 + + def test_serial_read_timeout_max_retries_exceeded(self): + """Serial read timeout should RETHROW when max retries exceeded.""" + policy = self._make_policy(max_retries=1) + retry, consistency, delay = policy.on_read_timeout( + query=None, consistency=ConsistencyLevel.SERIAL, + required_responses=3, received_responses=1, + data_retrieved=False, retry_num=1) + assert retry == RetryPolicy.RETHROW + + # --- Non-serial read timeout: delegate to parent --- + + def test_non_serial_read_timeout_delegates_to_parent(self): + """Non-serial read timeout should use base policy behavior.""" + policy = self._make_policy() + # Base: retry if enough responses but no data + retry, consistency, delay = policy.on_read_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + required_responses=2, received_responses=2, + data_retrieved=False, retry_num=0) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.QUORUM + + # Base: rethrow if we got data + retry, consistency, delay = policy.on_read_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + required_responses=2, received_responses=2, + data_retrieved=True, retry_num=0) + assert retry == RetryPolicy.RETHROW + + # --- Serial unavailable: retry on NEXT host --- + + def test_serial_unavailable_retries_next_host(self): + """Unavailable at SERIAL should retry on NEXT host.""" + policy = self._make_policy() + retry, consistency, delay = policy.on_unavailable( + query=None, consistency=ConsistencyLevel.SERIAL, + required_replicas=3, alive_replicas=1, retry_num=0) + assert retry == RetryPolicy.RETRY_NEXT_HOST + assert consistency is None + assert delay is not None and delay > 0 + + def test_local_serial_unavailable_retries_next_host(self): + """Unavailable at LOCAL_SERIAL should retry on NEXT host.""" + policy = self._make_policy() + retry, consistency, delay = policy.on_unavailable( + query=None, consistency=ConsistencyLevel.LOCAL_SERIAL, + required_replicas=3, alive_replicas=1, retry_num=0) + assert retry == RetryPolicy.RETRY_NEXT_HOST + assert consistency is None + assert delay is not None and delay > 0 + + def test_serial_unavailable_max_retries_exceeded(self): + """Serial unavailable should RETHROW when max retries exceeded.""" + policy = self._make_policy(max_retries=1) + retry, consistency, delay = policy.on_unavailable( + query=None, consistency=ConsistencyLevel.SERIAL, + required_replicas=3, alive_replicas=1, retry_num=1) + assert retry == RetryPolicy.RETHROW + + # --- Non-serial unavailable: delegate to parent --- + + def test_non_serial_unavailable_delegates_to_parent(self): + """Non-serial unavailable should use base policy behavior.""" + policy = self._make_policy() + # Base: RETRY_NEXT_HOST on first attempt + retry, consistency, delay = policy.on_unavailable( + query=None, consistency=ConsistencyLevel.QUORUM, + required_replicas=3, alive_replicas=1, retry_num=0) + assert retry == RetryPolicy.RETRY_NEXT_HOST + + # --- on_request_error: inherited from parent --- + + def test_request_error_retries_next_host(self): + """Request errors should retry on next host (inherited behavior).""" + policy = self._make_policy() + retry, consistency, delay = policy.on_request_error( + query=None, consistency=ConsistencyLevel.QUORUM, + error=Exception("overloaded"), retry_num=0) + assert retry == RetryPolicy.RETRY_NEXT_HOST + + def test_request_error_max_retries_exceeded(self): + """Request errors should RETHROW when max retries exceeded.""" + policy = self._make_policy(max_retries=1) + retry, consistency, delay = policy.on_request_error( + query=None, consistency=ConsistencyLevel.QUORUM, + error=Exception("overloaded"), retry_num=1) + assert retry == RetryPolicy.RETHROW + + # --- Constructor defaults --- + + def test_default_constructor(self): + """LWTRetryPolicy should have sensible defaults.""" + policy = LWTRetryPolicy() + assert policy.max_num_retries == 3 + assert policy.min_interval == 0.1 + assert policy.max_interval == 10.0 + + def test_custom_constructor(self): + """LWTRetryPolicy should accept custom parameters.""" + policy = LWTRetryPolicy(max_num_retries=5, min_interval=0.5, max_interval=30.0) + assert policy.max_num_retries == 5 + assert policy.min_interval == 0.5 + assert policy.max_interval == 30.0 + + def test_inherits_exponential_backoff(self): + """LWTRetryPolicy should inherit from ExponentialBackoffRetryPolicy.""" + policy = LWTRetryPolicy() + assert isinstance(policy, ExponentialBackoffRetryPolicy) + assert isinstance(policy, RetryPolicy) + + # --- Verify 3-tuple return format for all methods --- + + def test_all_methods_return_3_tuples(self): + """All retry decisions should return 3-tuples (decision, cl, delay).""" + policy = self._make_policy() + + # CAS write timeout + result = policy.on_write_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + write_type=WriteType.CAS, + required_responses=3, received_responses=1, retry_num=0) + assert len(result) == 3, f"Expected 3-tuple, got {result}" + + # Serial read timeout + result = policy.on_read_timeout( + query=None, consistency=ConsistencyLevel.SERIAL, + required_responses=3, received_responses=1, + data_retrieved=False, retry_num=0) + assert len(result) == 3, f"Expected 3-tuple, got {result}" + + # Serial unavailable + result = policy.on_unavailable( + query=None, consistency=ConsistencyLevel.SERIAL, + required_replicas=3, alive_replicas=1, retry_num=0) + assert len(result) == 3, f"Expected 3-tuple, got {result}" + + # RETHROW cases + result = policy.on_write_timeout( + query=None, consistency=ConsistencyLevel.QUORUM, + write_type=WriteType.SIMPLE, + required_responses=3, received_responses=1, retry_num=0) + assert len(result) == 3, f"Expected 3-tuple, got {result}" + + class WhiteListRoundRobinPolicyTest(unittest.TestCase): def test_hosts_with_hostname(self):