aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLarry Safran <lsafran@google.com>2023-09-26 17:31:58 -0700
committerGitHub <noreply@github.com>2023-09-26 17:31:58 -0700
commitbc784c0ef9f7727497f8543492979c60d5ce5ef8 (patch)
tree5fe983b69f364ebf0bf150beffc99950fa079acf
parentcf4cf03d79e321083015fe2e62b4e244112d88f0 (diff)
downloadgrpc-grpc-java-bc784c0ef9f7727497f8543492979c60d5ce5ef8.tar.gz
Revert "Change Round Robin and WeightedRoundRobin into petiole policies (#10528)" (#10575)
This reverts commit e1334eae7bba39d85a952bc5ab5aeb4cb05a56d8.
-rw-r--r--core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java2
-rw-r--r--core/src/testFixtures/java/io/grpc/internal/TestUtils.java5
-rw-r--r--examples/android/strictmode/app/build.gradle1
-rw-r--r--examples/android/strictmode/app/proguard-rules.pro1
-rw-r--r--util/build.gradle12
-rw-r--r--util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java241
-rw-r--r--util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java234
-rw-r--r--util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java4
-rw-r--r--util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java220
-rw-r--r--util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java156
-rw-r--r--xds/build.gradle3
-rw-r--r--xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java92
-rw-r--r--xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java243
-rw-r--r--xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java8
-rw-r--r--xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java391
15 files changed, 629 insertions, 984 deletions
diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
index df35afae1..da2bc072a 100644
--- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
+++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
@@ -161,7 +161,7 @@ import org.mockito.stubbing.Answer;
/** Unit tests for {@link ManagedChannelImpl}. */
@RunWith(JUnit4.class)
// TODO(creamsoup) remove backward compatible check when fully migrated
-@SuppressWarnings({"deprecation", "DataFlowIssue"})
+@SuppressWarnings("deprecation")
public class ManagedChannelImplTest {
private static final int DEFAULT_PORT = 447;
diff --git a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java
index 02df28f2e..974f36e59 100644
--- a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java
+++ b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java
@@ -24,7 +24,6 @@ import static org.mockito.Mockito.when;
import io.grpc.CallOptions;
import io.grpc.ChannelLogger;
import io.grpc.ClientStreamTracer;
-import io.grpc.EquivalentAddressGroup;
import io.grpc.InternalLogId;
import io.grpc.LoadBalancer.PickResult;
import io.grpc.LoadBalancer.PickSubchannelArgs;
@@ -144,10 +143,6 @@ public final class TestUtils {
return captor;
}
- public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
- return new EquivalentAddressGroup(eag.getAddresses());
- }
-
private TestUtils() {
}
diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle
index 85e283b11..c00b8fbd9 100644
--- a/examples/android/strictmode/app/build.gradle
+++ b/examples/android/strictmode/app/build.gradle
@@ -53,7 +53,6 @@ dependencies {
implementation 'androidx.appcompat:appcompat:1.0.0'
// You need to build grpc-java to obtain these libraries below.
- implementation 'io.grpc:grpc-core:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION
implementation 'io.grpc:grpc-okhttp:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION
implementation 'io.grpc:grpc-protobuf-lite:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION
implementation 'io.grpc:grpc-stub:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION
diff --git a/examples/android/strictmode/app/proguard-rules.pro b/examples/android/strictmode/app/proguard-rules.pro
index d5715fd16..1507a5267 100644
--- a/examples/android/strictmode/app/proguard-rules.pro
+++ b/examples/android/strictmode/app/proguard-rules.pro
@@ -15,4 +15,3 @@
-dontwarn javax.naming.**
-dontwarn okio.**
-dontwarn sun.misc.Unsafe
-
diff --git a/util/build.gradle b/util/build.gradle
index cdd32e0ce..a05c55b27 100644
--- a/util/build.gradle
+++ b/util/build.gradle
@@ -1,6 +1,5 @@
plugins {
id "java-library"
- id "java-test-fixtures"
id "maven-publish"
id "me.champeau.jmh"
@@ -20,18 +19,11 @@ dependencies {
implementation libraries.animalsniffer.annotations,
libraries.guava
- testImplementation libraries.guava.testlib,
- testFixtures(project(':grpc-api')),
+ testImplementation testFixtures(project(':grpc-api')),
testFixtures(project(':grpc-core')),
project(':grpc-testing')
+ testImplementation libraries.guava.testlib
- testFixturesApi project(':grpc-core')
- testFixturesImplementation libraries.guava,
- libraries.junit,
- libraries.mockito.core,
- testFixtures(project(':grpc-api')),
- testFixtures(project(':grpc-core')),
- project(':grpc-testing')
jmh project(':grpc-testing')
signature libraries.signature.java
diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java
index 2f0aa04cf..8f2269af2 100644
--- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java
+++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java
@@ -16,29 +16,25 @@
package io.grpc.util;
-import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
-import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.ImmutableList;
-import io.grpc.Attributes;
import io.grpc.ConnectivityState;
-import io.grpc.EquivalentAddressGroup;
import io.grpc.Internal;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerProvider;
import io.grpc.Status;
-import io.grpc.internal.PickFirstLoadBalancerProvider;
-import java.util.Collection;
-import java.util.Collections;
+import io.grpc.SynchronizationContext;
+import io.grpc.SynchronizationContext.ScheduledHandle;
+import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
@@ -50,34 +46,23 @@ import javax.annotation.Nullable;
@Internal
public abstract class MultiChildLoadBalancer extends LoadBalancer {
+ @VisibleForTesting
+ public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15;
private static final Logger logger = Logger.getLogger(MultiChildLoadBalancer.class.getName());
private final Map<Object, ChildLbState> childLbStates = new HashMap<>();
private final Helper helper;
+ protected final SynchronizationContext syncContext;
+ private final ScheduledExecutorService timeService;
// Set to true if currently in the process of handling resolved addresses.
- @VisibleForTesting
- boolean resolvingAddresses;
-
- protected final PickFirstLoadBalancerProvider pickFirstLbProvider =
- new PickFirstLoadBalancerProvider();
-
+ private boolean resolvingAddresses;
protected MultiChildLoadBalancer(Helper helper) {
this.helper = checkNotNull(helper, "helper");
+ this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
+ this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
logger.log(Level.FINE, "Created");
}
- @SuppressWarnings("ReferenceEquality")
- protected static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
- if (eag.getAttributes() == Attributes.EMPTY) {
- return eag;
- } else {
- return new EquivalentAddressGroup(eag.getAddresses());
- }
- }
-
- protected abstract SubchannelPicker getSubchannelPicker(
- Map<Object, SubchannelPicker> childPickers);
-
protected SubchannelPicker getInitialPicker() {
return EMPTY_PICKER;
}
@@ -86,43 +71,11 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
return new FixedResultPicker(PickResult.withError(error));
}
- @VisibleForTesting
- protected Collection<ChildLbState> getChildLbStates() {
- return childLbStates.values();
- }
-
- protected ChildLbState getChildLbState(Object key) {
- if (key == null) {
- return null;
- }
- return childLbStates.get(key);
- }
-
- protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) {
- return getChildLbState(stripAttrs(eag));
- }
-
- /**
- * Override to utilize parsing of the policy configuration or alternative helper/lb generation.
- */
- protected Map<Object, ChildLbState> createChildLbMap(ResolvedAddresses resolvedAddresses) {
- Map<Object, ChildLbState> childLbMap = new HashMap<>();
- List<EquivalentAddressGroup> addresses = resolvedAddresses.getAddresses();
- Object policyConfig = resolvedAddresses.getLoadBalancingPolicyConfig();
- for (EquivalentAddressGroup eag : addresses) {
- EquivalentAddressGroup strippedEag = stripAttrs(eag); // keys need to be just addresses
- if (!childLbMap.containsKey(strippedEag)) {
- childLbMap.put(strippedEag,
- createChildLbState(strippedEag, policyConfig, getInitialPicker()));
- }
- }
- return childLbMap;
- }
+ protected abstract Map<Object, PolicySelection> getPolicySelectionMap(
+ ResolvedAddresses resolvedAddresses);
- protected ChildLbState createChildLbState(Object key, Object policyConfig,
- SubchannelPicker initialPicker) {
- return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker);
- }
+ protected abstract SubchannelPicker getSubchannelPicker(
+ Map<Object, SubchannelPicker> childPickers);
@Override
public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
@@ -134,61 +87,25 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
}
}
- protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses,
- Object childConfig) {
- checkArgument(key instanceof EquivalentAddressGroup, "key is wrong type");
-
- // Retrieve the non-stripped version
- EquivalentAddressGroup eag = null;
- for (EquivalentAddressGroup equivalentAddressGroup : resolvedAddresses.getAddresses()) {
- if (stripAttrs(equivalentAddressGroup).equals(key)) {
- eag = equivalentAddressGroup;
- break;
- }
- }
-
- checkNotNull(eag, key.toString() + " no longer present in load balancer children");
-
- return resolvedAddresses.toBuilder()
- .setAddresses(Collections.singletonList(eag))
- .setLoadBalancingPolicyConfig(childConfig)
- .build();
- }
-
-
-
private boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) {
logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses);
- Map<Object, ChildLbState> newChildren = createChildLbMap(resolvedAddresses);
-
- if (newChildren.isEmpty()) {
- handleNameResolutionError(Status.UNAVAILABLE.withDescription(
- "NameResolver returned no usable address. " + resolvedAddresses));
- return false;
- }
-
- // Do adds and updates
- for (Map.Entry<Object, ChildLbState> entry : newChildren.entrySet()) {
+ Map<Object, PolicySelection> newChildPolicies = getPolicySelectionMap(resolvedAddresses);
+ for (Map.Entry<Object, PolicySelection> entry : newChildPolicies.entrySet()) {
final Object key = entry.getKey();
- LoadBalancerProvider childPolicyProvider = entry.getValue().getPolicyProvider();
+ LoadBalancerProvider childPolicyProvider = entry.getValue().getProvider();
Object childConfig = entry.getValue().getConfig();
if (!childLbStates.containsKey(key)) {
- childLbStates.put(key, entry.getValue());
+ childLbStates.put(key, new ChildLbState(key, childPolicyProvider, getInitialPicker()));
} else {
- // Reuse the existing one
- ChildLbState existingChildLbState = childLbStates.get(key);
- if (existingChildLbState.isDeactivated()) {
- existingChildLbState.reactivate(childPolicyProvider);
- }
+ childLbStates.get(key).reactivate(childPolicyProvider);
}
-
LoadBalancer childLb = childLbStates.get(key).lb;
- childLb.handleResolvedAddresses(getChildAddresses(key, resolvedAddresses, childConfig));
+ ResolvedAddresses childAddresses =
+ resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build();
+ childLb.handleResolvedAddresses(childAddresses);
}
-
- // Do removals
- for (Object key : ImmutableList.copyOf(childLbStates.keySet())) {
- if (!newChildren.containsKey(key)) {
+ for (Object key : childLbStates.keySet()) {
+ if (!newChildPolicies.containsKey(key)) {
childLbStates.get(key).deactivate();
}
}
@@ -222,10 +139,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
childLbStates.clear();
}
- protected void updateOverallBalancingState() {
+ private void updateOverallBalancingState() {
ConnectivityState overallState = null;
final Map<Object, SubchannelPicker> childPickers = new HashMap<>();
- for (ChildLbState childLbState : getChildLbStates()) {
+ for (ChildLbState childLbState : childLbStates.values()) {
if (childLbState.deactivated) {
continue;
}
@@ -238,7 +155,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
}
@Nullable
- protected static ConnectivityState aggregateState(
+ private static ConnectivityState aggregateState(
@Nullable ConnectivityState overallState, ConnectivityState childState) {
if (overallState == null) {
return childState;
@@ -255,109 +172,67 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
return overallState;
}
- protected Helper getHelper() {
- return helper;
- }
-
- protected void removeChild(Object key) {
- childLbStates.remove(key);
- }
-
-
- public class ChildLbState {
+ private final class ChildLbState {
private final Object key;
- private final Object config;
private final GracefulSwitchLoadBalancer lb;
private LoadBalancerProvider policyProvider;
private ConnectivityState currentState = CONNECTING;
private SubchannelPicker currentPicker;
private boolean deactivated;
+ @Nullable
+ ScheduledHandle deletionTimer;
- public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
- SubchannelPicker initialPicker) {
+ ChildLbState(Object key, LoadBalancerProvider policyProvider, SubchannelPicker initialPicker) {
this.key = key;
this.policyProvider = policyProvider;
lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper());
lb.switchTo(policyProvider);
currentPicker = initialPicker;
- config = childConfig;
- }
-
-
- @Override
- public String toString() {
- return "Address = " + key
- + ", state = " + currentState
- + ", picker type: " + currentPicker.getClass()
- + ", lb: " + lb.delegate().getClass()
- + (deactivated ? ", deactivated" : "");
- }
-
- public Object getKey() {
- return key;
- }
-
- Object getConfig() {
- return config;
- }
-
- public LoadBalancerProvider getPolicyProvider() {
- return policyProvider;
- }
-
- protected Subchannel getSubchannels(PickSubchannelArgs args) {
- return getCurrentPicker().pickSubchannel(args).getSubchannel();
- }
-
- ConnectivityState getCurrentState() {
- return currentState;
}
- public SubchannelPicker getCurrentPicker() {
- return currentPicker;
- }
-
- public boolean isDeactivated() {
- return deactivated;
- }
-
- @VisibleForTesting
- LoadBalancer getLb() {
- return this.lb;
- }
-
- protected void setDeactivated() {
- deactivated = true;
- }
-
- protected void deactivate() {
+ void deactivate() {
if (deactivated) {
return;
}
- shutdown();
- childLbStates.remove(key);
+ class DeletionTask implements Runnable {
+ @Override
+ public void run() {
+ shutdown();
+ childLbStates.remove(key);
+ }
+ }
+
+ deletionTimer =
+ syncContext.schedule(
+ new DeletionTask(),
+ DELAYED_CHILD_DELETION_TIME_MINUTES,
+ TimeUnit.MINUTES,
+ timeService);
deactivated = true;
logger.log(Level.FINE, "Child balancer {0} deactivated", key);
}
- protected void reactivate(LoadBalancerProvider policyProvider) {
+ void reactivate(LoadBalancerProvider policyProvider) {
+ if (deletionTimer != null && deletionTimer.isPending()) {
+ deletionTimer.cancel();
+ deactivated = false;
+ logger.log(Level.FINE, "Child balancer {0} reactivated", key);
+ }
if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) {
Object[] objects = {
key, this.policyProvider.getPolicyName(),policyProvider.getPolicyName()};
logger.log(Level.FINE, "Child balancer {0} switching policy from {1} to {2}", objects);
lb.switchTo(policyProvider);
this.policyProvider = policyProvider;
- } else {
- logger.log(Level.FINE, "Child balancer {0} reactivated", key);
}
-
- deactivated = false;
}
- protected void shutdown() {
+ void shutdown() {
+ if (deletionTimer != null && deletionTimer.isPending()) {
+ deletionTimer.cancel();
+ }
lb.shutdown();
- this.currentState = SHUTDOWN;
logger.log(Level.FINE, "Child balancer {0} deleted", key);
}
diff --git a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java
index 9873e3e45..560970849 100644
--- a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java
+++ b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java
@@ -16,9 +16,11 @@
package io.grpc.util;
+import static com.google.common.base.Preconditions.checkNotNull;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
+import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import com.google.common.annotations.VisibleForTesting;
@@ -35,10 +37,13 @@ import io.grpc.NameResolver;
import io.grpc.Status;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
+import java.util.Set;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import javax.annotation.Nonnull;
@@ -47,23 +52,131 @@ import javax.annotation.Nonnull;
* EquivalentAddressGroup}s from the {@link NameResolver}.
*/
@Internal
-public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
+public class RoundRobinLoadBalancer extends LoadBalancer {
@VisibleForTesting
static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO =
Attributes.Key.create("state-info");
+ private final Helper helper;
+ private final Map<EquivalentAddressGroup, Subchannel> subchannels =
+ new HashMap<>();
private final Random random;
private ConnectivityState currentState;
protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK);
public RoundRobinLoadBalancer(Helper helper) {
- super(helper);
+ this.helper = checkNotNull(helper, "helper");
this.random = new Random();
}
@Override
- protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) {
- throw new UnsupportedOperationException(); // local updateOverallBalancingState doesn't use this
+ public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
+ if (resolvedAddresses.getAddresses().isEmpty()) {
+ handleNameResolutionError(Status.UNAVAILABLE.withDescription(
+ "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses()
+ + ", attrs=" + resolvedAddresses.getAttributes()));
+ return false;
+ }
+
+ List<EquivalentAddressGroup> servers = resolvedAddresses.getAddresses();
+ Set<EquivalentAddressGroup> currentAddrs = subchannels.keySet();
+ Map<EquivalentAddressGroup, EquivalentAddressGroup> latestAddrs = stripAttrs(servers);
+ Set<EquivalentAddressGroup> removedAddrs = setsDifference(currentAddrs, latestAddrs.keySet());
+
+ for (Map.Entry<EquivalentAddressGroup, EquivalentAddressGroup> latestEntry :
+ latestAddrs.entrySet()) {
+ EquivalentAddressGroup strippedAddressGroup = latestEntry.getKey();
+ EquivalentAddressGroup originalAddressGroup = latestEntry.getValue();
+ Subchannel existingSubchannel = subchannels.get(strippedAddressGroup);
+ if (existingSubchannel != null) {
+ // EAG's Attributes may have changed.
+ existingSubchannel.updateAddresses(Collections.singletonList(originalAddressGroup));
+ continue;
+ }
+ // Create new subchannels for new addresses.
+
+ // NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel
+ // doesn't need them. They're describing the resolved server list but we're not taking
+ // any action based on this information.
+ Attributes.Builder subchannelAttrs = Attributes.newBuilder()
+ // NB(lukaszx0): because attributes are immutable we can't set new value for the key
+ // after creation but since we can mutate the values we leverage that and set
+ // AtomicReference which will allow mutating state info for given channel.
+ .set(STATE_INFO,
+ new Ref<>(ConnectivityStateInfo.forNonError(IDLE)));
+
+ final Subchannel subchannel = checkNotNull(
+ helper.createSubchannel(CreateSubchannelArgs.newBuilder()
+ .setAddresses(originalAddressGroup)
+ .setAttributes(subchannelAttrs.build())
+ .build()),
+ "subchannel");
+ subchannel.start(new SubchannelStateListener() {
+ @Override
+ public void onSubchannelState(ConnectivityStateInfo state) {
+ processSubchannelState(subchannel, state);
+ }
+ });
+ subchannels.put(strippedAddressGroup, subchannel);
+ subchannel.requestConnection();
+ }
+
+ ArrayList<Subchannel> removedSubchannels = new ArrayList<>();
+ for (EquivalentAddressGroup addressGroup : removedAddrs) {
+ removedSubchannels.add(subchannels.remove(addressGroup));
+ }
+
+ // Update the picker before shutting down the subchannels, to reduce the chance of the race
+ // between picking a subchannel and shutting it down.
+ updateBalancingState();
+
+ // Shutdown removed subchannels
+ for (Subchannel removedSubchannel : removedSubchannels) {
+ shutdownSubchannel(removedSubchannel);
+ }
+
+ return true;
+ }
+
+ @Override
+ public void handleNameResolutionError(Status error) {
+ if (currentState != READY) {
+ updateBalancingState(TRANSIENT_FAILURE, new EmptyPicker(error));
+ }
+ }
+
+ private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
+ if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) {
+ return;
+ }
+ if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) {
+ helper.refreshNameResolution();
+ }
+ if (stateInfo.getState() == IDLE) {
+ subchannel.requestConnection();
+ }
+ Ref<ConnectivityStateInfo> subchannelStateRef = getSubchannelStateInfoRef(subchannel);
+ if (subchannelStateRef.value.getState().equals(TRANSIENT_FAILURE)) {
+ if (stateInfo.getState().equals(CONNECTING) || stateInfo.getState().equals(IDLE)) {
+ return;
+ }
+ }
+ subchannelStateRef.value = stateInfo;
+ updateBalancingState();
+ }
+
+ private void shutdownSubchannel(Subchannel subchannel) {
+ subchannel.shutdown();
+ getSubchannelStateInfoRef(subchannel).value =
+ ConnectivityStateInfo.forNonError(SHUTDOWN);
+ }
+
+ @Override
+ public void shutdown() {
+ for (Subchannel subchannel : getSubchannels()) {
+ shutdownSubchannel(subchannel);
+ }
+ subchannels.clear();
}
private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready");
@@ -71,27 +184,29 @@ public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
/**
* Updates picker with the list of active subchannels (state == READY).
*/
- @Override
- protected void updateOverallBalancingState() {
- List<ChildLbState> activeList = getReadyChildren();
+ @SuppressWarnings("ReferenceEquality")
+ private void updateBalancingState() {
+ List<Subchannel> activeList = filterNonFailingSubchannels(getSubchannels());
if (activeList.isEmpty()) {
- // No READY subchannels
-
- // RRLB will request connection immediately on subchannel IDLE.
+ // No READY subchannels, determine aggregate state and error status
boolean isConnecting = false;
- for (ChildLbState childLbState : getChildLbStates()) {
- ConnectivityState state = childLbState.getCurrentState();
- if (state == CONNECTING || state == IDLE) {
+ Status aggStatus = EMPTY_OK;
+ for (Subchannel subchannel : getSubchannels()) {
+ ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value;
+ // This subchannel IDLE is not because of channel IDLE_TIMEOUT,
+ // in which case LB is already shutdown.
+ // RRLB will request connection immediately on subchannel IDLE.
+ if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) {
isConnecting = true;
- break;
+ }
+ if (aggStatus == EMPTY_OK || !aggStatus.isOk()) {
+ aggStatus = stateInfo.getStatus();
}
}
-
- if (isConnecting) {
- updateBalancingState(CONNECTING, new EmptyPicker(Status.OK));
- } else {
- updateBalancingState(TRANSIENT_FAILURE, createReadyPicker(getChildLbStates()));
- }
+ updateBalancingState(isConnecting ? CONNECTING : TRANSIENT_FAILURE,
+ // If all subchannels are TRANSIENT_FAILURE, return the Status associated with
+ // an arbitrary subchannel, otherwise return OK.
+ new EmptyPicker(aggStatus));
} else {
updateBalancingState(READY, createReadyPicker(activeList));
}
@@ -99,39 +214,72 @@ public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
private void updateBalancingState(ConnectivityState state, RoundRobinPicker picker) {
if (state != currentState || !picker.isEquivalentTo(currentPicker)) {
- getHelper().updateBalancingState(state, picker);
+ helper.updateBalancingState(state, picker);
currentState = state;
currentPicker = picker;
}
}
- protected RoundRobinPicker createReadyPicker(Collection<ChildLbState> children) {
+ protected RoundRobinPicker createReadyPicker(List<Subchannel> activeList) {
// initialize the Picker to a random start index to ensure that a high frequency of Picker
// churn does not skew subchannel selection.
- int startIndex = random.nextInt(children.size());
+ int startIndex = random.nextInt(activeList.size());
+ return new ReadyPicker(activeList, startIndex);
+ }
- List<SubchannelPicker> pickerList = new ArrayList<>();
- for (ChildLbState child : children) {
- SubchannelPicker picker = child.getCurrentPicker();
- pickerList.add(picker);
+ /**
+ * Filters out non-ready subchannels.
+ */
+ private static List<Subchannel> filterNonFailingSubchannels(
+ Collection<Subchannel> subchannels) {
+ List<Subchannel> readySubchannels = new ArrayList<>(subchannels.size());
+ for (Subchannel subchannel : subchannels) {
+ if (isReady(subchannel)) {
+ readySubchannels.add(subchannel);
+ }
}
-
- return new ReadyPicker(pickerList, startIndex);
+ return readySubchannels;
}
/**
- * Filters out non-ready and deactivated child load balancers (subchannels).
+ * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and
+ * remove all attributes. The values are the original EAGs.
*/
- private List<ChildLbState> getReadyChildren() {
- List<ChildLbState> activeChildren = new ArrayList<>();
- for (ChildLbState child : getChildLbStates()) {
- if (!child.isDeactivated() && child.getCurrentState() == READY) {
- activeChildren.add(child);
- }
+ private static Map<EquivalentAddressGroup, EquivalentAddressGroup> stripAttrs(
+ List<EquivalentAddressGroup> groupList) {
+ Map<EquivalentAddressGroup, EquivalentAddressGroup> addrs = new HashMap<>(groupList.size() * 2);
+ for (EquivalentAddressGroup group : groupList) {
+ addrs.put(stripAttrs(group), group);
}
- return activeChildren;
+ return addrs;
+ }
+
+ private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
+ return new EquivalentAddressGroup(eag.getAddresses());
+ }
+
+ @VisibleForTesting
+ protected Collection<Subchannel> getSubchannels() {
+ return subchannels.values();
+ }
+
+ private static Ref<ConnectivityStateInfo> getSubchannelStateInfoRef(
+ Subchannel subchannel) {
+ return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO");
+ }
+
+ // package-private to avoid synthetic access
+ static boolean isReady(Subchannel subchannel) {
+ return getSubchannelStateInfoRef(subchannel).value.getState() == READY;
+ }
+
+ private static <T> Set<T> setsDifference(Set<T> a, Set<T> b) {
+ Set<T> aCopy = new HashSet<>(a);
+ aCopy.removeAll(b);
+ return aCopy;
}
+ // Only subclasses are ReadyPicker or EmptyPicker
public abstract static class RoundRobinPicker extends SubchannelPicker {
public abstract boolean isEquivalentTo(RoundRobinPicker picker);
}
@@ -141,11 +289,11 @@ public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
private static final AtomicIntegerFieldUpdater<ReadyPicker> indexUpdater =
AtomicIntegerFieldUpdater.newUpdater(ReadyPicker.class, "index");
- private final List<SubchannelPicker> list; // non-empty
+ private final List<Subchannel> list; // non-empty
@SuppressWarnings("unused")
private volatile int index;
- public ReadyPicker(List<SubchannelPicker> list, int startIndex) {
+ public ReadyPicker(List<Subchannel> list, int startIndex) {
Preconditions.checkArgument(!list.isEmpty(), "empty list");
this.list = list;
this.index = startIndex - 1;
@@ -153,7 +301,7 @@ public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
- return list.get(nextIndex()).pickSubchannel(args);
+ return PickResult.withSubchannel(nextSubchannel());
}
@Override
@@ -161,7 +309,7 @@ public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
return MoreObjects.toStringHelper(ReadyPicker.class).add("list", list).toString();
}
- private int nextIndex() {
+ private Subchannel nextSubchannel() {
int size = list.size();
int i = indexUpdater.incrementAndGet(this);
if (i >= size) {
@@ -169,11 +317,11 @@ public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
i %= size;
indexUpdater.compareAndSet(this, oldi, i);
}
- return i;
+ return list.get(i);
}
@VisibleForTesting
- List<SubchannelPicker> getList() {
+ List<Subchannel> getList() {
return list;
}
diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java
index ac5bd8b98..13f13421a 100644
--- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java
+++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java
@@ -512,7 +512,7 @@ public class OutlierDetectionLoadBalancerTest {
loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers));
- generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 12);
+ generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 8);
// Move forward in time to a point where the detection timer has fired.
forwardTime(config);
@@ -546,7 +546,7 @@ public class OutlierDetectionLoadBalancerTest {
assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0)));
// Now we produce more load, but the subchannel start working and is no longer an outlier.
- generateLoad(ImmutableMap.of(), 12);
+ generateLoad(ImmutableMap.of(), 8);
// Move forward in time to a point where the detection timer has fired.
fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS);
diff --git a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java
index 3b7f6599d..23b6e1c10 100644
--- a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java
+++ b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java
@@ -22,21 +22,23 @@ import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
+import static io.grpc.util.RoundRobinLoadBalancer.STATE_INFO;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
-import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
+import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
@@ -53,19 +55,16 @@ import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.Status;
-import io.grpc.internal.TestUtils;
-import io.grpc.util.MultiChildLoadBalancer.ChildLbState;
import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker;
import io.grpc.util.RoundRobinLoadBalancer.ReadyPicker;
+import io.grpc.util.RoundRobinLoadBalancer.Ref;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
-import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
@@ -76,8 +75,10 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
+import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
+import org.mockito.stubbing.Answer;
/** Unit test for {@link RoundRobinLoadBalancer}. */
@RunWith(JUnit4.class)
@@ -88,9 +89,7 @@ public class RoundRobinLoadBalancerTest {
private RoundRobinLoadBalancer loadBalancer;
private final List<EquivalentAddressGroup> servers = Lists.newArrayList();
- private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels =
- new ConcurrentHashMap<>();
- private final Map<Subchannel, Subchannel> mockToRealSubChannelMap = new HashMap<>();
+ private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Maps.newLinkedHashMap();
private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
Maps.newLinkedHashMap();
private final Attributes affinity =
@@ -102,7 +101,8 @@ public class RoundRobinLoadBalancerTest {
private ArgumentCaptor<ConnectivityState> stateCaptor;
@Captor
private ArgumentCaptor<CreateSubchannelArgs> createArgsCaptor;
- private Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper()));
+ @Mock
+ private Helper mockHelper;
@Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown().
private PickSubchannelArgs mockArgs;
@@ -113,14 +113,32 @@ public class RoundRobinLoadBalancerTest {
SocketAddress addr = new FakeSocketAddress("server" + i);
EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
servers.add(eag);
+ Subchannel sc = mock(Subchannel.class);
+ subchannels.put(Arrays.asList(eag), sc);
}
- loadBalancer = new RoundRobinLoadBalancer(mockHelper);
- }
+ when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class)))
+ .then(new Answer<Subchannel>() {
+ @Override
+ public Subchannel answer(InvocationOnMock invocation) throws Throwable {
+ CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0];
+ final Subchannel subchannel = subchannels.get(args.getAddresses());
+ when(subchannel.getAllAddresses()).thenReturn(args.getAddresses());
+ when(subchannel.getAttributes()).thenReturn(args.getAttributes());
+ doAnswer(
+ new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable {
+ subchannelStateListeners.put(
+ subchannel, (SubchannelStateListener) invocation.getArguments()[0]);
+ return null;
+ }
+ }).when(subchannel).start(any(SubchannelStateListener.class));
+ return subchannel;
+ }
+ });
- private boolean acceptAddresses(List<EquivalentAddressGroup> eagList, Attributes attrs) {
- return loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder().setAddresses(eagList).setAttributes(attrs).build());
+ loadBalancer = new RoundRobinLoadBalancer(mockHelper);
}
@After
@@ -130,9 +148,10 @@ public class RoundRobinLoadBalancerTest {
@Test
public void pickAfterResolved() throws Exception {
- boolean addressesAccepted = acceptAddresses(servers, affinity);
- assertThat(addressesAccepted).isTrue();
final Subchannel readySubchannel = subchannels.values().iterator().next();
+ boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build());
+ assertThat(addressesAccepted).isTrue();
deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture());
@@ -159,6 +178,10 @@ public class RoundRobinLoadBalancerTest {
@Test
public void pickAfterResolvedUpdatedHosts() throws Exception {
+ Subchannel removedSubchannel = mock(Subchannel.class);
+ Subchannel oldSubchannel = mock(Subchannel.class);
+ Subchannel newSubchannel = mock(Subchannel.class);
+
Attributes.Key<String> key = Attributes.Key.create("check-that-it-is-propagated");
FakeSocketAddress removedAddr = new FakeSocketAddress("removed");
EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr);
@@ -170,13 +193,6 @@ public class RoundRobinLoadBalancerTest {
EquivalentAddressGroup newEag = new EquivalentAddressGroup(
newAddr, Attributes.newBuilder().set(key, "newattr").build());
- Subchannel removedSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder()
- .setAddresses(removedEag).build());
- Subchannel oldSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder()
- .setAddresses(oldEag1).build());
- Subchannel newSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder()
- .setAddresses(newEag).build());
-
subchannels.put(Collections.singletonList(removedEag), removedSubchannel);
subchannels.put(Collections.singletonList(oldEag1), oldSubchannel);
subchannels.put(Collections.singletonList(newEag), newSubchannel);
@@ -185,7 +201,9 @@ public class RoundRobinLoadBalancerTest {
InOrder inOrder = inOrder(mockHelper);
- boolean addressesAccepted = acceptAddresses(currentServers, affinity);
+ boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity)
+ .build());
assertThat(addressesAccepted).isTrue();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
@@ -200,11 +218,8 @@ public class RoundRobinLoadBalancerTest {
verify(removedSubchannel, times(1)).requestConnection();
verify(oldSubchannel, times(1)).requestConnection();
- assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2);
- assertThat(loadBalancer.getChildLbStateEag(removedEag).getCurrentPicker().pickSubchannel(null)
- .getSubchannel()).isEqualTo(removedSubchannel);
- assertThat(loadBalancer.getChildLbStateEag(oldEag1).getCurrentPicker().pickSubchannel(null)
- .getSubchannel()).isEqualTo(oldSubchannel);
+ assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel,
+ oldSubchannel);
// This time with Attributes
List<EquivalentAddressGroup> latestServers = Lists.newArrayList(oldEag2, newEag);
@@ -217,15 +232,13 @@ public class RoundRobinLoadBalancerTest {
verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2));
verify(removedSubchannel, times(1)).shutdown();
+ deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(SHUTDOWN));
deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY));
- assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2);
- assertThat(loadBalancer.getChildLbStateEag(newEag).getCurrentPicker()
- .pickSubchannel(null).getSubchannel()).isEqualTo(newSubchannel);
- assertThat(loadBalancer.getChildLbStateEag(oldEag2).getCurrentPicker()
- .pickSubchannel(null).getSubchannel()).isEqualTo(oldSubchannel);
+ assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel,
+ newSubchannel);
- verify(mockHelper, times(6)).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture());
picker = pickerCaptor.getValue();
@@ -237,26 +250,29 @@ public class RoundRobinLoadBalancerTest {
@Test
public void pickAfterStateChange() throws Exception {
InOrder inOrder = inOrder(mockHelper);
- boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
+ boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
+ .build());
assertThat(addressesAccepted).isTrue();
-
- // TODO figure out if this method testing the right things
-
- ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next();
- Subchannel subchannel = childLbState.getCurrentPicker().pickSubchannel(null).getSubchannel();
+ Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
+ Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
+ STATE_INFO);
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
- assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING);
+ assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE));
- deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
+ deliverSubchannelState(subchannel,
+ ConnectivityStateInfo.forNonError(READY));
inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class);
- assertThat(childLbState.getCurrentState()).isEqualTo(READY);
+ assertThat(subchannelStateInfo.value).isEqualTo(
+ ConnectivityStateInfo.forNonError(READY));
Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯");
deliverSubchannelState(subchannel,
ConnectivityStateInfo.forTransientFailure(error));
- assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
+ assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE);
+ assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error);
inOrder.verify(mockHelper).refreshNameResolution();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class);
@@ -264,7 +280,8 @@ public class RoundRobinLoadBalancerTest {
deliverSubchannelState(subchannel,
ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(mockHelper).refreshNameResolution();
- assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE);
+ assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE);
+ assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error);
verify(subchannel, times(2)).requestConnection();
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
@@ -274,14 +291,15 @@ public class RoundRobinLoadBalancerTest {
@Test
public void ignoreShutdownSubchannelStateChange() {
InOrder inOrder = inOrder(mockHelper);
- boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
+ boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
+ .build());
assertThat(addressesAccepted).isTrue();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
loadBalancer.shutdown();
- for (ChildLbState child : loadBalancer.getChildLbStates()) {
- Subchannel sc = child.getCurrentPicker().pickSubchannel(null).getSubchannel();
- verify(child).shutdown();
+ for (Subchannel sc : loadBalancer.getSubchannels()) {
+ verify(sc).shutdown();
// When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered
// back to the subchannel state listener.
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(SHUTDOWN));
@@ -293,34 +311,36 @@ public class RoundRobinLoadBalancerTest {
@Test
public void stayTransientFailureUntilReady() {
InOrder inOrder = inOrder(mockHelper);
- boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
+ boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
+ .build());
assertThat(addressesAccepted).isTrue();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
- Map<ChildLbState, Subchannel> childToSubChannelMap = new HashMap<>();
// Simulate state transitions for each subchannel individually.
- for ( ChildLbState child : loadBalancer.getChildLbStates()) {
- Subchannel sc = child.getSubchannels(mockArgs);
- childToSubChannelMap.put(child, sc);
+ for (Subchannel sc : loadBalancer.getSubchannels()) {
Status error = Status.UNKNOWN.withDescription("connection broken");
deliverSubchannelState(
sc,
ConnectivityStateInfo.forTransientFailure(error));
- assertEquals(TRANSIENT_FAILURE, child.getCurrentState());
inOrder.verify(mockHelper).refreshNameResolution();
deliverSubchannelState(
sc,
ConnectivityStateInfo.forNonError(CONNECTING));
- assertEquals(TRANSIENT_FAILURE, child.getCurrentState());
+ Ref<ConnectivityStateInfo> scStateInfo = sc.getAttributes().get(
+ STATE_INFO);
+ assertThat(scStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE);
+ assertThat(scStateInfo.value.getStatus()).isEqualTo(error);
}
- inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(ReadyPicker.class));
+ inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(EmptyPicker.class));
inOrder.verifyNoMoreInteractions();
- ChildLbState child = loadBalancer.getChildLbStates().iterator().next();
- Subchannel subchannel = childToSubChannelMap.get(child);
+ Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
- assertThat(child.getCurrentState()).isEqualTo(READY);
+ Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
+ STATE_INFO);
+ assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(READY));
inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class));
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
@@ -330,15 +350,16 @@ public class RoundRobinLoadBalancerTest {
@Test
public void refreshNameResolutionWhenSubchannelConnectionBroken() {
InOrder inOrder = inOrder(mockHelper);
- boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
+ boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
+ .build());
assertThat(addressesAccepted).isTrue();
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
// Simulate state transitions for each subchannel individually.
- for (ChildLbState child : loadBalancer.getChildLbStates()) {
- Subchannel sc = child.getSubchannels(mockArgs);
+ for (Subchannel sc : loadBalancer.getSubchannels()) {
verify(sc).requestConnection();
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING));
Status error = Status.UNKNOWN.withDescription("connection broken");
@@ -349,7 +370,7 @@ public class RoundRobinLoadBalancerTest {
// Simulate receiving go-away so READY subchannels transit to IDLE.
deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(mockHelper).refreshNameResolution();
- verify(sc, times(1)).requestConnection();
+ verify(sc, times(2)).requestConnection();
inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
}
@@ -362,13 +383,12 @@ public class RoundRobinLoadBalancerTest {
Subchannel subchannel1 = mock(Subchannel.class);
Subchannel subchannel2 = mock(Subchannel.class);
- ArrayList<SubchannelPicker> pickers = Lists.newArrayList(
- TestUtils.pickerOf(subchannel), TestUtils.pickerOf(subchannel1),
- TestUtils.pickerOf(subchannel2));
-
- ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList(pickers),
+ ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList(
+ Lists.newArrayList(subchannel, subchannel1, subchannel2)),
0 /* startIndex */);
+ assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2);
+
assertEquals(subchannel, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(subchannel1, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(subchannel2, picker.pickSubchannel(mockArgs).getSubchannel());
@@ -379,7 +399,7 @@ public class RoundRobinLoadBalancerTest {
public void pickerEmptyList() throws Exception {
SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN);
- assertNull(picker.pickSubchannel(mockArgs).getSubchannel());
+ assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel());
assertEquals(Status.UNKNOWN,
picker.pickSubchannel(mockArgs).getStatus());
}
@@ -397,13 +417,12 @@ public class RoundRobinLoadBalancerTest {
@Test
public void nameResolutionErrorWithActiveChannels() throws Exception {
- boolean addressesAccepted = acceptAddresses(servers, affinity);
final Subchannel readySubchannel = subchannels.values().iterator().next();
+ boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build());
assertThat(addressesAccepted).isTrue();
deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
- loadBalancer.resolvingAddresses = true;
loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError"));
- loadBalancer.resolvingAddresses = false;
verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
verify(mockHelper, times(2))
@@ -424,14 +443,15 @@ public class RoundRobinLoadBalancerTest {
@Test
public void subchannelStateIsolation() throws Exception {
- boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY);
- assertThat(addressesAccepted).isTrue();
-
Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
Subchannel sc1 = subchannelIterator.next();
Subchannel sc2 = subchannelIterator.next();
Subchannel sc3 = subchannelIterator.next();
+ boolean addressesAccepted = loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY)
+ .build());
+ assertThat(addressesAccepted).isTrue();
verify(sc1, times(1)).requestConnection();
verify(sc2, times(1)).requestConnection();
verify(sc3, times(1)).requestConnection();
@@ -458,7 +478,7 @@ public class RoundRobinLoadBalancerTest {
// The IDLE subchannel is dropped from the picker, but a reconnection is requested
assertEquals(READY, stateIterator.next());
assertThat(getList(pickers.next())).containsExactly(sc1, sc3);
- verify(sc2, times(1)).requestConnection();
+ verify(sc2, times(2)).requestConnection();
// The failing subchannel is dropped from the picker, with no requested reconnect
assertEquals(READY, stateIterator.next());
assertThat(getList(pickers.next())).containsExactly(sc1);
@@ -471,7 +491,7 @@ public class RoundRobinLoadBalancerTest {
public void readyPicker_emptyList() {
// ready picker list must be non-empty
try {
- new ReadyPicker(Collections.emptyList(), 0);
+ new ReadyPicker(Collections.<Subchannel>emptyList(), 0);
fail();
} catch (IllegalArgumentException expected) {
}
@@ -483,10 +503,9 @@ public class RoundRobinLoadBalancerTest {
EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK"));
EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"));
- acceptAddresses(servers, Attributes.EMPTY); // create subchannels
Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
- SubchannelPicker sc1 = TestUtils.pickerOf(subchannelIterator.next());
- SubchannelPicker sc2 = TestUtils.pickerOf(subchannelIterator.next());
+ Subchannel sc1 = subchannelIterator.next();
+ Subchannel sc2 = subchannelIterator.next();
ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 0);
ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 0);
ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 1);
@@ -507,27 +526,18 @@ public class RoundRobinLoadBalancerTest {
public void emptyAddresses() {
assertThat(loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder()
- .setAddresses(Collections.emptyList())
+ .setAddresses(Collections.<EquivalentAddressGroup>emptyList())
.setAttributes(affinity)
.build())).isFalse();
}
- private List<Subchannel> getList(SubchannelPicker picker) {
-
- if (picker instanceof ReadyPicker) {
- List<Subchannel> subchannelList = new ArrayList<>();
- for (SubchannelPicker childPicker : ((ReadyPicker) picker).getList()) {
- subchannelList.add(childPicker.pickSubchannel(mockArgs).getSubchannel());
- }
- return subchannelList;
- } else {
- return new ArrayList<>();
- }
+ private static List<Subchannel> getList(SubchannelPicker picker) {
+ return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() :
+ Collections.<Subchannel>emptyList();
}
private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) {
- Subchannel realSc = mockToRealSubChannelMap.get(subchannel);
- subchannelStateListeners.get(realSc).onSubchannelState(newState);
+ subchannelStateListeners.get(subchannel).onSubchannelState(newState);
}
private static class FakeSocketAddress extends SocketAddress {
@@ -542,22 +552,4 @@ public class RoundRobinLoadBalancerTest {
return "FakeSocketAddress-" + name;
}
}
-
- private class TestHelper extends AbstractTestHelper {
-
- @Override
- public Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap() {
- return subchannels;
- }
-
- @Override
- public Map<Subchannel, Subchannel> getMockToRealSubChannelMap() {
- return mockToRealSubChannelMap;
- }
-
- @Override
- public Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners() {
- return subchannelStateListeners;
- }
- }
}
diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java
deleted file mode 100644
index 409861783..000000000
--- a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
- * Copyright 2023 The gRPC Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.grpc.util;
-
-import static org.mockito.AdditionalAnswers.delegatesTo;
-import static org.mockito.Mockito.mock;
-
-import io.grpc.Attributes;
-import io.grpc.Channel;
-import io.grpc.ChannelLogger;
-import io.grpc.ConnectivityState;
-import io.grpc.EquivalentAddressGroup;
-import io.grpc.LoadBalancer.CreateSubchannelArgs;
-import io.grpc.LoadBalancer.Helper;
-import io.grpc.LoadBalancer.Subchannel;
-import io.grpc.LoadBalancer.SubchannelPicker;
-import io.grpc.LoadBalancer.SubchannelStateListener;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-
-/**
- * A real class that can be used as a delegate of a mock Helper to provide more real representation
- * and track the subchannels as is needed with petiole policies where the subchannels are no
- * longer direct children of the loadbalancer.
- * <br>
- * To use it replace <br>
- * \@mock Helper mockHelper<br>
- * with<br>
- * <p>Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper()));</p>
- * <br>
- * TestHelper will need to define accessors for the maps that information is store within as
- * those maps need to be defined in the Test class.
- */
-public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper {
-
- public abstract Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap();
-
- public abstract Map<Subchannel, Subchannel> getMockToRealSubChannelMap();
-
- public abstract Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners();
-
- @Override
- public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
- // do nothing, should have been done in the wrapper helpers
- }
-
- @Override
- protected Helper delegate() {
- throw new UnsupportedOperationException("This helper class is only for use in this test");
- }
-
- @Override
- public Subchannel createSubchannel(CreateSubchannelArgs args) {
- Subchannel subchannel = getSubchannelMap().get(args.getAddresses());
- if (subchannel == null) {
- TestSubchannel delegate = new TestSubchannel(args);
- subchannel = mock(Subchannel.class, delegatesTo(delegate));
- getSubchannelMap().put(args.getAddresses(), subchannel);
- getMockToRealSubChannelMap().put(subchannel, delegate);
- }
-
- return subchannel;
- }
-
- @Override
- public void refreshNameResolution() {
- // no-op
- }
-
- public void setChannel(Subchannel subchannel, Channel channel) {
- ((TestSubchannel)subchannel).channel = channel;
- }
-
- @Override
- public String toString() {
- return "Test Helper";
- }
-
- private class TestSubchannel extends ForwardingSubchannel {
- final CreateSubchannelArgs args;
- Channel channel;
-
- public TestSubchannel(CreateSubchannelArgs args) {
- this.args = args;
- }
-
- @Override
- protected Subchannel delegate() {
- throw new UnsupportedOperationException("Only to be used in tests");
- }
-
- @Override
- public List<EquivalentAddressGroup> getAllAddresses() {
- return args.getAddresses();
- }
-
- @Override
- public Attributes getAttributes() {
- return args.getAttributes();
- }
-
- @Override
- public void requestConnection() {
- // Ignore, we will manually update state
- }
-
- @Override
- public void updateAddresses(List<EquivalentAddressGroup> addrs) {
- // Do nothing, will be handled in wrappers
- }
-
- @Override
- public void start(SubchannelStateListener listener) {
- getSubchannelStateListeners().put(this, listener);
- }
-
- @Override
- public void shutdown() {
- getSubchannelStateListeners().remove(this);
- for (EquivalentAddressGroup eag : getAllAddresses()) {
- getSubchannelMap().remove(Collections.singletonList(eag));
- }
- }
-
- @Override
- public Channel asChannel() {
- return channel;
- }
-
- @Override
- public ChannelLogger getChannelLogger() {
- return mock(ChannelLogger.class);
- }
-
- @Override
- public String toString() {
- return "Mock Subchannel" + args.toString();
- }
- }
-}
-
diff --git a/xds/build.gradle b/xds/build.gradle
index a6db9db99..3f3cf6a0f 100644
--- a/xds/build.gradle
+++ b/xds/build.gradle
@@ -58,8 +58,7 @@ dependencies {
def nettyDependency = implementation project(':grpc-netty')
testImplementation project(':grpc-rls')
- testImplementation testFixtures(project(':grpc-core')),
- testFixtures(project(':grpc-util'))
+ testImplementation testFixtures(project(':grpc-core'))
annotationProcessor libraries.auto.value
// At runtime use the epoll included in grpc-netty-shaded
diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java
index 895125d32..a44892042 100644
--- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java
+++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java
@@ -16,68 +16,36 @@
package io.grpc.xds;
-import static com.google.common.base.Preconditions.checkNotNull;
-
-import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import io.grpc.InternalLogId;
-import io.grpc.LoadBalancerProvider;
import io.grpc.Status;
-import io.grpc.SynchronizationContext;
-import io.grpc.SynchronizationContext.ScheduledHandle;
import io.grpc.internal.ServiceConfigUtil.PolicySelection;
import io.grpc.util.MultiChildLoadBalancer;
import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig;
import io.grpc.xds.XdsLogger.XdsLogLevel;
import java.util.HashMap;
import java.util.Map;
-import java.util.Map.Entry;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.TimeUnit;
-import javax.annotation.Nullable;
/**
* The top-level load balancing policy.
*/
class ClusterManagerLoadBalancer extends MultiChildLoadBalancer {
- @VisibleForTesting
- public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15;
- protected final SynchronizationContext syncContext;
- private final ScheduledExecutorService timeService;
private final XdsLogger logger;
ClusterManagerLoadBalancer(Helper helper) {
super(helper);
- this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
- this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
logger = XdsLogger.withLogId(
InternalLogId.allocate("cluster_manager-lb", helper.getAuthority()));
-
logger.log(XdsLogLevel.INFO, "Created");
}
@Override
- protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses,
- Object childConfig) {
- return resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build();
- }
-
- @Override
- protected Map<Object, ChildLbState> createChildLbMap(ResolvedAddresses resolvedAddresses) {
+ protected Map<Object, PolicySelection> getPolicySelectionMap(
+ ResolvedAddresses resolvedAddresses) {
ClusterManagerConfig config = (ClusterManagerConfig)
resolvedAddresses.getLoadBalancingPolicyConfig();
- Map<Object, ChildLbState> newChildPolicies = new HashMap<>();
- if (config != null) {
- for (Entry<String, PolicySelection> entry : config.childPolicies.entrySet()) {
- ChildLbState child = getChildLbState(entry.getKey());
- if (child == null) {
- child = new ClusterManagerLbState(entry.getKey(),
- entry.getValue().getProvider(), entry.getValue().getConfig(), getInitialPicker());
- }
- newChildPolicies.put(entry.getKey(), child);
- }
- }
+ Map<Object, PolicySelection> newChildPolicies = new HashMap<>(config.childPolicies);
logger.log(
XdsLogLevel.INFO,
"Received cluster_manager lb config: child names={0}", newChildPolicies.keySet());
@@ -107,58 +75,4 @@ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer {
}
};
}
-
- private class ClusterManagerLbState extends ChildLbState {
- @Nullable
- ScheduledHandle deletionTimer;
-
- public ClusterManagerLbState(Object key, LoadBalancerProvider policyProvider,
- Object childConfig, SubchannelPicker initialPicker) {
- super(key, policyProvider, childConfig, initialPicker);
- }
-
- @Override
- protected void shutdown() {
- if (deletionTimer != null && deletionTimer.isPending()) {
- deletionTimer.cancel();
- }
- super.shutdown();
- }
-
- @Override
- protected void reactivate(LoadBalancerProvider policyProvider) {
- if (deletionTimer != null && deletionTimer.isPending()) {
- deletionTimer.cancel();
- logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey());
- }
-
- super.reactivate(policyProvider);
- }
-
- @Override
- protected void deactivate() {
- if (isDeactivated()) {
- return;
- }
-
- class DeletionTask implements Runnable {
-
- @Override
- public void run() {
- shutdown();
- removeChild(getKey());
- }
- }
-
- deletionTimer =
- syncContext.schedule(
- new DeletionTask(),
- DELAYED_CHILD_DELETION_TIME_MINUTES,
- TimeUnit.MINUTES,
- timeService);
- setDeactivated();
- logger.log(XdsLogLevel.DEBUG, "Child balancer {0} deactivated", getKey());
- }
-
- }
}
diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
index 216221d25..833683729 100644
--- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
+++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
@@ -17,20 +17,17 @@
package io.grpc.xds;
import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.common.base.Preconditions.checkElementIndex;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.Deadline.Ticker;
import io.grpc.EquivalentAddressGroup;
import io.grpc.ExperimentalApi;
import io.grpc.LoadBalancer;
-import io.grpc.LoadBalancerProvider;
import io.grpc.NameResolver;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
@@ -43,13 +40,11 @@ import io.grpc.xds.orca.OrcaOobUtil;
import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener;
import io.grpc.xds.orca.OrcaPerRequestUtil;
import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener;
-import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
-import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@@ -96,14 +91,6 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
@Override
- protected ChildLbState createChildLbState(Object key, Object policyConfig,
- SubchannelPicker initialPicker) {
- ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig,
- initialPicker);
- return childLbState;
- }
-
- @Override
public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) {
handleNameResolutionError(Status.UNAVAILABLE.withDescription(
@@ -124,100 +111,9 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
@Override
- public RoundRobinPicker createReadyPicker(Collection<ChildLbState> activeList) {
- return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList),
- config.enableOobLoadReport, config.errorUtilizationPenalty);
- }
-
- @Override
- protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) {
- return super.getChildLbStateEag(eag);
- }
-
- @VisibleForTesting
- final class WeightedChildLbState extends ChildLbState {
-
- private final Set<WrrSubchannel> subchannels = new HashSet<>();
- private volatile long lastUpdated;
- private volatile long nonEmptySince;
- private volatile double weight = 0;
-
- private OrcaReportListener orcaReportListener;
-
- public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
- SubchannelPicker initialPicker) {
- super(key, policyProvider, childConfig, initialPicker);
- }
-
- @VisibleForTesting
- EquivalentAddressGroup getEag() {
- return stripAttrs((EquivalentAddressGroup) getKey());
- }
-
- private double getWeight() {
- if (config == null) {
- return 0;
- }
- long now = ticker.nanoTime();
- if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
- nonEmptySince = infTime;
- return 0;
- } else if (now - nonEmptySince < config.blackoutPeriodNanos
- && config.blackoutPeriodNanos > 0) {
- return 0;
- } else {
- return weight;
- }
- }
-
- public void addSubchannel(WrrSubchannel wrrSubchannel) {
- subchannels.add(wrrSubchannel);
- }
-
- public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) {
- if (orcaReportListener != null
- && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) {
- return orcaReportListener;
- }
- orcaReportListener = new OrcaReportListener(errorUtilizationPenalty);
- return orcaReportListener;
- }
-
- public void removeSubchannel(WrrSubchannel wrrSubchannel) {
- subchannels.remove(wrrSubchannel);
- }
-
- final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener {
- private final float errorUtilizationPenalty;
-
- OrcaReportListener(float errorUtilizationPenalty) {
- this.errorUtilizationPenalty = errorUtilizationPenalty;
- }
-
- @Override
- public void onLoadReport(MetricReport report) {
- double newWeight = 0;
- // Prefer application utilization and fallback to CPU utilization if unset.
- double utilization =
- report.getApplicationUtilization() > 0 ? report.getApplicationUtilization()
- : report.getCpuUtilization();
- if (utilization > 0 && report.getQps() > 0) {
- double penalty = 0;
- if (report.getEps() > 0 && errorUtilizationPenalty > 0) {
- penalty = report.getEps() / report.getQps() * errorUtilizationPenalty;
- }
- newWeight = report.getQps() / (utilization + penalty);
- }
- if (newWeight == 0) {
- return;
- }
- if (nonEmptySince == infTime) {
- nonEmptySince = ticker.nanoTime();
- }
- lastUpdated = ticker.nanoTime();
- weight = newWeight;
- }
- }
+ public RoundRobinPicker createReadyPicker(List<Subchannel> activeList) {
+ return new WeightedRoundRobinPicker(activeList, config.enableOobLoadReport,
+ config.errorUtilizationPenalty);
}
private final class UpdateWeightTask implements Runnable {
@@ -232,18 +128,16 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
private void afterAcceptAddresses() {
- for (ChildLbState child : getChildLbStates()) {
- WeightedChildLbState wChild = (WeightedChildLbState) child;
- for (WrrSubchannel weightedSubchannel : wChild.subchannels) {
- if (config.enableOobLoadReport) {
- OrcaOobUtil.setListener(weightedSubchannel,
- wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty),
- OrcaOobUtil.OrcaReportingConfig.newBuilder()
- .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
- .build());
- } else {
- OrcaOobUtil.setListener(weightedSubchannel, null, null);
- }
+ for (Subchannel subchannel : getSubchannels()) {
+ WrrSubchannel weightedSubchannel = (WrrSubchannel) subchannel;
+ if (config.enableOobLoadReport) {
+ OrcaOobUtil.setListener(weightedSubchannel,
+ weightedSubchannel.new OrcaReportListener(config.errorUtilizationPenalty),
+ OrcaOobUtil.OrcaReportingConfig.newBuilder()
+ .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS)
+ .build());
+ } else {
+ OrcaOobUtil.setListener(weightedSubchannel, null, null);
}
}
}
@@ -275,69 +169,105 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
@Override
public Subchannel createSubchannel(CreateSubchannelArgs args) {
- checkElementIndex(0, args.getAddresses().size(), "Empty address group");
- WeightedChildLbState childLbState =
- (WeightedChildLbState) wrr.getChildLbStateEag(args.getAddresses().get(0));
- return wrr.new WrrSubchannel(delegate().createSubchannel(args), childLbState);
+ return wrr.new WrrSubchannel(delegate().createSubchannel(args));
}
}
@VisibleForTesting
final class WrrSubchannel extends ForwardingSubchannel {
private final Subchannel delegate;
- private final WeightedChildLbState owner;
+ private volatile long lastUpdated;
+ private volatile long nonEmptySince;
+ private volatile double weight;
- WrrSubchannel(Subchannel delegate, WeightedChildLbState owner) {
+ WrrSubchannel(Subchannel delegate) {
this.delegate = checkNotNull(delegate, "delegate");
- this.owner = checkNotNull(owner, "owner");
}
@Override
public void start(SubchannelStateListener listener) {
- owner.addSubchannel(this);
delegate().start(new SubchannelStateListener() {
@Override
public void onSubchannelState(ConnectivityStateInfo newState) {
if (newState.getState().equals(ConnectivityState.READY)) {
- owner.nonEmptySince = infTime;
+ nonEmptySince = infTime;
}
listener.onSubchannelState(newState);
}
});
}
+ private double getWeight() {
+ if (config == null) {
+ return 0;
+ }
+ long now = ticker.nanoTime();
+ if (now - lastUpdated >= config.weightExpirationPeriodNanos) {
+ nonEmptySince = infTime;
+ return 0;
+ } else if (now - nonEmptySince < config.blackoutPeriodNanos
+ && config.blackoutPeriodNanos > 0) {
+ return 0;
+ } else {
+ return weight;
+ }
+ }
+
@Override
protected Subchannel delegate() {
return delegate;
}
- @Override
- public void shutdown() {
- super.shutdown();
- owner.removeSubchannel(this);
+ final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener {
+ private final float errorUtilizationPenalty;
+
+ OrcaReportListener(float errorUtilizationPenalty) {
+ this.errorUtilizationPenalty = errorUtilizationPenalty;
+ }
+
+ @Override
+ public void onLoadReport(MetricReport report) {
+ double newWeight = 0;
+ // Prefer application utilization and fallback to CPU utilization if unset.
+ double utilization =
+ report.getApplicationUtilization() > 0 ? report.getApplicationUtilization()
+ : report.getCpuUtilization();
+ if (utilization > 0 && report.getQps() > 0) {
+ double penalty = 0;
+ if (report.getEps() > 0 && errorUtilizationPenalty > 0) {
+ penalty = report.getEps() / report.getQps() * errorUtilizationPenalty;
+ }
+ newWeight = report.getQps() / (utilization + penalty);
+ }
+ if (newWeight == 0) {
+ return;
+ }
+ if (nonEmptySince == infTime) {
+ nonEmptySince = ticker.nanoTime();
+ }
+ lastUpdated = ticker.nanoTime();
+ weight = newWeight;
+ }
}
}
@VisibleForTesting
final class WeightedRoundRobinPicker extends RoundRobinPicker {
- private final List<ChildLbState> children;
+ private final List<Subchannel> list;
private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap =
new HashMap<>();
private final boolean enableOobLoadReport;
private final float errorUtilizationPenalty;
private volatile StaticStrideScheduler scheduler;
- WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport,
+ WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport,
float errorUtilizationPenalty) {
- checkNotNull(children, "children");
- Preconditions.checkArgument(!children.isEmpty(), "empty child list");
- this.children = children;
- for (ChildLbState child : children) {
- WeightedChildLbState wChild = (WeightedChildLbState) child;
- for (WrrSubchannel subchannel : wChild.subchannels) {
- this.subchannelToReportListenerMap
- .put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty));
- }
+ checkNotNull(list, "list");
+ Preconditions.checkArgument(!list.isEmpty(), "empty list");
+ this.list = list;
+ for (Subchannel subchannel : list) {
+ this.subchannelToReportListenerMap.put(subchannel,
+ ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty));
}
this.enableOobLoadReport = enableOobLoadReport;
this.errorUtilizationPenalty = errorUtilizationPenalty;
@@ -346,24 +276,22 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
@Override
public PickResult pickSubchannel(PickSubchannelArgs args) {
- ChildLbState childLbState = children.get(scheduler.pick());
- WeightedChildLbState wChild = (WeightedChildLbState) childLbState;
- PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args);
- Subchannel subchannel = pickResult.getSubchannel();
+ Subchannel subchannel = list.get(scheduler.pick());
if (!enableOobLoadReport) {
return PickResult.withSubchannel(subchannel,
- OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
+ OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
subchannelToReportListenerMap.getOrDefault(subchannel,
- wChild.getOrCreateOrcaListener(errorUtilizationPenalty))));
+ ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty))));
} else {
return PickResult.withSubchannel(subchannel);
}
}
private void updateWeight() {
- float[] newWeights = new float[children.size()];
- for (int i = 0; i < children.size(); i++) {
- double newWeight = ((WeightedChildLbState)children.get(i)).getWeight();
+ float[] newWeights = new float[list.size()];
+ for (int i = 0; i < list.size(); i++) {
+ WrrSubchannel subchannel = (WrrSubchannel) list.get(i);
+ double newWeight = subchannel.getWeight();
newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f;
}
this.scheduler = new StaticStrideScheduler(newWeights, sequence);
@@ -374,12 +302,12 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class)
.add("enableOobLoadReport", enableOobLoadReport)
.add("errorUtilizationPenalty", errorUtilizationPenalty)
- .add("list", children).toString();
+ .add("list", list).toString();
}
@VisibleForTesting
- List<ChildLbState> getChildren() {
- return children;
+ List<Subchannel> getList() {
+ return list;
}
@Override
@@ -394,8 +322,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
// the lists cannot contain duplicate subchannels
return enableOobLoadReport == other.enableOobLoadReport
&& Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0
- && children.size() == other.children.size() && new HashSet<>(
- children).containsAll(other.children);
+ && list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list);
}
}
@@ -577,13 +504,11 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
- @SuppressWarnings("UnusedReturnValue")
Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) {
this.blackoutPeriodNanos = blackoutPeriodNanos;
return this;
}
- @SuppressWarnings("UnusedReturnValue")
Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) {
this.weightExpirationPeriodNanos = weightExpirationPeriodNanos;
return this;
diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java
index 32e905225..c90a9f58d 100644
--- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java
+++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java
@@ -202,9 +202,7 @@ public final class OrcaOobUtil {
*/
public static void setListener(Subchannel subchannel, OrcaOobReportListener listener,
OrcaReportingConfig config) {
- Attributes attributes = subchannel.getAttributes();
- SubchannelImpl orcaSubchannel =
- (attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY);
+ SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY);
if (orcaSubchannel == null) {
throw new IllegalArgumentException("Subchannel does not have orca Out-Of-Band stream enabled."
+ " Try to use a subchannel created by OrcaOobUtil.OrcaHelper.");
@@ -243,9 +241,7 @@ public final class OrcaOobUtil {
public Subchannel createSubchannel(CreateSubchannelArgs args) {
syncContext.throwIfNotInThisSynchronizationContext();
Subchannel subchannel = super.createSubchannel(args);
- Attributes attributes = subchannel.getAttributes();
- SubchannelImpl orcaSubchannel =
- (attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY);
+ SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY);
OrcaReportingState orcaState;
if (orcaSubchannel == null) {
// Only the first load balancing policy requesting ORCA reports instantiates an
diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
index c59ad1318..ac08f69f8 100644
--- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
+++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
@@ -17,10 +17,11 @@
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
-import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
@@ -34,6 +35,7 @@ import com.google.common.collect.Maps;
import com.google.protobuf.Duration;
import io.grpc.Attributes;
import io.grpc.Channel;
+import io.grpc.ChannelLogger;
import io.grpc.ClientCall;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
@@ -48,15 +50,12 @@ import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.SynchronizationContext;
import io.grpc.internal.FakeClock;
-import io.grpc.internal.TestUtils;
import io.grpc.services.InternalCallMetricRecorder;
import io.grpc.services.MetricReport;
-import io.grpc.util.AbstractTestHelper;
-import io.grpc.util.MultiChildLoadBalancer.ChildLbState;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler;
-import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedChildLbState;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig;
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker;
+import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.HashMap;
@@ -68,7 +67,6 @@ import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CyclicBarrier;
-import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Before;
@@ -89,8 +87,8 @@ public class WeightedRoundRobinLoadBalancerTest {
@Rule
public final MockitoRule mockito = MockitoJUnit.rule();
- private final TestHelper testHelperInstance = new TestHelper();
- private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance));
+ @Mock
+ Helper helper;
@Mock
private LoadBalancer.PickSubchannelArgs mockArgs;
@@ -101,8 +99,9 @@ public class WeightedRoundRobinLoadBalancerTest {
private ArgumentCaptor<SubchannelPicker> pickerCaptor2;
private final List<EquivalentAddressGroup> servers = Lists.newArrayList();
+
private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Maps.newLinkedHashMap();
- private final Map<Subchannel, Subchannel> mockToRealSubChannelMap = new HashMap<>();
+
private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
Maps.newLinkedHashMap();
@@ -135,8 +134,7 @@ public class WeightedRoundRobinLoadBalancerTest {
SocketAddress addr = new FakeSocketAddress("server" + i);
EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
servers.add(eag);
- Subchannel sc = helper.createSubchannel(CreateSubchannelArgs.newBuilder().setAddresses(eag)
- .build());
+ Subchannel sc = mock(Subchannel.class);
Channel channel = mock(Channel.class);
when(channel.newCall(any(), any())).then(
new Answer<ClientCall<OrcaLoadReportRequest, OrcaLoadReport>>() {
@@ -149,13 +147,35 @@ public class WeightedRoundRobinLoadBalancerTest {
return clientCall;
}
});
- testHelperInstance.setChannel(mockToRealSubChannelMap.get(sc), channel);
+ when(sc.asChannel()).thenReturn(channel);
subchannels.put(Arrays.asList(eag), sc);
}
+ when(helper.getSynchronizationContext()).thenReturn(syncContext);
+ when(helper.getScheduledExecutorService()).thenReturn(
+ fakeClock.getScheduledExecutorService());
+ when(helper.createSubchannel(any(CreateSubchannelArgs.class)))
+ .then(new Answer<Subchannel>() {
+ @Override
+ public Subchannel answer(InvocationOnMock invocation) throws Throwable {
+ CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0];
+ final Subchannel subchannel = subchannels.get(args.getAddresses());
+ when(subchannel.getAllAddresses()).thenReturn(args.getAddresses());
+ when(subchannel.getAttributes()).thenReturn(args.getAttributes());
+ when(subchannel.getChannelLogger()).thenReturn(mock(ChannelLogger.class));
+ doAnswer(
+ new Answer<Void>() {
+ @Override
+ public Void answer(InvocationOnMock invocation) throws Throwable {
+ subchannelStateListeners.put(
+ subchannel, (SubchannelStateListener) invocation.getArguments()[0]);
+ return null;
+ }
+ }).when(subchannel).start(any(SubchannelStateListener.class));
+ return subchannel;
+ }
+ });
wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(),
new FakeRandom(0));
-
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
}
@Test
@@ -163,44 +183,44 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
+ verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
- getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
- getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
- getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0);
- assertThat(weightedPicker.getChildren().size()).isEqualTo(1);
+ assertThat(weightedPicker.getList().size()).isEqualTo(1);
weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
- assertThat(weightedPicker.getChildren().size()).isEqualTo(2);
+ assertThat(weightedPicker.getList().size()).isEqualTo(2);
String weightedPickerStr = weightedPicker.toString();
assertThat(weightedPickerStr).contains("enableOobLoadReport=false");
assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0");
assertThat(weightedPickerStr).contains("list=");
- WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
- WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
- weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
+ WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
+ weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
- weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
-
- assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(weightedChild1.getEag());
+ assertThat(weightedPicker.pickSubchannel(mockArgs)
+ .getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder()
.setWeightUpdatePeriodNanos(500_000_000L) //.5s
@@ -218,44 +238,35 @@ public class WeightedRoundRobinLoadBalancerTest {
verifyNoMoreInteractions(mockArgs);
}
- /**
- * Picks subchannel using mockArgs, gets its EAG, and then strips the Attrs to make a key.
- */
- private EquivalentAddressGroup getAddressesFromPick(WeightedRoundRobinPicker weightedPicker) {
- return TestUtils.stripAttrs(
- weightedPicker.pickSubchannel(mockArgs).getSubchannel().getAddresses());
- }
-
@Test
public void enableOobLoadReportConfig() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
+ verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
- getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
- getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
- WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
- WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
- weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
+ WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
+ weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
- weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.9, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
- assertThat(getAddresses(pickResult))
- .isEqualTo(weightedChild1.getEag());
+ assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(pickResult.getStreamTracerFactory()).isNotNull(); // verify per-request listener
assertThat(oobCalls.isEmpty()).isTrue();
@@ -269,8 +280,7 @@ public class WeightedRoundRobinLoadBalancerTest {
eq(ConnectivityState.READY), pickerCaptor2.capture());
weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2);
pickResult = weightedPicker.pickSubchannel(mockArgs);
- assertThat(getAddresses(pickResult))
- .isEqualTo(weightedChild1.getEag());
+ assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(pickResult.getStreamTracerFactory()).isNull();
OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval(
Duration.newBuilder().setSeconds(20).setNanos(30000000).build()).build();
@@ -285,52 +295,46 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
+ verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
- getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
- getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel3 = it.next();
- getSubchannelStateListener(readySubchannel3).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(3)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2);
- WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
- WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
- WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2);
- weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r1);
- weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r2);
- weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r3);
-
+ WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
+ WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
+ WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2);
+ weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ r1);
+ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ r2);
+ weightedSubchannel3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ r3);
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
- Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
+ Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 10000; i++) {
- EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
+ Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(3);
- assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - subchannel1PickRatio))
- .isAtMost(0.0002);
- assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - subchannel2PickRatio ))
- .isAtMost(0.0002);
- assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 10000.0 - subchannel3PickRatio ))
- .isAtMost(0.0002);
- }
-
- private SubchannelStateListener getSubchannelStateListener(Subchannel mockSubChannel) {
- return subchannelStateListeners.get(mockToRealSubChannelMap.get(mockSubChannel));
- }
-
- private static ChildLbState getChild(WeightedRoundRobinPicker picker, int index) {
- return picker.getChildren().get(index);
+ assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio))
+ .isLessThan(0.0002);
+ assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio ))
+ .isLessThan(0.0002);
+ assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio ))
+ .isLessThan(0.0002);
}
@Test
@@ -468,14 +472,14 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(null)
.setAttributes(affinity).build())).isFalse();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class));
verify(helper).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any());
assertThat(fakeClock.getPendingTasks()).isEmpty();
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
+ verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue().getClass().getName())
@@ -488,51 +492,51 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
+ verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
- getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
- getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
- WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
- WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
- weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
+ WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
+ weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
- weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1);
- Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
- for (int i = 0; i < 10000; i++) {
- EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
+ Map<Subchannel, Integer> pickCount = new HashMap<>();
+ for (int i = 0; i < 1000; i++) {
+ Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
// within blackout period, fallback to simple round robin
- assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 0.5)).isLessThan(0.002);
- assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 0.5)).isLessThan(0.002);
+ assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)).isLessThan(0.002);
+ assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)).isLessThan(0.002);
assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1);
pickCount = new HashMap<>();
- for (int i = 0; i < 10000; i++) {
- EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
+ for (int i = 0; i < 1000; i++) {
+ Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
// after blackout period
- assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 2.0 / 3))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
.isLessThan(0.002);
- assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 1.0 / 3))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
.isLessThan(0.002);
}
@@ -541,39 +545,39 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
+ verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
- getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
- getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
- getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0);
- assertThat(weightedPicker.getChildren().size()).isEqualTo(1);
+ assertThat(weightedPicker.getList().size()).isEqualTo(1);
weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
- assertThat(weightedPicker.getChildren().size()).isEqualTo(2);
- WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
- WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
- weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ assertThat(weightedPicker.getList().size()).isEqualTo(2);
+ WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
+ WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
+ weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
- weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
- assertThat(getAddressesFromPick(weightedPicker))
- .isEqualTo(weightedChild1.getEag());
+ assertThat(weightedPicker.pickSubchannel(mockArgs)
+ .getSubchannel()).isEqualTo(weightedSubchannel1);
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder()
.setWeightUpdatePeriodNanos(500_000_000L) //.5s
@@ -582,18 +586,17 @@ public class WeightedRoundRobinLoadBalancerTest {
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
- weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
- weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
//timer fires, new weight updated
assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1);
- assertThat(getAddressesFromPick(weightedPicker))
- .isEqualTo(weightedChild2.getEag());
- assertThat(getAddressesFromPick(weightedPicker))
- .isEqualTo(weightedChild1.getEag());
+ assertThat(weightedPicker.pickSubchannel(mockArgs)
+ .getSubchannel()).isEqualTo(weightedSubchannel2);
+
}
@Test
@@ -601,52 +604,52 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
+ verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
- getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
- getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
- WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
- WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
- weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
+ WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
+ weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
- weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
- Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
+ Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
- EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
+ Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
- assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
.isLessThan(0.002);
- assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
.isLessThan(0.002);
// weight expired, fallback to simple round robin
assertThat(fakeClock.forwardTime(300, TimeUnit.SECONDS)).isEqualTo(1);
pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
- EquivalentAddressGroup result = getAddressesFromPick(weightedPicker);
+ Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(2);
- assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 0.5))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5))
.isLessThan(0.002);
- assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 0.5))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5))
.isLessThan(0.002);
}
@@ -655,113 +658,107 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
+ verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
- getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
- getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
- WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
- WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
- Map<EquivalentAddressGroup, Integer> qpsByChannel = ImmutableMap.of(weightedChild1.getEag(), 2,
- weightedChild2.getEag(), 1);
- Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
+ WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
+ WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
+ Map<WrrSubchannel, Integer> qpsByChannel = ImmutableMap.of(weightedSubchannel1, 2,
+ weightedSubchannel2, 1);
+ Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
- EquivalentAddressGroup addresses = getAddresses(pickResult);
- pickCount.merge(addresses, 1, Integer::sum);
+ pickCount.put(pickResult.getSubchannel(),
+ pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1);
assertThat(pickResult.getStreamTracerFactory()).isNotNull();
- WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses);
- childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel();
+ subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
- 0.1, 0, 0.1, qpsByChannel.get(addresses), 0,
+ 0.1, 0, 0.1, qpsByChannel.get(subchannel), 0,
new HashMap<>(), new HashMap<>(), new HashMap<>()));
}
- assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 1.0 / 2))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 1.0 / 2))
.isAtMost(0.1);
- assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 2))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 2))
.isAtMost(0.1);
-
- // Identical to above except forwards time after each pick
pickCount.clear();
for (int i = 0; i < 1000; i++) {
PickResult pickResult = weightedPicker.pickSubchannel(mockArgs);
- EquivalentAddressGroup addresses = getAddresses(pickResult);
- pickCount.merge(addresses, 1, Integer::sum);
+ pickCount.put(pickResult.getSubchannel(),
+ pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1);
assertThat(pickResult.getStreamTracerFactory()).isNotNull();
- WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses);
- childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ WrrSubchannel subchannel = (WrrSubchannel) pickResult.getSubchannel();
+ subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
- 0.1, 0, 0.1, qpsByChannel.get(addresses), 0,
+ 0.1, 0, 0.1, qpsByChannel.get(subchannel), 0,
new HashMap<>(), new HashMap<>(), new HashMap<>()));
fakeClock.forwardTime(50, TimeUnit.MILLISECONDS);
}
assertThat(pickCount.size()).isEqualTo(2);
- assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3))
.isAtMost(0.1);
- assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3))
.isAtMost(0.1);
}
- private static EquivalentAddressGroup getAddresses(PickResult pickResult) {
- return TestUtils.stripAttrs(pickResult.getSubchannel().getAddresses());
- }
-
@Test
public void unknownWeightIsAvgWeight() {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
- any(CreateSubchannelArgs.class)); // 3 from setup plus 3 from the execute
+ verify(helper, times(3)).createSubchannel(
+ any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
- getSubchannelStateListener(readySubchannel1)
- .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
+ subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
+ .forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
- getSubchannelStateListener(readySubchannel2)
- .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
+ subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
+ .forNonError(ConnectivityState.READY));
Subchannel readySubchannel3 = it.next();
- getSubchannelStateListener(readySubchannel3)
- .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY));
+ subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo
+ .forNonError(ConnectivityState.READY));
verify(helper, times(3)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2);
- WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
- WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
- WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2);
- weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
+ WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
+ WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2);
+ weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
- weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
- Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>();
+ Map<Subchannel, Integer> pickCount = new HashMap<>();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
- pickCount.merge(result.getAddresses(), 1, Integer::sum);
+ pickCount.put(result, pickCount.getOrDefault(result, 0) + 1);
}
assertThat(pickCount.size()).isEqualTo(3);
- assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 4.0 / 9))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9))
.isLessThan(0.002);
- assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 2.0 / 9))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9))
.isLessThan(0.002);
// subchannel3's weight is average of subchannel1 and subchannel2
- assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 1000.0 - 3.0 / 9))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9))
.isLessThan(0.002);
}
@@ -770,33 +767,33 @@ public class WeightedRoundRobinLoadBalancerTest {
syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
.setAttributes(affinity).build()));
- verify(helper, times(6)).createSubchannel(
+ verify(helper, times(3)).createSubchannel(
any(CreateSubchannelArgs.class));
assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1);
Iterator<Subchannel> it = subchannels.values().iterator();
Subchannel readySubchannel1 = it.next();
- getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
Subchannel readySubchannel2 = it.next();
- getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo
+ subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo
.forNonError(ConnectivityState.READY));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
WeightedRoundRobinPicker weightedPicker =
(WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1);
- WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0);
- WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1);
- weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0);
+ WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1);
+ weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
- weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
+ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(
InternalCallMetricRecorder.createMetricReport(
0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>()));
CyclicBarrier barrier = new CyclicBarrier(2);
- Map<EquivalentAddressGroup, AtomicInteger> pickCount = new ConcurrentHashMap<>();
- pickCount.put(weightedChild1.getEag(), new AtomicInteger(0));
- pickCount.put(weightedChild2.getEag(), new AtomicInteger(0));
+ Map<Subchannel, AtomicInteger> pickCount = new ConcurrentHashMap<>();
+ pickCount.put(weightedSubchannel1, new AtomicInteger(0));
+ pickCount.put(weightedSubchannel2, new AtomicInteger(0));
new Thread(new Runnable() {
@Override
public void run() {
@@ -805,7 +802,7 @@ public class WeightedRoundRobinLoadBalancerTest {
barrier.await();
for (int i = 0; i < 1000; i++) {
Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
- pickCount.get(result.getAddresses()).addAndGet(1);
+ pickCount.get(result).addAndGet(1);
}
barrier.await();
} catch (Exception ex) {
@@ -816,15 +813,15 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1);
barrier.await();
for (int i = 0; i < 1000; i++) {
- EquivalentAddressGroup result = getAddresses(weightedPicker.pickSubchannel(mockArgs));
+ Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel();
pickCount.get(result).addAndGet(1);
}
barrier.await();
assertThat(pickCount.size()).isEqualTo(2);
// after blackout period
- assertThat(Math.abs(pickCount.get(weightedChild1.getEag()).get() / 2000.0 - 2.0 / 3))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel1).get() / 2000.0 - 2.0 / 3))
.isLessThan(0.002);
- assertThat(Math.abs(pickCount.get(weightedChild2.getEag()).get() / 2000.0 - 1.0 / 3))
+ assertThat(Math.abs(pickCount.get(weightedSubchannel2).get() / 2000.0 - 1.0 / 3))
.isLessThan(0.002);
}
@@ -1107,34 +1104,4 @@ public class WeightedRoundRobinLoadBalancerTest {
return nextInt;
}
}
-
- private class TestHelper extends AbstractTestHelper {
-
- @Override
- public Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap() {
- return subchannels;
- }
-
- @Override
- public Map<Subchannel, Subchannel> getMockToRealSubChannelMap() {
- return mockToRealSubChannelMap;
- }
-
- @Override
- public Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners() {
- return subchannelStateListeners;
- }
-
- @Override
- public SynchronizationContext getSynchronizationContext() {
- return syncContext;
- }
-
- @Override
- public ScheduledExecutorService getScheduledExecutorService() {
- return fakeClock.getScheduledExecutorService();
- }
-
-
- }
}