diff options
author | Larry Safran <lsafran@google.com> | 2023-09-15 10:27:36 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-15 10:27:36 -0700 |
commit | e1334eae7bba39d85a952bc5ab5aeb4cb05a56d8 (patch) | |
tree | ed895999b1784030271c31addb3bac1fa7a84995 | |
parent | 69986b542e322d37fbb3029aaf37d37aefe14157 (diff) | |
download | grpc-grpc-java-e1334eae7bba39d85a952bc5ab5aeb4cb05a56d8.tar.gz |
Change Round Robin and WeightedRoundRobin into petiole policies (#10528)
* Change Round Robin and WeightedRoundRobin into petiole policies
15 files changed, 983 insertions, 629 deletions
diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index da2bc072a..df35afae1 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -161,7 +161,7 @@ import org.mockito.stubbing.Answer; /** Unit tests for {@link ManagedChannelImpl}. */ @RunWith(JUnit4.class) // TODO(creamsoup) remove backward compatible check when fully migrated -@SuppressWarnings("deprecation") +@SuppressWarnings({"deprecation", "DataFlowIssue"}) public class ManagedChannelImplTest { private static final int DEFAULT_PORT = 447; diff --git a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java index 974f36e59..02df28f2e 100644 --- a/core/src/testFixtures/java/io/grpc/internal/TestUtils.java +++ b/core/src/testFixtures/java/io/grpc/internal/TestUtils.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.when; import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ClientStreamTracer; +import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -143,6 +144,10 @@ public final class TestUtils { return captor; } + public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { + return new EquivalentAddressGroup(eag.getAddresses()); + } + private TestUtils() { } diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index c00b8fbd9..85e283b11 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -53,6 +53,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. + implementation 'io.grpc:grpc-core:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'io.grpc:grpc-okhttp:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'io.grpc:grpc-protobuf-lite:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'io.grpc:grpc-stub:1.59.0-SNAPSHOT' // CURRENT_GRPC_VERSION diff --git a/examples/android/strictmode/app/proguard-rules.pro b/examples/android/strictmode/app/proguard-rules.pro index 1507a5267..d5715fd16 100644 --- a/examples/android/strictmode/app/proguard-rules.pro +++ b/examples/android/strictmode/app/proguard-rules.pro @@ -15,3 +15,4 @@ -dontwarn javax.naming.** -dontwarn okio.** -dontwarn sun.misc.Unsafe + diff --git a/util/build.gradle b/util/build.gradle index a05c55b27..cdd32e0ce 100644 --- a/util/build.gradle +++ b/util/build.gradle @@ -1,5 +1,6 @@ plugins { id "java-library" + id "java-test-fixtures" id "maven-publish" id "me.champeau.jmh" @@ -19,11 +20,18 @@ dependencies { implementation libraries.animalsniffer.annotations, libraries.guava - testImplementation testFixtures(project(':grpc-api')), + testImplementation libraries.guava.testlib, + testFixtures(project(':grpc-api')), testFixtures(project(':grpc-core')), project(':grpc-testing') - testImplementation libraries.guava.testlib + testFixturesApi project(':grpc-core') + testFixturesImplementation libraries.guava, + libraries.junit, + libraries.mockito.core, + testFixtures(project(':grpc-api')), + testFixtures(project(':grpc-core')), + project(':grpc-testing') jmh project(':grpc-testing') signature libraries.signature.java diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java index be0a23a16..fa31755ea 100644 --- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java @@ -16,25 +16,29 @@ package io.grpc.util; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; +import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import io.grpc.Attributes; import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.internal.PickFirstLoadBalancerProvider; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -46,23 +50,34 @@ import javax.annotation.Nullable; @Internal public abstract class MultiChildLoadBalancer extends LoadBalancer { - @VisibleForTesting - public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15; private static final Logger logger = Logger.getLogger(MultiChildLoadBalancer.class.getName()); private final Map<Object, ChildLbState> childLbStates = new HashMap<>(); private final Helper helper; - protected final SynchronizationContext syncContext; - private final ScheduledExecutorService timeService; // Set to true if currently in the process of handling resolved addresses. - private boolean resolvingAddresses; + @VisibleForTesting + boolean resolvingAddresses; + + protected final PickFirstLoadBalancerProvider pickFirstLbProvider = + new PickFirstLoadBalancerProvider(); + protected MultiChildLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); - this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); - this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); logger.log(Level.FINE, "Created"); } + @SuppressWarnings("ReferenceEquality") + protected static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { + if (eag.getAttributes() == Attributes.EMPTY) { + return eag; + } else { + return new EquivalentAddressGroup(eag.getAddresses()); + } + } + + protected abstract SubchannelPicker getSubchannelPicker( + Map<Object, SubchannelPicker> childPickers); + protected SubchannelPicker getInitialPicker() { return EMPTY_PICKER; } @@ -71,11 +86,42 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { return new ErrorPicker(error); } - protected abstract Map<Object, PolicySelection> getPolicySelectionMap( - ResolvedAddresses resolvedAddresses); + @VisibleForTesting + protected Collection<ChildLbState> getChildLbStates() { + return childLbStates.values(); + } - protected abstract SubchannelPicker getSubchannelPicker( - Map<Object, SubchannelPicker> childPickers); + protected ChildLbState getChildLbState(Object key) { + if (key == null) { + return null; + } + return childLbStates.get(key); + } + + protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) { + return getChildLbState(stripAttrs(eag)); + } + + /** + * Override to utilize parsing of the policy configuration or alternative helper/lb generation. + */ + protected Map<Object, ChildLbState> createChildLbMap(ResolvedAddresses resolvedAddresses) { + Map<Object, ChildLbState> childLbMap = new HashMap<>(); + List<EquivalentAddressGroup> addresses = resolvedAddresses.getAddresses(); + Object policyConfig = resolvedAddresses.getLoadBalancingPolicyConfig(); + for (EquivalentAddressGroup eag : addresses) { + EquivalentAddressGroup strippedEag = stripAttrs(eag); // keys need to be just addresses + ChildLbState childLbState = childLbMap.getOrDefault(strippedEag, + createChildLbState(strippedEag, policyConfig, getInitialPicker())); + childLbMap.put(strippedEag, childLbState); + } + return childLbMap; + } + + protected ChildLbState createChildLbState(Object key, Object policyConfig, + SubchannelPicker initialPicker) { + return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker); + } @Override public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { @@ -87,25 +133,61 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { } } + protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, + Object childConfig) { + checkArgument(key instanceof EquivalentAddressGroup, "key is wrong type"); + + // Retrieve the non-stripped version + EquivalentAddressGroup eag = null; + for (EquivalentAddressGroup equivalentAddressGroup : resolvedAddresses.getAddresses()) { + if (stripAttrs(equivalentAddressGroup).equals(key)) { + eag = equivalentAddressGroup; + break; + } + } + + checkNotNull(eag, key.toString() + " no longer present in load balancer children"); + + return resolvedAddresses.toBuilder() + .setAddresses(Collections.singletonList(eag)) + .setLoadBalancingPolicyConfig(childConfig) + .build(); + } + + + private boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) { logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses); - Map<Object, PolicySelection> newChildPolicies = getPolicySelectionMap(resolvedAddresses); - for (Map.Entry<Object, PolicySelection> entry : newChildPolicies.entrySet()) { + Map<Object, ChildLbState> newChildren = createChildLbMap(resolvedAddresses); + + if (newChildren.isEmpty()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. " + resolvedAddresses)); + return false; + } + + // Do adds and updates + for (Map.Entry<Object, ChildLbState> entry : newChildren.entrySet()) { final Object key = entry.getKey(); - LoadBalancerProvider childPolicyProvider = entry.getValue().getProvider(); + LoadBalancerProvider childPolicyProvider = entry.getValue().getPolicyProvider(); Object childConfig = entry.getValue().getConfig(); if (!childLbStates.containsKey(key)) { - childLbStates.put(key, new ChildLbState(key, childPolicyProvider, getInitialPicker())); + childLbStates.put(key, entry.getValue()); } else { - childLbStates.get(key).reactivate(childPolicyProvider); + // Reuse the existing one + ChildLbState existingChildLbState = childLbStates.get(key); + if (existingChildLbState.isDeactivated()) { + existingChildLbState.reactivate(childPolicyProvider); + } } + LoadBalancer childLb = childLbStates.get(key).lb; - ResolvedAddresses childAddresses = - resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build(); - childLb.handleResolvedAddresses(childAddresses); + childLb.handleResolvedAddresses(getChildAddresses(key, resolvedAddresses, childConfig)); } - for (Object key : childLbStates.keySet()) { - if (!newChildPolicies.containsKey(key)) { + + // Do removals + for (Object key : ImmutableList.copyOf(childLbStates.keySet())) { + if (!newChildren.containsKey(key)) { childLbStates.get(key).deactivate(); } } @@ -139,10 +221,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { childLbStates.clear(); } - private void updateOverallBalancingState() { + protected void updateOverallBalancingState() { ConnectivityState overallState = null; final Map<Object, SubchannelPicker> childPickers = new HashMap<>(); - for (ChildLbState childLbState : childLbStates.values()) { + for (ChildLbState childLbState : getChildLbStates()) { if (childLbState.deactivated) { continue; } @@ -155,7 +237,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { } @Nullable - private static ConnectivityState aggregateState( + protected static ConnectivityState aggregateState( @Nullable ConnectivityState overallState, ConnectivityState childState) { if (overallState == null) { return childState; @@ -172,67 +254,109 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { return overallState; } - private final class ChildLbState { + protected Helper getHelper() { + return helper; + } + + protected void removeChild(Object key) { + childLbStates.remove(key); + } + + + public class ChildLbState { private final Object key; + private final Object config; private final GracefulSwitchLoadBalancer lb; private LoadBalancerProvider policyProvider; private ConnectivityState currentState = CONNECTING; private SubchannelPicker currentPicker; private boolean deactivated; - @Nullable - ScheduledHandle deletionTimer; - ChildLbState(Object key, LoadBalancerProvider policyProvider, SubchannelPicker initialPicker) { + public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, + SubchannelPicker initialPicker) { this.key = key; this.policyProvider = policyProvider; lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper()); lb.switchTo(policyProvider); currentPicker = initialPicker; + config = childConfig; + } + + + @Override + public String toString() { + return "Address = " + key + + ", state = " + currentState + + ", picker type: " + currentPicker.getClass() + + ", lb: " + lb.delegate().getClass() + + (deactivated ? ", deactivated" : ""); + } + + public Object getKey() { + return key; + } + + Object getConfig() { + return config; + } + + public LoadBalancerProvider getPolicyProvider() { + return policyProvider; + } + + protected Subchannel getSubchannels(PickSubchannelArgs args) { + return getCurrentPicker().pickSubchannel(args).getSubchannel(); + } + + ConnectivityState getCurrentState() { + return currentState; } - void deactivate() { + public SubchannelPicker getCurrentPicker() { + return currentPicker; + } + + public boolean isDeactivated() { + return deactivated; + } + + @VisibleForTesting + LoadBalancer getLb() { + return this.lb; + } + + protected void setDeactivated() { + deactivated = true; + } + + protected void deactivate() { if (deactivated) { return; } - class DeletionTask implements Runnable { - @Override - public void run() { - shutdown(); - childLbStates.remove(key); - } - } - - deletionTimer = - syncContext.schedule( - new DeletionTask(), - DELAYED_CHILD_DELETION_TIME_MINUTES, - TimeUnit.MINUTES, - timeService); + shutdown(); + childLbStates.remove(key); deactivated = true; logger.log(Level.FINE, "Child balancer {0} deactivated", key); } - void reactivate(LoadBalancerProvider policyProvider) { - if (deletionTimer != null && deletionTimer.isPending()) { - deletionTimer.cancel(); - deactivated = false; - logger.log(Level.FINE, "Child balancer {0} reactivated", key); - } + protected void reactivate(LoadBalancerProvider policyProvider) { if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) { Object[] objects = { key, this.policyProvider.getPolicyName(),policyProvider.getPolicyName()}; logger.log(Level.FINE, "Child balancer {0} switching policy from {1} to {2}", objects); lb.switchTo(policyProvider); this.policyProvider = policyProvider; + } else { + logger.log(Level.FINE, "Child balancer {0} reactivated", key); } + + deactivated = false; } - void shutdown() { - if (deletionTimer != null && deletionTimer.isPending()) { - deletionTimer.cancel(); - } + protected void shutdown() { lb.shutdown(); + this.currentState = SHUTDOWN; logger.log(Level.FINE, "Child balancer {0} deleted", key); } diff --git a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index 560970849..9873e3e45 100644 --- a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -16,11 +16,9 @@ package io.grpc.util; -import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; -import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; @@ -37,13 +35,10 @@ import io.grpc.NameResolver; import io.grpc.Status; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; -import java.util.Set; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import javax.annotation.Nonnull; @@ -52,131 +47,23 @@ import javax.annotation.Nonnull; * EquivalentAddressGroup}s from the {@link NameResolver}. */ @Internal -public class RoundRobinLoadBalancer extends LoadBalancer { +public class RoundRobinLoadBalancer extends MultiChildLoadBalancer { @VisibleForTesting static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO = Attributes.Key.create("state-info"); - private final Helper helper; - private final Map<EquivalentAddressGroup, Subchannel> subchannels = - new HashMap<>(); private final Random random; private ConnectivityState currentState; protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK); public RoundRobinLoadBalancer(Helper helper) { - this.helper = checkNotNull(helper, "helper"); + super(helper); this.random = new Random(); } @Override - public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { - if (resolvedAddresses.getAddresses().isEmpty()) { - handleNameResolutionError(Status.UNAVAILABLE.withDescription( - "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() - + ", attrs=" + resolvedAddresses.getAttributes())); - return false; - } - - List<EquivalentAddressGroup> servers = resolvedAddresses.getAddresses(); - Set<EquivalentAddressGroup> currentAddrs = subchannels.keySet(); - Map<EquivalentAddressGroup, EquivalentAddressGroup> latestAddrs = stripAttrs(servers); - Set<EquivalentAddressGroup> removedAddrs = setsDifference(currentAddrs, latestAddrs.keySet()); - - for (Map.Entry<EquivalentAddressGroup, EquivalentAddressGroup> latestEntry : - latestAddrs.entrySet()) { - EquivalentAddressGroup strippedAddressGroup = latestEntry.getKey(); - EquivalentAddressGroup originalAddressGroup = latestEntry.getValue(); - Subchannel existingSubchannel = subchannels.get(strippedAddressGroup); - if (existingSubchannel != null) { - // EAG's Attributes may have changed. - existingSubchannel.updateAddresses(Collections.singletonList(originalAddressGroup)); - continue; - } - // Create new subchannels for new addresses. - - // NB(lukaszx0): we don't merge `attributes` with `subchannelAttr` because subchannel - // doesn't need them. They're describing the resolved server list but we're not taking - // any action based on this information. - Attributes.Builder subchannelAttrs = Attributes.newBuilder() - // NB(lukaszx0): because attributes are immutable we can't set new value for the key - // after creation but since we can mutate the values we leverage that and set - // AtomicReference which will allow mutating state info for given channel. - .set(STATE_INFO, - new Ref<>(ConnectivityStateInfo.forNonError(IDLE))); - - final Subchannel subchannel = checkNotNull( - helper.createSubchannel(CreateSubchannelArgs.newBuilder() - .setAddresses(originalAddressGroup) - .setAttributes(subchannelAttrs.build()) - .build()), - "subchannel"); - subchannel.start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo state) { - processSubchannelState(subchannel, state); - } - }); - subchannels.put(strippedAddressGroup, subchannel); - subchannel.requestConnection(); - } - - ArrayList<Subchannel> removedSubchannels = new ArrayList<>(); - for (EquivalentAddressGroup addressGroup : removedAddrs) { - removedSubchannels.add(subchannels.remove(addressGroup)); - } - - // Update the picker before shutting down the subchannels, to reduce the chance of the race - // between picking a subchannel and shutting it down. - updateBalancingState(); - - // Shutdown removed subchannels - for (Subchannel removedSubchannel : removedSubchannels) { - shutdownSubchannel(removedSubchannel); - } - - return true; - } - - @Override - public void handleNameResolutionError(Status error) { - if (currentState != READY) { - updateBalancingState(TRANSIENT_FAILURE, new EmptyPicker(error)); - } - } - - private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) { - return; - } - if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) { - helper.refreshNameResolution(); - } - if (stateInfo.getState() == IDLE) { - subchannel.requestConnection(); - } - Ref<ConnectivityStateInfo> subchannelStateRef = getSubchannelStateInfoRef(subchannel); - if (subchannelStateRef.value.getState().equals(TRANSIENT_FAILURE)) { - if (stateInfo.getState().equals(CONNECTING) || stateInfo.getState().equals(IDLE)) { - return; - } - } - subchannelStateRef.value = stateInfo; - updateBalancingState(); - } - - private void shutdownSubchannel(Subchannel subchannel) { - subchannel.shutdown(); - getSubchannelStateInfoRef(subchannel).value = - ConnectivityStateInfo.forNonError(SHUTDOWN); - } - - @Override - public void shutdown() { - for (Subchannel subchannel : getSubchannels()) { - shutdownSubchannel(subchannel); - } - subchannels.clear(); + protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) { + throw new UnsupportedOperationException(); // local updateOverallBalancingState doesn't use this } private static final Status EMPTY_OK = Status.OK.withDescription("no subchannels ready"); @@ -184,29 +71,27 @@ public class RoundRobinLoadBalancer extends LoadBalancer { /** * Updates picker with the list of active subchannels (state == READY). */ - @SuppressWarnings("ReferenceEquality") - private void updateBalancingState() { - List<Subchannel> activeList = filterNonFailingSubchannels(getSubchannels()); + @Override + protected void updateOverallBalancingState() { + List<ChildLbState> activeList = getReadyChildren(); if (activeList.isEmpty()) { - // No READY subchannels, determine aggregate state and error status + // No READY subchannels + + // RRLB will request connection immediately on subchannel IDLE. boolean isConnecting = false; - Status aggStatus = EMPTY_OK; - for (Subchannel subchannel : getSubchannels()) { - ConnectivityStateInfo stateInfo = getSubchannelStateInfoRef(subchannel).value; - // This subchannel IDLE is not because of channel IDLE_TIMEOUT, - // in which case LB is already shutdown. - // RRLB will request connection immediately on subchannel IDLE. - if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) { + for (ChildLbState childLbState : getChildLbStates()) { + ConnectivityState state = childLbState.getCurrentState(); + if (state == CONNECTING || state == IDLE) { isConnecting = true; - } - if (aggStatus == EMPTY_OK || !aggStatus.isOk()) { - aggStatus = stateInfo.getStatus(); + break; } } - updateBalancingState(isConnecting ? CONNECTING : TRANSIENT_FAILURE, - // If all subchannels are TRANSIENT_FAILURE, return the Status associated with - // an arbitrary subchannel, otherwise return OK. - new EmptyPicker(aggStatus)); + + if (isConnecting) { + updateBalancingState(CONNECTING, new EmptyPicker(Status.OK)); + } else { + updateBalancingState(TRANSIENT_FAILURE, createReadyPicker(getChildLbStates())); + } } else { updateBalancingState(READY, createReadyPicker(activeList)); } @@ -214,72 +99,39 @@ public class RoundRobinLoadBalancer extends LoadBalancer { private void updateBalancingState(ConnectivityState state, RoundRobinPicker picker) { if (state != currentState || !picker.isEquivalentTo(currentPicker)) { - helper.updateBalancingState(state, picker); + getHelper().updateBalancingState(state, picker); currentState = state; currentPicker = picker; } } - protected RoundRobinPicker createReadyPicker(List<Subchannel> activeList) { + protected RoundRobinPicker createReadyPicker(Collection<ChildLbState> children) { // initialize the Picker to a random start index to ensure that a high frequency of Picker // churn does not skew subchannel selection. - int startIndex = random.nextInt(activeList.size()); - return new ReadyPicker(activeList, startIndex); - } + int startIndex = random.nextInt(children.size()); - /** - * Filters out non-ready subchannels. - */ - private static List<Subchannel> filterNonFailingSubchannels( - Collection<Subchannel> subchannels) { - List<Subchannel> readySubchannels = new ArrayList<>(subchannels.size()); - for (Subchannel subchannel : subchannels) { - if (isReady(subchannel)) { - readySubchannels.add(subchannel); - } + List<SubchannelPicker> pickerList = new ArrayList<>(); + for (ChildLbState child : children) { + SubchannelPicker picker = child.getCurrentPicker(); + pickerList.add(picker); } - return readySubchannels; + + return new ReadyPicker(pickerList, startIndex); } /** - * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and - * remove all attributes. The values are the original EAGs. + * Filters out non-ready and deactivated child load balancers (subchannels). */ - private static Map<EquivalentAddressGroup, EquivalentAddressGroup> stripAttrs( - List<EquivalentAddressGroup> groupList) { - Map<EquivalentAddressGroup, EquivalentAddressGroup> addrs = new HashMap<>(groupList.size() * 2); - for (EquivalentAddressGroup group : groupList) { - addrs.put(stripAttrs(group), group); + private List<ChildLbState> getReadyChildren() { + List<ChildLbState> activeChildren = new ArrayList<>(); + for (ChildLbState child : getChildLbStates()) { + if (!child.isDeactivated() && child.getCurrentState() == READY) { + activeChildren.add(child); + } } - return addrs; - } - - private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { - return new EquivalentAddressGroup(eag.getAddresses()); - } - - @VisibleForTesting - protected Collection<Subchannel> getSubchannels() { - return subchannels.values(); - } - - private static Ref<ConnectivityStateInfo> getSubchannelStateInfoRef( - Subchannel subchannel) { - return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO"); - } - - // package-private to avoid synthetic access - static boolean isReady(Subchannel subchannel) { - return getSubchannelStateInfoRef(subchannel).value.getState() == READY; - } - - private static <T> Set<T> setsDifference(Set<T> a, Set<T> b) { - Set<T> aCopy = new HashSet<>(a); - aCopy.removeAll(b); - return aCopy; + return activeChildren; } - // Only subclasses are ReadyPicker or EmptyPicker public abstract static class RoundRobinPicker extends SubchannelPicker { public abstract boolean isEquivalentTo(RoundRobinPicker picker); } @@ -289,11 +141,11 @@ public class RoundRobinLoadBalancer extends LoadBalancer { private static final AtomicIntegerFieldUpdater<ReadyPicker> indexUpdater = AtomicIntegerFieldUpdater.newUpdater(ReadyPicker.class, "index"); - private final List<Subchannel> list; // non-empty + private final List<SubchannelPicker> list; // non-empty @SuppressWarnings("unused") private volatile int index; - public ReadyPicker(List<Subchannel> list, int startIndex) { + public ReadyPicker(List<SubchannelPicker> list, int startIndex) { Preconditions.checkArgument(!list.isEmpty(), "empty list"); this.list = list; this.index = startIndex - 1; @@ -301,7 +153,7 @@ public class RoundRobinLoadBalancer extends LoadBalancer { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(nextSubchannel()); + return list.get(nextIndex()).pickSubchannel(args); } @Override @@ -309,7 +161,7 @@ public class RoundRobinLoadBalancer extends LoadBalancer { return MoreObjects.toStringHelper(ReadyPicker.class).add("list", list).toString(); } - private Subchannel nextSubchannel() { + private int nextIndex() { int size = list.size(); int i = indexUpdater.incrementAndGet(this); if (i >= size) { @@ -317,11 +169,11 @@ public class RoundRobinLoadBalancer extends LoadBalancer { i %= size; indexUpdater.compareAndSet(this, oldi, i); } - return list.get(i); + return i; } @VisibleForTesting - List<Subchannel> getList() { + List<SubchannelPicker> getList() { return list; } diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java index 13f13421a..ac5bd8b98 100644 --- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java @@ -512,7 +512,7 @@ public class OutlierDetectionLoadBalancerTest { loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers)); - generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 8); + generateLoad(ImmutableMap.of(subchannel2, Status.DEADLINE_EXCEEDED), 12); // Move forward in time to a point where the detection timer has fired. forwardTime(config); @@ -546,7 +546,7 @@ public class OutlierDetectionLoadBalancerTest { assertEjectedSubchannels(ImmutableSet.of(servers.get(0).getAddresses().get(0))); // Now we produce more load, but the subchannel start working and is no longer an outlier. - generateLoad(ImmutableMap.of(), 8); + generateLoad(ImmutableMap.of(), 12); // Move forward in time to a point where the detection timer has fired. fakeClock.forwardTime(config.maxEjectionTimeNanos + 1, TimeUnit.NANOSECONDS); diff --git a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 23b6e1c10..3b7f6599d 100644 --- a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -22,23 +22,21 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.util.RoundRobinLoadBalancer.STATE_INFO; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -55,16 +53,19 @@ import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Status; +import io.grpc.internal.TestUtils; +import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker; import io.grpc.util.RoundRobinLoadBalancer.ReadyPicker; -import io.grpc.util.RoundRobinLoadBalancer.Ref; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -75,10 +76,8 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; -import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -import org.mockito.stubbing.Answer; /** Unit test for {@link RoundRobinLoadBalancer}. */ @RunWith(JUnit4.class) @@ -89,7 +88,9 @@ public class RoundRobinLoadBalancerTest { private RoundRobinLoadBalancer loadBalancer; private final List<EquivalentAddressGroup> servers = Lists.newArrayList(); - private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Maps.newLinkedHashMap(); + private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = + new ConcurrentHashMap<>(); + private final Map<Subchannel, Subchannel> mockToRealSubChannelMap = new HashMap<>(); private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners = Maps.newLinkedHashMap(); private final Attributes affinity = @@ -101,8 +102,7 @@ public class RoundRobinLoadBalancerTest { private ArgumentCaptor<ConnectivityState> stateCaptor; @Captor private ArgumentCaptor<CreateSubchannelArgs> createArgsCaptor; - @Mock - private Helper mockHelper; + private Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper())); @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown(). private PickSubchannelArgs mockArgs; @@ -113,34 +113,16 @@ public class RoundRobinLoadBalancerTest { SocketAddress addr = new FakeSocketAddress("server" + i); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.add(eag); - Subchannel sc = mock(Subchannel.class); - subchannels.put(Arrays.asList(eag), sc); } - when(mockHelper.createSubchannel(any(CreateSubchannelArgs.class))) - .then(new Answer<Subchannel>() { - @Override - public Subchannel answer(InvocationOnMock invocation) throws Throwable { - CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; - final Subchannel subchannel = subchannels.get(args.getAddresses()); - when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); - when(subchannel.getAttributes()).thenReturn(args.getAttributes()); - doAnswer( - new Answer<Void>() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - subchannelStateListeners.put( - subchannel, (SubchannelStateListener) invocation.getArguments()[0]); - return null; - } - }).when(subchannel).start(any(SubchannelStateListener.class)); - return subchannel; - } - }); - loadBalancer = new RoundRobinLoadBalancer(mockHelper); } + private boolean acceptAddresses(List<EquivalentAddressGroup> eagList, Attributes attrs) { + return loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(eagList).setAttributes(attrs).build()); + } + @After public void tearDown() throws Exception { verifyNoMoreInteractions(mockArgs); @@ -148,10 +130,9 @@ public class RoundRobinLoadBalancerTest { @Test public void pickAfterResolved() throws Exception { - final Subchannel readySubchannel = subchannels.values().iterator().next(); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + boolean addressesAccepted = acceptAddresses(servers, affinity); assertThat(addressesAccepted).isTrue(); + final Subchannel readySubchannel = subchannels.values().iterator().next(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); @@ -178,10 +159,6 @@ public class RoundRobinLoadBalancerTest { @Test public void pickAfterResolvedUpdatedHosts() throws Exception { - Subchannel removedSubchannel = mock(Subchannel.class); - Subchannel oldSubchannel = mock(Subchannel.class); - Subchannel newSubchannel = mock(Subchannel.class); - Attributes.Key<String> key = Attributes.Key.create("check-that-it-is-propagated"); FakeSocketAddress removedAddr = new FakeSocketAddress("removed"); EquivalentAddressGroup removedEag = new EquivalentAddressGroup(removedAddr); @@ -193,6 +170,13 @@ public class RoundRobinLoadBalancerTest { EquivalentAddressGroup newEag = new EquivalentAddressGroup( newAddr, Attributes.newBuilder().set(key, "newattr").build()); + Subchannel removedSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(removedEag).build()); + Subchannel oldSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(oldEag1).build()); + Subchannel newSubchannel = mockHelper.createSubchannel(CreateSubchannelArgs.newBuilder() + .setAddresses(newEag).build()); + subchannels.put(Collections.singletonList(removedEag), removedSubchannel); subchannels.put(Collections.singletonList(oldEag1), oldSubchannel); subchannels.put(Collections.singletonList(newEag), newSubchannel); @@ -201,9 +185,7 @@ public class RoundRobinLoadBalancerTest { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity) - .build()); + boolean addressesAccepted = acceptAddresses(currentServers, affinity); assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -218,8 +200,11 @@ public class RoundRobinLoadBalancerTest { verify(removedSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).requestConnection(); - assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel, - oldSubchannel); + assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); + assertThat(loadBalancer.getChildLbStateEag(removedEag).getCurrentPicker().pickSubchannel(null) + .getSubchannel()).isEqualTo(removedSubchannel); + assertThat(loadBalancer.getChildLbStateEag(oldEag1).getCurrentPicker().pickSubchannel(null) + .getSubchannel()).isEqualTo(oldSubchannel); // This time with Attributes List<EquivalentAddressGroup> latestServers = Lists.newArrayList(oldEag2, newEag); @@ -232,13 +217,15 @@ public class RoundRobinLoadBalancerTest { verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2)); verify(removedSubchannel, times(1)).shutdown(); - deliverSubchannelState(removedSubchannel, ConnectivityStateInfo.forNonError(SHUTDOWN)); deliverSubchannelState(newSubchannel, ConnectivityStateInfo.forNonError(READY)); - assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel, - newSubchannel); + assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(2); + assertThat(loadBalancer.getChildLbStateEag(newEag).getCurrentPicker() + .pickSubchannel(null).getSubchannel()).isEqualTo(newSubchannel); + assertThat(loadBalancer.getChildLbStateEag(oldEag2).getCurrentPicker() + .pickSubchannel(null).getSubchannel()).isEqualTo(oldSubchannel); - verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(mockHelper, times(6)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper, times(2)).updateBalancingState(eq(READY), pickerCaptor.capture()); picker = pickerCaptor.getValue(); @@ -250,29 +237,26 @@ public class RoundRobinLoadBalancerTest { @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAccepted).isTrue(); - Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); - Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get( - STATE_INFO); + + // TODO figure out if this method testing the right things + + ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); + Subchannel subchannel = childLbState.getCurrentPicker().pickSubchannel(null).getSubchannel(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); - assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE)); + assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); - deliverSubchannelState(subchannel, - ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class); - assertThat(subchannelStateInfo.value).isEqualTo( - ConnectivityStateInfo.forNonError(READY)); + assertThat(childLbState.getCurrentState()).isEqualTo(READY); Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"); deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(error)); - assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); - assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); + assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); inOrder.verify(mockHelper).refreshNameResolution(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); @@ -280,8 +264,7 @@ public class RoundRobinLoadBalancerTest { deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); - assertThat(subchannelStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); - assertThat(subchannelStateInfo.value.getStatus()).isEqualTo(error); + assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); verify(subchannel, times(2)).requestConnection(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); @@ -291,15 +274,14 @@ public class RoundRobinLoadBalancerTest { @Test public void ignoreShutdownSubchannelStateChange() { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); loadBalancer.shutdown(); - for (Subchannel sc : loadBalancer.getSubchannels()) { - verify(sc).shutdown(); + for (ChildLbState child : loadBalancer.getChildLbStates()) { + Subchannel sc = child.getCurrentPicker().pickSubchannel(null).getSubchannel(); + verify(child).shutdown(); // When the subchannel is being shut down, a SHUTDOWN connectivity state is delivered // back to the subchannel state listener. deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(SHUTDOWN)); @@ -311,36 +293,34 @@ public class RoundRobinLoadBalancerTest { @Test public void stayTransientFailureUntilReady() { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + Map<ChildLbState, Subchannel> childToSubChannelMap = new HashMap<>(); // Simulate state transitions for each subchannel individually. - for (Subchannel sc : loadBalancer.getSubchannels()) { + for ( ChildLbState child : loadBalancer.getChildLbStates()) { + Subchannel sc = child.getSubchannels(mockArgs); + childToSubChannelMap.put(child, sc); Status error = Status.UNKNOWN.withDescription("connection broken"); deliverSubchannelState( sc, ConnectivityStateInfo.forTransientFailure(error)); + assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); inOrder.verify(mockHelper).refreshNameResolution(); deliverSubchannelState( sc, ConnectivityStateInfo.forNonError(CONNECTING)); - Ref<ConnectivityStateInfo> scStateInfo = sc.getAttributes().get( - STATE_INFO); - assertThat(scStateInfo.value.getState()).isEqualTo(TRANSIENT_FAILURE); - assertThat(scStateInfo.value.getStatus()).isEqualTo(error); + assertEquals(TRANSIENT_FAILURE, child.getCurrentState()); } - inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), isA(ReadyPicker.class)); inOrder.verifyNoMoreInteractions(); - Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); + ChildLbState child = loadBalancer.getChildLbStates().iterator().next(); + Subchannel subchannel = childToSubChannelMap.get(child); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); - Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get( - STATE_INFO); - assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(READY)); + assertThat(child.getCurrentState()).isEqualTo(READY); inOrder.verify(mockHelper).updateBalancingState(eq(READY), isA(ReadyPicker.class)); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); @@ -350,16 +330,15 @@ public class RoundRobinLoadBalancerTest { @Test public void refreshNameResolutionWhenSubchannelConnectionBroken() { InOrder inOrder = inOrder(mockHelper); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAccepted).isTrue(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); // Simulate state transitions for each subchannel individually. - for (Subchannel sc : loadBalancer.getSubchannels()) { + for (ChildLbState child : loadBalancer.getChildLbStates()) { + Subchannel sc = child.getSubchannels(mockArgs); verify(sc).requestConnection(); deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(CONNECTING)); Status error = Status.UNKNOWN.withDescription("connection broken"); @@ -370,7 +349,7 @@ public class RoundRobinLoadBalancerTest { // Simulate receiving go-away so READY subchannels transit to IDLE. deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); - verify(sc, times(2)).requestConnection(); + verify(sc, times(1)).requestConnection(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); } @@ -383,11 +362,12 @@ public class RoundRobinLoadBalancerTest { Subchannel subchannel1 = mock(Subchannel.class); Subchannel subchannel2 = mock(Subchannel.class); - ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList( - Lists.newArrayList(subchannel, subchannel1, subchannel2)), - 0 /* startIndex */); + ArrayList<SubchannelPicker> pickers = Lists.newArrayList( + TestUtils.pickerOf(subchannel), TestUtils.pickerOf(subchannel1), + TestUtils.pickerOf(subchannel2)); - assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2); + ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList(pickers), + 0 /* startIndex */); assertEquals(subchannel, picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(subchannel1, picker.pickSubchannel(mockArgs).getSubchannel()); @@ -399,7 +379,7 @@ public class RoundRobinLoadBalancerTest { public void pickerEmptyList() throws Exception { SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN); - assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel()); + assertNull(picker.pickSubchannel(mockArgs).getSubchannel()); assertEquals(Status.UNKNOWN, picker.pickSubchannel(mockArgs).getStatus()); } @@ -417,12 +397,13 @@ public class RoundRobinLoadBalancerTest { @Test public void nameResolutionErrorWithActiveChannels() throws Exception { + boolean addressesAccepted = acceptAddresses(servers, affinity); final Subchannel readySubchannel = subchannels.values().iterator().next(); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + loadBalancer.resolvingAddresses = true; loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); + loadBalancer.resolvingAddresses = false; verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(mockHelper, times(2)) @@ -443,15 +424,14 @@ public class RoundRobinLoadBalancerTest { @Test public void subchannelStateIsolation() throws Exception { + boolean addressesAccepted = acceptAddresses(servers, Attributes.EMPTY); + assertThat(addressesAccepted).isTrue(); + Iterator<Subchannel> subchannelIterator = subchannels.values().iterator(); Subchannel sc1 = subchannelIterator.next(); Subchannel sc2 = subchannelIterator.next(); Subchannel sc3 = subchannelIterator.next(); - boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) - .build()); - assertThat(addressesAccepted).isTrue(); verify(sc1, times(1)).requestConnection(); verify(sc2, times(1)).requestConnection(); verify(sc3, times(1)).requestConnection(); @@ -478,7 +458,7 @@ public class RoundRobinLoadBalancerTest { // The IDLE subchannel is dropped from the picker, but a reconnection is requested assertEquals(READY, stateIterator.next()); assertThat(getList(pickers.next())).containsExactly(sc1, sc3); - verify(sc2, times(2)).requestConnection(); + verify(sc2, times(1)).requestConnection(); // The failing subchannel is dropped from the picker, with no requested reconnect assertEquals(READY, stateIterator.next()); assertThat(getList(pickers.next())).containsExactly(sc1); @@ -491,7 +471,7 @@ public class RoundRobinLoadBalancerTest { public void readyPicker_emptyList() { // ready picker list must be non-empty try { - new ReadyPicker(Collections.<Subchannel>emptyList(), 0); + new ReadyPicker(Collections.emptyList(), 0); fail(); } catch (IllegalArgumentException expected) { } @@ -503,9 +483,10 @@ public class RoundRobinLoadBalancerTest { EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK")); EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯")); + acceptAddresses(servers, Attributes.EMPTY); // create subchannels Iterator<Subchannel> subchannelIterator = subchannels.values().iterator(); - Subchannel sc1 = subchannelIterator.next(); - Subchannel sc2 = subchannelIterator.next(); + SubchannelPicker sc1 = TestUtils.pickerOf(subchannelIterator.next()); + SubchannelPicker sc2 = TestUtils.pickerOf(subchannelIterator.next()); ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 0); ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 0); ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 1); @@ -526,18 +507,27 @@ public class RoundRobinLoadBalancerTest { public void emptyAddresses() { assertThat(loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() - .setAddresses(Collections.<EquivalentAddressGroup>emptyList()) + .setAddresses(Collections.emptyList()) .setAttributes(affinity) .build())).isFalse(); } - private static List<Subchannel> getList(SubchannelPicker picker) { - return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() : - Collections.<Subchannel>emptyList(); + private List<Subchannel> getList(SubchannelPicker picker) { + + if (picker instanceof ReadyPicker) { + List<Subchannel> subchannelList = new ArrayList<>(); + for (SubchannelPicker childPicker : ((ReadyPicker) picker).getList()) { + subchannelList.add(childPicker.pickSubchannel(mockArgs).getSubchannel()); + } + return subchannelList; + } else { + return new ArrayList<>(); + } } private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo newState) { - subchannelStateListeners.get(subchannel).onSubchannelState(newState); + Subchannel realSc = mockToRealSubChannelMap.get(subchannel); + subchannelStateListeners.get(realSc).onSubchannelState(newState); } private static class FakeSocketAddress extends SocketAddress { @@ -552,4 +542,22 @@ public class RoundRobinLoadBalancerTest { return "FakeSocketAddress-" + name; } } + + private class TestHelper extends AbstractTestHelper { + + @Override + public Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap() { + return subchannels; + } + + @Override + public Map<Subchannel, Subchannel> getMockToRealSubChannelMap() { + return mockToRealSubChannelMap; + } + + @Override + public Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners() { + return subchannelStateListeners; + } + } } diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java new file mode 100644 index 000000000..409861783 --- /dev/null +++ b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java @@ -0,0 +1,156 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.Mockito.mock; + +import io.grpc.Attributes; +import io.grpc.Channel; +import io.grpc.ChannelLogger; +import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.Subchannel; +import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * A real class that can be used as a delegate of a mock Helper to provide more real representation + * and track the subchannels as is needed with petiole policies where the subchannels are no + * longer direct children of the loadbalancer. + * <br> + * To use it replace <br> + * \@mock Helper mockHelper<br> + * with<br> + * <p>Helper mockHelper = mock(Helper.class, delegatesTo(new TestHelper()));</p> + * <br> + * TestHelper will need to define accessors for the maps that information is store within as + * those maps need to be defined in the Test class. + */ +public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper { + + public abstract Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap(); + + public abstract Map<Subchannel, Subchannel> getMockToRealSubChannelMap(); + + public abstract Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners(); + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + // do nothing, should have been done in the wrapper helpers + } + + @Override + protected Helper delegate() { + throw new UnsupportedOperationException("This helper class is only for use in this test"); + } + + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + Subchannel subchannel = getSubchannelMap().get(args.getAddresses()); + if (subchannel == null) { + TestSubchannel delegate = new TestSubchannel(args); + subchannel = mock(Subchannel.class, delegatesTo(delegate)); + getSubchannelMap().put(args.getAddresses(), subchannel); + getMockToRealSubChannelMap().put(subchannel, delegate); + } + + return subchannel; + } + + @Override + public void refreshNameResolution() { + // no-op + } + + public void setChannel(Subchannel subchannel, Channel channel) { + ((TestSubchannel)subchannel).channel = channel; + } + + @Override + public String toString() { + return "Test Helper"; + } + + private class TestSubchannel extends ForwardingSubchannel { + final CreateSubchannelArgs args; + Channel channel; + + public TestSubchannel(CreateSubchannelArgs args) { + this.args = args; + } + + @Override + protected Subchannel delegate() { + throw new UnsupportedOperationException("Only to be used in tests"); + } + + @Override + public List<EquivalentAddressGroup> getAllAddresses() { + return args.getAddresses(); + } + + @Override + public Attributes getAttributes() { + return args.getAttributes(); + } + + @Override + public void requestConnection() { + // Ignore, we will manually update state + } + + @Override + public void updateAddresses(List<EquivalentAddressGroup> addrs) { + // Do nothing, will be handled in wrappers + } + + @Override + public void start(SubchannelStateListener listener) { + getSubchannelStateListeners().put(this, listener); + } + + @Override + public void shutdown() { + getSubchannelStateListeners().remove(this); + for (EquivalentAddressGroup eag : getAllAddresses()) { + getSubchannelMap().remove(Collections.singletonList(eag)); + } + } + + @Override + public Channel asChannel() { + return channel; + } + + @Override + public ChannelLogger getChannelLogger() { + return mock(ChannelLogger.class); + } + + @Override + public String toString() { + return "Mock Subchannel" + args.toString(); + } + } +} + diff --git a/xds/build.gradle b/xds/build.gradle index 3f3cf6a0f..a6db9db99 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -58,7 +58,8 @@ dependencies { def nettyDependency = implementation project(':grpc-netty') testImplementation project(':grpc-rls') - testImplementation testFixtures(project(':grpc-core')) + testImplementation testFixtures(project(':grpc-core')), + testFixtures(project(':grpc-util')) annotationProcessor libraries.auto.value // At runtime use the epoll included in grpc-netty-shaded diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java index a44892042..895125d32 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java @@ -16,36 +16,68 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import io.grpc.InternalLogId; +import io.grpc.LoadBalancerProvider; import io.grpc.Status; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.util.MultiChildLoadBalancer; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; import io.grpc.xds.XdsLogger.XdsLogLevel; import java.util.HashMap; import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; /** * The top-level load balancing policy. */ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer { + @VisibleForTesting + public static final int DELAYED_CHILD_DELETION_TIME_MINUTES = 15; + protected final SynchronizationContext syncContext; + private final ScheduledExecutorService timeService; private final XdsLogger logger; ClusterManagerLoadBalancer(Helper helper) { super(helper); + this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); + this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); logger = XdsLogger.withLogId( InternalLogId.allocate("cluster_manager-lb", helper.getAuthority())); + logger.log(XdsLogLevel.INFO, "Created"); } @Override - protected Map<Object, PolicySelection> getPolicySelectionMap( - ResolvedAddresses resolvedAddresses) { + protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, + Object childConfig) { + return resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build(); + } + + @Override + protected Map<Object, ChildLbState> createChildLbMap(ResolvedAddresses resolvedAddresses) { ClusterManagerConfig config = (ClusterManagerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - Map<Object, PolicySelection> newChildPolicies = new HashMap<>(config.childPolicies); + Map<Object, ChildLbState> newChildPolicies = new HashMap<>(); + if (config != null) { + for (Entry<String, PolicySelection> entry : config.childPolicies.entrySet()) { + ChildLbState child = getChildLbState(entry.getKey()); + if (child == null) { + child = new ClusterManagerLbState(entry.getKey(), + entry.getValue().getProvider(), entry.getValue().getConfig(), getInitialPicker()); + } + newChildPolicies.put(entry.getKey(), child); + } + } logger.log( XdsLogLevel.INFO, "Received cluster_manager lb config: child names={0}", newChildPolicies.keySet()); @@ -75,4 +107,58 @@ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer { } }; } + + private class ClusterManagerLbState extends ChildLbState { + @Nullable + ScheduledHandle deletionTimer; + + public ClusterManagerLbState(Object key, LoadBalancerProvider policyProvider, + Object childConfig, SubchannelPicker initialPicker) { + super(key, policyProvider, childConfig, initialPicker); + } + + @Override + protected void shutdown() { + if (deletionTimer != null && deletionTimer.isPending()) { + deletionTimer.cancel(); + } + super.shutdown(); + } + + @Override + protected void reactivate(LoadBalancerProvider policyProvider) { + if (deletionTimer != null && deletionTimer.isPending()) { + deletionTimer.cancel(); + logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey()); + } + + super.reactivate(policyProvider); + } + + @Override + protected void deactivate() { + if (isDeactivated()) { + return; + } + + class DeletionTask implements Runnable { + + @Override + public void run() { + shutdown(); + removeChild(getKey()); + } + } + + deletionTimer = + syncContext.schedule( + new DeletionTask(), + DELAYED_CHILD_DELETION_TIME_MINUTES, + TimeUnit.MINUTES, + timeService); + setDeactivated(); + logger.log(XdsLogLevel.DEBUG, "Child balancer {0} deactivated", getKey()); + } + + } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 833683729..216221d25 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -17,17 +17,20 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkElementIndex; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Deadline.Ticker; import io.grpc.EquivalentAddressGroup; import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; import io.grpc.NameResolver; import io.grpc.Status; import io.grpc.SynchronizationContext; @@ -40,11 +43,13 @@ import io.grpc.xds.orca.OrcaOobUtil; import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil; import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -91,6 +96,14 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { } @Override + protected ChildLbState createChildLbState(Object key, Object policyConfig, + SubchannelPicker initialPicker) { + ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig, + initialPicker); + return childLbState; + } + + @Override public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (resolvedAddresses.getLoadBalancingPolicyConfig() == null) { handleNameResolutionError(Status.UNAVAILABLE.withDescription( @@ -111,9 +124,100 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { } @Override - public RoundRobinPicker createReadyPicker(List<Subchannel> activeList) { - return new WeightedRoundRobinPicker(activeList, config.enableOobLoadReport, - config.errorUtilizationPenalty); + public RoundRobinPicker createReadyPicker(Collection<ChildLbState> activeList) { + return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), + config.enableOobLoadReport, config.errorUtilizationPenalty); + } + + @Override + protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) { + return super.getChildLbStateEag(eag); + } + + @VisibleForTesting + final class WeightedChildLbState extends ChildLbState { + + private final Set<WrrSubchannel> subchannels = new HashSet<>(); + private volatile long lastUpdated; + private volatile long nonEmptySince; + private volatile double weight = 0; + + private OrcaReportListener orcaReportListener; + + public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, + SubchannelPicker initialPicker) { + super(key, policyProvider, childConfig, initialPicker); + } + + @VisibleForTesting + EquivalentAddressGroup getEag() { + return stripAttrs((EquivalentAddressGroup) getKey()); + } + + private double getWeight() { + if (config == null) { + return 0; + } + long now = ticker.nanoTime(); + if (now - lastUpdated >= config.weightExpirationPeriodNanos) { + nonEmptySince = infTime; + return 0; + } else if (now - nonEmptySince < config.blackoutPeriodNanos + && config.blackoutPeriodNanos > 0) { + return 0; + } else { + return weight; + } + } + + public void addSubchannel(WrrSubchannel wrrSubchannel) { + subchannels.add(wrrSubchannel); + } + + public OrcaReportListener getOrCreateOrcaListener(float errorUtilizationPenalty) { + if (orcaReportListener != null + && orcaReportListener.errorUtilizationPenalty == errorUtilizationPenalty) { + return orcaReportListener; + } + orcaReportListener = new OrcaReportListener(errorUtilizationPenalty); + return orcaReportListener; + } + + public void removeSubchannel(WrrSubchannel wrrSubchannel) { + subchannels.remove(wrrSubchannel); + } + + final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { + private final float errorUtilizationPenalty; + + OrcaReportListener(float errorUtilizationPenalty) { + this.errorUtilizationPenalty = errorUtilizationPenalty; + } + + @Override + public void onLoadReport(MetricReport report) { + double newWeight = 0; + // Prefer application utilization and fallback to CPU utilization if unset. + double utilization = + report.getApplicationUtilization() > 0 ? report.getApplicationUtilization() + : report.getCpuUtilization(); + if (utilization > 0 && report.getQps() > 0) { + double penalty = 0; + if (report.getEps() > 0 && errorUtilizationPenalty > 0) { + penalty = report.getEps() / report.getQps() * errorUtilizationPenalty; + } + newWeight = report.getQps() / (utilization + penalty); + } + if (newWeight == 0) { + return; + } + if (nonEmptySince == infTime) { + nonEmptySince = ticker.nanoTime(); + } + lastUpdated = ticker.nanoTime(); + weight = newWeight; + } + } } private final class UpdateWeightTask implements Runnable { @@ -128,16 +232,18 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { } private void afterAcceptAddresses() { - for (Subchannel subchannel : getSubchannels()) { - WrrSubchannel weightedSubchannel = (WrrSubchannel) subchannel; - if (config.enableOobLoadReport) { - OrcaOobUtil.setListener(weightedSubchannel, - weightedSubchannel.new OrcaReportListener(config.errorUtilizationPenalty), - OrcaOobUtil.OrcaReportingConfig.newBuilder() - .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) - .build()); - } else { - OrcaOobUtil.setListener(weightedSubchannel, null, null); + for (ChildLbState child : getChildLbStates()) { + WeightedChildLbState wChild = (WeightedChildLbState) child; + for (WrrSubchannel weightedSubchannel : wChild.subchannels) { + if (config.enableOobLoadReport) { + OrcaOobUtil.setListener(weightedSubchannel, + wChild.getOrCreateOrcaListener(config.errorUtilizationPenalty), + OrcaOobUtil.OrcaReportingConfig.newBuilder() + .setReportInterval(config.oobReportingPeriodNanos, TimeUnit.NANOSECONDS) + .build()); + } else { + OrcaOobUtil.setListener(weightedSubchannel, null, null); + } } } } @@ -169,105 +275,69 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { - return wrr.new WrrSubchannel(delegate().createSubchannel(args)); + checkElementIndex(0, args.getAddresses().size(), "Empty address group"); + WeightedChildLbState childLbState = + (WeightedChildLbState) wrr.getChildLbStateEag(args.getAddresses().get(0)); + return wrr.new WrrSubchannel(delegate().createSubchannel(args), childLbState); } } @VisibleForTesting final class WrrSubchannel extends ForwardingSubchannel { private final Subchannel delegate; - private volatile long lastUpdated; - private volatile long nonEmptySince; - private volatile double weight; + private final WeightedChildLbState owner; - WrrSubchannel(Subchannel delegate) { + WrrSubchannel(Subchannel delegate, WeightedChildLbState owner) { this.delegate = checkNotNull(delegate, "delegate"); + this.owner = checkNotNull(owner, "owner"); } @Override public void start(SubchannelStateListener listener) { + owner.addSubchannel(this); delegate().start(new SubchannelStateListener() { @Override public void onSubchannelState(ConnectivityStateInfo newState) { if (newState.getState().equals(ConnectivityState.READY)) { - nonEmptySince = infTime; + owner.nonEmptySince = infTime; } listener.onSubchannelState(newState); } }); } - private double getWeight() { - if (config == null) { - return 0; - } - long now = ticker.nanoTime(); - if (now - lastUpdated >= config.weightExpirationPeriodNanos) { - nonEmptySince = infTime; - return 0; - } else if (now - nonEmptySince < config.blackoutPeriodNanos - && config.blackoutPeriodNanos > 0) { - return 0; - } else { - return weight; - } - } - @Override protected Subchannel delegate() { return delegate; } - final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { - private final float errorUtilizationPenalty; - - OrcaReportListener(float errorUtilizationPenalty) { - this.errorUtilizationPenalty = errorUtilizationPenalty; - } - - @Override - public void onLoadReport(MetricReport report) { - double newWeight = 0; - // Prefer application utilization and fallback to CPU utilization if unset. - double utilization = - report.getApplicationUtilization() > 0 ? report.getApplicationUtilization() - : report.getCpuUtilization(); - if (utilization > 0 && report.getQps() > 0) { - double penalty = 0; - if (report.getEps() > 0 && errorUtilizationPenalty > 0) { - penalty = report.getEps() / report.getQps() * errorUtilizationPenalty; - } - newWeight = report.getQps() / (utilization + penalty); - } - if (newWeight == 0) { - return; - } - if (nonEmptySince == infTime) { - nonEmptySince = ticker.nanoTime(); - } - lastUpdated = ticker.nanoTime(); - weight = newWeight; - } + @Override + public void shutdown() { + super.shutdown(); + owner.removeSubchannel(this); } } @VisibleForTesting final class WeightedRoundRobinPicker extends RoundRobinPicker { - private final List<Subchannel> list; + private final List<ChildLbState> children; private final Map<Subchannel, OrcaPerRequestReportListener> subchannelToReportListenerMap = new HashMap<>(); private final boolean enableOobLoadReport; private final float errorUtilizationPenalty; private volatile StaticStrideScheduler scheduler; - WeightedRoundRobinPicker(List<Subchannel> list, boolean enableOobLoadReport, + WeightedRoundRobinPicker(List<ChildLbState> children, boolean enableOobLoadReport, float errorUtilizationPenalty) { - checkNotNull(list, "list"); - Preconditions.checkArgument(!list.isEmpty(), "empty list"); - this.list = list; - for (Subchannel subchannel : list) { - this.subchannelToReportListenerMap.put(subchannel, - ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty)); + checkNotNull(children, "children"); + Preconditions.checkArgument(!children.isEmpty(), "empty child list"); + this.children = children; + for (ChildLbState child : children) { + WeightedChildLbState wChild = (WeightedChildLbState) child; + for (WrrSubchannel subchannel : wChild.subchannels) { + this.subchannelToReportListenerMap + .put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); + } } this.enableOobLoadReport = enableOobLoadReport; this.errorUtilizationPenalty = errorUtilizationPenalty; @@ -276,22 +346,24 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - Subchannel subchannel = list.get(scheduler.pick()); + ChildLbState childLbState = children.get(scheduler.pick()); + WeightedChildLbState wChild = (WeightedChildLbState) childLbState; + PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args); + Subchannel subchannel = pickResult.getSubchannel(); if (!enableOobLoadReport) { return PickResult.withSubchannel(subchannel, - OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( + OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( subchannelToReportListenerMap.getOrDefault(subchannel, - ((WrrSubchannel) subchannel).new OrcaReportListener(errorUtilizationPenalty)))); + wChild.getOrCreateOrcaListener(errorUtilizationPenalty)))); } else { return PickResult.withSubchannel(subchannel); } } private void updateWeight() { - float[] newWeights = new float[list.size()]; - for (int i = 0; i < list.size(); i++) { - WrrSubchannel subchannel = (WrrSubchannel) list.get(i); - double newWeight = subchannel.getWeight(); + float[] newWeights = new float[children.size()]; + for (int i = 0; i < children.size(); i++) { + double newWeight = ((WeightedChildLbState)children.get(i)).getWeight(); newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; } this.scheduler = new StaticStrideScheduler(newWeights, sequence); @@ -302,12 +374,12 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) .add("enableOobLoadReport", enableOobLoadReport) .add("errorUtilizationPenalty", errorUtilizationPenalty) - .add("list", list).toString(); + .add("list", children).toString(); } @VisibleForTesting - List<Subchannel> getList() { - return list; + List<ChildLbState> getChildren() { + return children; } @Override @@ -322,7 +394,8 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { // the lists cannot contain duplicate subchannels return enableOobLoadReport == other.enableOobLoadReport && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0 - && list.size() == other.list.size() && new HashSet<>(list).containsAll(other.list); + && children.size() == other.children.size() && new HashSet<>( + children).containsAll(other.children); } } @@ -504,11 +577,13 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { } + @SuppressWarnings("UnusedReturnValue") Builder setBlackoutPeriodNanos(long blackoutPeriodNanos) { this.blackoutPeriodNanos = blackoutPeriodNanos; return this; } + @SuppressWarnings("UnusedReturnValue") Builder setWeightExpirationPeriodNanos(long weightExpirationPeriodNanos) { this.weightExpirationPeriodNanos = weightExpirationPeriodNanos; return this; diff --git a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java index c90a9f58d..32e905225 100644 --- a/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java +++ b/xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java @@ -202,7 +202,9 @@ public final class OrcaOobUtil { */ public static void setListener(Subchannel subchannel, OrcaOobReportListener listener, OrcaReportingConfig config) { - SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY); + Attributes attributes = subchannel.getAttributes(); + SubchannelImpl orcaSubchannel = + (attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY); if (orcaSubchannel == null) { throw new IllegalArgumentException("Subchannel does not have orca Out-Of-Band stream enabled." + " Try to use a subchannel created by OrcaOobUtil.OrcaHelper."); @@ -241,7 +243,9 @@ public final class OrcaOobUtil { public Subchannel createSubchannel(CreateSubchannelArgs args) { syncContext.throwIfNotInThisSynchronizationContext(); Subchannel subchannel = super.createSubchannel(args); - SubchannelImpl orcaSubchannel = subchannel.getAttributes().get(ORCA_REPORTING_STATE_KEY); + Attributes attributes = subchannel.getAttributes(); + SubchannelImpl orcaSubchannel = + (attributes == null) ? null : attributes.get(ORCA_REPORTING_STATE_KEY); OrcaReportingState orcaState; if (orcaSubchannel == null) { // Only the first load balancing policy requesting ORCA reports instantiates an diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index ac08f69f8..c59ad1318 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -17,11 +17,10 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -35,7 +34,6 @@ import com.google.common.collect.Maps; import com.google.protobuf.Duration; import io.grpc.Attributes; import io.grpc.Channel; -import io.grpc.ChannelLogger; import io.grpc.ClientCall; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -50,12 +48,15 @@ import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; +import io.grpc.internal.TestUtils; import io.grpc.services.InternalCallMetricRecorder; import io.grpc.services.MetricReport; +import io.grpc.util.AbstractTestHelper; +import io.grpc.util.MultiChildLoadBalancer.ChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.StaticStrideScheduler; +import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedChildLbState; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancerConfig; import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; -import io.grpc.xds.WeightedRoundRobinLoadBalancer.WrrSubchannel; import java.net.SocketAddress; import java.util.Arrays; import java.util.HashMap; @@ -67,6 +68,7 @@ import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; @@ -87,8 +89,8 @@ public class WeightedRoundRobinLoadBalancerTest { @Rule public final MockitoRule mockito = MockitoJUnit.rule(); - @Mock - Helper helper; + private final TestHelper testHelperInstance = new TestHelper(); + private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance)); @Mock private LoadBalancer.PickSubchannelArgs mockArgs; @@ -99,9 +101,8 @@ public class WeightedRoundRobinLoadBalancerTest { private ArgumentCaptor<SubchannelPicker> pickerCaptor2; private final List<EquivalentAddressGroup> servers = Lists.newArrayList(); - private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = Maps.newLinkedHashMap(); - + private final Map<Subchannel, Subchannel> mockToRealSubChannelMap = new HashMap<>(); private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners = Maps.newLinkedHashMap(); @@ -134,7 +135,8 @@ public class WeightedRoundRobinLoadBalancerTest { SocketAddress addr = new FakeSocketAddress("server" + i); EquivalentAddressGroup eag = new EquivalentAddressGroup(addr); servers.add(eag); - Subchannel sc = mock(Subchannel.class); + Subchannel sc = helper.createSubchannel(CreateSubchannelArgs.newBuilder().setAddresses(eag) + .build()); Channel channel = mock(Channel.class); when(channel.newCall(any(), any())).then( new Answer<ClientCall<OrcaLoadReportRequest, OrcaLoadReport>>() { @@ -147,35 +149,13 @@ public class WeightedRoundRobinLoadBalancerTest { return clientCall; } }); - when(sc.asChannel()).thenReturn(channel); + testHelperInstance.setChannel(mockToRealSubChannelMap.get(sc), channel); subchannels.put(Arrays.asList(eag), sc); } - when(helper.getSynchronizationContext()).thenReturn(syncContext); - when(helper.getScheduledExecutorService()).thenReturn( - fakeClock.getScheduledExecutorService()); - when(helper.createSubchannel(any(CreateSubchannelArgs.class))) - .then(new Answer<Subchannel>() { - @Override - public Subchannel answer(InvocationOnMock invocation) throws Throwable { - CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; - final Subchannel subchannel = subchannels.get(args.getAddresses()); - when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); - when(subchannel.getAttributes()).thenReturn(args.getAttributes()); - when(subchannel.getChannelLogger()).thenReturn(mock(ChannelLogger.class)); - doAnswer( - new Answer<Void>() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - subchannelStateListeners.put( - subchannel, (SubchannelStateListener) invocation.getArguments()[0]); - return null; - } - }).when(subchannel).start(any(SubchannelStateListener.class)); - return subchannel; - } - }); wrr = new WeightedRoundRobinLoadBalancer(helper, fakeClock.getDeadlineTicker(), new FakeRandom(0)); + + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); } @Test @@ -183,44 +163,44 @@ public class WeightedRoundRobinLoadBalancerTest { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator<Subchannel> it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel connectingSubchannel = it.next(); - subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.CONNECTING)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); - assertThat(weightedPicker.getList().size()).isEqualTo(1); + assertThat(weightedPicker.getChildren().size()).isEqualTo(1); weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - assertThat(weightedPicker.getList().size()).isEqualTo(2); + assertThat(weightedPicker.getChildren().size()).isEqualTo(2); String weightedPickerStr = weightedPicker.toString(); assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0"); assertThat(weightedPickerStr).contains("list="); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - assertThat(weightedPicker.pickSubchannel(mockArgs) - .getSubchannel()).isEqualTo(weightedSubchannel1); + + assertThat(getAddressesFromPick(weightedPicker)).isEqualTo(weightedChild1.getEag()); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -238,35 +218,44 @@ public class WeightedRoundRobinLoadBalancerTest { verifyNoMoreInteractions(mockArgs); } + /** + * Picks subchannel using mockArgs, gets its EAG, and then strips the Attrs to make a key. + */ + private EquivalentAddressGroup getAddressesFromPick(WeightedRoundRobinPicker weightedPicker) { + return TestUtils.stripAttrs( + weightedPicker.pickSubchannel(mockArgs).getSubchannel().getAddresses()); + } + @Test public void enableOobLoadReportConfig() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); Iterator<Subchannel> it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.9, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(getAddresses(pickResult)) + .isEqualTo(weightedChild1.getEag()); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); // verify per-request listener assertThat(oobCalls.isEmpty()).isTrue(); @@ -280,7 +269,8 @@ public class WeightedRoundRobinLoadBalancerTest { eq(ConnectivityState.READY), pickerCaptor2.capture()); weightedPicker = (WeightedRoundRobinPicker) pickerCaptor2.getAllValues().get(2); pickResult = weightedPicker.pickSubchannel(mockArgs); - assertThat(pickResult.getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(getAddresses(pickResult)) + .isEqualTo(weightedChild1.getEag()); assertThat(pickResult.getStreamTracerFactory()).isNull(); OrcaLoadReportRequest golden = OrcaLoadReportRequest.newBuilder().setReportInterval( Duration.newBuilder().setSeconds(20).setNanos(30000000).build()).build(); @@ -295,46 +285,52 @@ public class WeightedRoundRobinLoadBalancerTest { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator<Subchannel> it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel3 = it.next(); - subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel3).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - r1); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - r2); - weightedSubchannel3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( - r3); + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r1); + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r2); + weightedChild3.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport(r3); + assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - Map<Subchannel, Integer> pickCount = new HashMap<>(); + Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>(); for (int i = 0; i < 10000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio)) - .isLessThan(0.0002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio )) - .isLessThan(0.0002); - assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio )) - .isLessThan(0.0002); + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - subchannel1PickRatio)) + .isAtMost(0.0002); + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - subchannel2PickRatio )) + .isAtMost(0.0002); + assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 10000.0 - subchannel3PickRatio )) + .isAtMost(0.0002); + } + + private SubchannelStateListener getSubchannelStateListener(Subchannel mockSubChannel) { + return subchannelStateListeners.get(mockToRealSubChannelMap.get(mockSubChannel)); + } + + private static ChildLbState getChild(WeightedRoundRobinPicker picker, int index) { + return picker.getChildren().get(index); } @Test @@ -472,14 +468,14 @@ public class WeightedRoundRobinLoadBalancerTest { assertThat(wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(null) .setAttributes(affinity).build())).isFalse(); - verify(helper, never()).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(fakeClock.getPendingTasks()).isEmpty(); syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().getClass().getName()) @@ -492,51 +488,51 @@ public class WeightedRoundRobinLoadBalancerTest { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator<Subchannel> it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); - Map<Subchannel, Integer> pickCount = new HashMap<>(); - for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>(); + for (int i = 0; i < 10000; i++) { + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); // within blackout period, fallback to simple round robin - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)).isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 0.5)).isLessThan(0.002); assertThat(fakeClock.forwardTime(5, TimeUnit.SECONDS)).isEqualTo(1); pickCount = new HashMap<>(); - for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + for (int i = 0; i < 10000; i++) { + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 10000.0 - 2.0 / 3)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 10000.0 - 1.0 / 3)) .isLessThan(0.002); } @@ -545,39 +541,39 @@ public class WeightedRoundRobinLoadBalancerTest { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator<Subchannel> it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel connectingSubchannel = it.next(); - subchannelStateListeners.get(connectingSubchannel).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.CONNECTING)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(0); - assertThat(weightedPicker.getList().size()).isEqualTo(1); + assertThat(weightedPicker.getChildren().size()).isEqualTo(1); weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - assertThat(weightedPicker.getList().size()).isEqualTo(2); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + assertThat(weightedPicker.getChildren().size()).isEqualTo(2); + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); - assertThat(weightedPicker.pickSubchannel(mockArgs) - .getSubchannel()).isEqualTo(weightedSubchannel1); + assertThat(getAddressesFromPick(weightedPicker)) + .isEqualTo(weightedChild1.getEag()); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); weightedConfig = WeightedRoundRobinLoadBalancerConfig.newBuilder() .setWeightUpdatePeriodNanos(500_000_000L) //.5s @@ -586,17 +582,18 @@ public class WeightedRoundRobinLoadBalancerTest { .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); //timer fires, new weight updated assertThat(fakeClock.forwardTime(500, TimeUnit.MILLISECONDS)).isEqualTo(1); - assertThat(weightedPicker.pickSubchannel(mockArgs) - .getSubchannel()).isEqualTo(weightedSubchannel2); - + assertThat(getAddressesFromPick(weightedPicker)) + .isEqualTo(weightedChild2.getEag()); + assertThat(getAddressesFromPick(weightedPicker)) + .isEqualTo(weightedChild1.getEag()); } @Test @@ -604,52 +601,52 @@ public class WeightedRoundRobinLoadBalancerTest { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator<Subchannel> it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - Map<Subchannel, Integer> pickCount = new HashMap<>(); + Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) .isLessThan(0.002); // weight expired, fallback to simple round robin assertThat(fakeClock.forwardTime(300, TimeUnit.SECONDS)).isEqualTo(1); pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + EquivalentAddressGroup result = getAddressesFromPick(weightedPicker); pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 0.5)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 0.5)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 0.5)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 0.5)) .isLessThan(0.002); } @@ -658,107 +655,113 @@ public class WeightedRoundRobinLoadBalancerTest { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator<Subchannel> it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - Map<WrrSubchannel, Integer> qpsByChannel = ImmutableMap.of(weightedSubchannel1, 2, - weightedSubchannel2, 1); - Map<Subchannel, Integer> pickCount = new HashMap<>(); + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + Map<EquivalentAddressGroup, Integer> qpsByChannel = ImmutableMap.of(weightedChild1.getEag(), 2, + weightedChild2.getEag(), 1); + Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - pickCount.put(pickResult.getSubchannel(), - pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1); + EquivalentAddressGroup addresses = getAddresses(pickResult); + pickCount.merge(addresses, 1, Integer::sum); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WrrSubchannel subchannel = (WrrSubchannel)pickResult.getSubchannel(); - subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); + childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(subchannel), 0, + 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); } - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 1.0 / 2)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 1.0 / 2)) .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 2)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 2)) .isAtMost(0.1); + + // Identical to above except forwards time after each pick pickCount.clear(); for (int i = 0; i < 1000; i++) { PickResult pickResult = weightedPicker.pickSubchannel(mockArgs); - pickCount.put(pickResult.getSubchannel(), - pickCount.getOrDefault(pickResult.getSubchannel(), 0) + 1); + EquivalentAddressGroup addresses = getAddresses(pickResult); + pickCount.merge(addresses, 1, Integer::sum); assertThat(pickResult.getStreamTracerFactory()).isNotNull(); - WrrSubchannel subchannel = (WrrSubchannel) pickResult.getSubchannel(); - subchannel.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState childLbState = (WeightedChildLbState) wrr.getChildLbStateEag(addresses); + childLbState.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( - 0.1, 0, 0.1, qpsByChannel.get(subchannel), 0, + 0.1, 0, 0.1, qpsByChannel.get(addresses), 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); fakeClock.forwardTime(50, TimeUnit.MILLISECONDS); } assertThat(pickCount.size()).isEqualTo(2); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 2.0 / 3)) .isAtMost(0.1); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 1.0 / 3)) .isAtMost(0.1); } + private static EquivalentAddressGroup getAddresses(PickResult pickResult) { + return TestUtils.stripAttrs(pickResult.getSubchannel().getAddresses()); + } + @Test public void unknownWeightIsAvgWeight() { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( - any(CreateSubchannelArgs.class)); + verify(helper, times(6)).createSubchannel( + any(CreateSubchannelArgs.class)); // 3 from setup plus 3 from the execute assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator<Subchannel> it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo - .forNonError(ConnectivityState.READY)); + getSubchannelStateListener(readySubchannel1) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo - .forNonError(ConnectivityState.READY)); + getSubchannelStateListener(readySubchannel2) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); Subchannel readySubchannel3 = it.next(); - subchannelStateListeners.get(readySubchannel3).onSubchannelState(ConnectivityStateInfo - .forNonError(ConnectivityState.READY)); + getSubchannelStateListener(readySubchannel3) + .onSubchannelState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); verify(helper, times(3)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(2); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - WrrSubchannel weightedSubchannel3 = (WrrSubchannel) weightedPicker.getList().get(2); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + WeightedChildLbState weightedChild3 = (WeightedChildLbState) getChild(weightedPicker, 2); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); - Map<Subchannel, Integer> pickCount = new HashMap<>(); + Map<EquivalentAddressGroup, Integer> pickCount = new HashMap<>(); for (int i = 0; i < 1000; i++) { Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); - pickCount.put(result, pickCount.getOrDefault(result, 0) + 1); + pickCount.merge(result.getAddresses(), 1, Integer::sum); } assertThat(pickCount.size()).isEqualTo(3); - assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()) / 1000.0 - 4.0 / 9)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()) / 1000.0 - 2.0 / 9)) .isLessThan(0.002); // subchannel3's weight is average of subchannel1 and subchannel2 - assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9)) + assertThat(Math.abs(pickCount.get(weightedChild3.getEag()) / 1000.0 - 3.0 / 9)) .isLessThan(0.002); } @@ -767,33 +770,33 @@ public class WeightedRoundRobinLoadBalancerTest { syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) .setAttributes(affinity).build())); - verify(helper, times(3)).createSubchannel( + verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); assertThat(fakeClock.getPendingTasks().size()).isEqualTo(1); Iterator<Subchannel> it = subchannels.values().iterator(); Subchannel readySubchannel1 = it.next(); - subchannelStateListeners.get(readySubchannel1).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel1).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); Subchannel readySubchannel2 = it.next(); - subchannelStateListeners.get(readySubchannel2).onSubchannelState(ConnectivityStateInfo + getSubchannelStateListener(readySubchannel2).onSubchannelState(ConnectivityStateInfo .forNonError(ConnectivityState.READY)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); WeightedRoundRobinPicker weightedPicker = (WeightedRoundRobinPicker) pickerCaptor.getAllValues().get(1); - WrrSubchannel weightedSubchannel1 = (WrrSubchannel) weightedPicker.getList().get(0); - WrrSubchannel weightedSubchannel2 = (WrrSubchannel) weightedPicker.getList().get(1); - weightedSubchannel1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); + WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); + weightedChild1.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.1, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); - weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( + weightedChild2.new OrcaReportListener(weightedConfig.errorUtilizationPenalty).onLoadReport( InternalCallMetricRecorder.createMetricReport( 0.2, 0, 0.1, 1, 0, new HashMap<>(), new HashMap<>(), new HashMap<>())); CyclicBarrier barrier = new CyclicBarrier(2); - Map<Subchannel, AtomicInteger> pickCount = new ConcurrentHashMap<>(); - pickCount.put(weightedSubchannel1, new AtomicInteger(0)); - pickCount.put(weightedSubchannel2, new AtomicInteger(0)); + Map<EquivalentAddressGroup, AtomicInteger> pickCount = new ConcurrentHashMap<>(); + pickCount.put(weightedChild1.getEag(), new AtomicInteger(0)); + pickCount.put(weightedChild2.getEag(), new AtomicInteger(0)); new Thread(new Runnable() { @Override public void run() { @@ -802,7 +805,7 @@ public class WeightedRoundRobinLoadBalancerTest { barrier.await(); for (int i = 0; i < 1000; i++) { Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); - pickCount.get(result).addAndGet(1); + pickCount.get(result.getAddresses()).addAndGet(1); } barrier.await(); } catch (Exception ex) { @@ -813,15 +816,15 @@ public class WeightedRoundRobinLoadBalancerTest { assertThat(fakeClock.forwardTime(10, TimeUnit.SECONDS)).isEqualTo(1); barrier.await(); for (int i = 0; i < 1000; i++) { - Subchannel result = weightedPicker.pickSubchannel(mockArgs).getSubchannel(); + EquivalentAddressGroup result = getAddresses(weightedPicker.pickSubchannel(mockArgs)); pickCount.get(result).addAndGet(1); } barrier.await(); assertThat(pickCount.size()).isEqualTo(2); // after blackout period - assertThat(Math.abs(pickCount.get(weightedSubchannel1).get() / 2000.0 - 2.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild1.getEag()).get() / 2000.0 - 2.0 / 3)) .isLessThan(0.002); - assertThat(Math.abs(pickCount.get(weightedSubchannel2).get() / 2000.0 - 1.0 / 3)) + assertThat(Math.abs(pickCount.get(weightedChild2.getEag()).get() / 2000.0 - 1.0 / 3)) .isLessThan(0.002); } @@ -1104,4 +1107,34 @@ public class WeightedRoundRobinLoadBalancerTest { return nextInt; } } + + private class TestHelper extends AbstractTestHelper { + + @Override + public Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap() { + return subchannels; + } + + @Override + public Map<Subchannel, Subchannel> getMockToRealSubChannelMap() { + return mockToRealSubChannelMap; + } + + @Override + public Map<Subchannel, SubchannelStateListener> getSubchannelStateListeners() { + return subchannelStateListeners; + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return fakeClock.getScheduledExecutorService(); + } + + + } } |