diff --git a/datastream-server-api/src/main/java/com/linkedin/datastream/server/DatastreamSourceClusterResolver.java b/datastream-server-api/src/main/java/com/linkedin/datastream/server/DatastreamSourceClusterResolver.java deleted file mode 100644 index 580764539..000000000 --- a/datastream-server-api/src/main/java/com/linkedin/datastream/server/DatastreamSourceClusterResolver.java +++ /dev/null @@ -1,19 +0,0 @@ -/** - * Copyright 2021 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD 2-Clause License. See the LICENSE file in the project root for license information. - * See the NOTICE file in the project root for additional information regarding copyright ownership. - */ -package com.linkedin.datastream.server; - -/** - * An interface that resolves the source Kafka cluster from the given {@link DatastreamGroup} instance - */ -public interface DatastreamSourceClusterResolver { - - /** - * Given a datastream group, gets the name of the source cluster - * @param datastreamGroup Datastream group - * @return The name of the source cluster - */ - String getSourceCluster(DatastreamGroup datastreamGroup); -} diff --git a/datastream-server-api/src/main/java/com/linkedin/datastream/server/providers/PartitionThroughputProvider.java b/datastream-server-api/src/main/java/com/linkedin/datastream/server/providers/PartitionThroughputProvider.java index de0fce8ae..199670752 100644 --- a/datastream-server-api/src/main/java/com/linkedin/datastream/server/providers/PartitionThroughputProvider.java +++ b/datastream-server-api/src/main/java/com/linkedin/datastream/server/providers/PartitionThroughputProvider.java @@ -8,6 +8,7 @@ import java.util.Map; import com.linkedin.datastream.server.ClusterThroughputInfo; +import com.linkedin.datastream.server.DatastreamGroup; /** @@ -23,6 +24,13 @@ public interface PartitionThroughputProvider { */ ClusterThroughputInfo getThroughputInfo(String clusterName); + /** + * Retrieves per-partition throughput information for the given datastream group + * @param datastreamGroup Datastream group + * @return Throughput information for the provided datastream group + */ + ClusterThroughputInfo getThroughputInfo(DatastreamGroup datastreamGroup); + /** * Retrieves per-partition throughput information for all clusters * @return A map, where keys are cluster names and values are throughput information for the cluster diff --git a/datastream-server/src/main/java/com/linkedin/datastream/server/DummyDatastreamSourceClusterResolver.java b/datastream-server/src/main/java/com/linkedin/datastream/server/DummyDatastreamSourceClusterResolver.java deleted file mode 100644 index 742d89e8b..000000000 --- a/datastream-server/src/main/java/com/linkedin/datastream/server/DummyDatastreamSourceClusterResolver.java +++ /dev/null @@ -1,18 +0,0 @@ -/** - * Copyright 2021 LinkedIn Corporation. All rights reserved. - * Licensed under the BSD 2-Clause License. See the LICENSE file in the project root for license information. - * See the NOTICE file in the project root for additional information regarding copyright ownership. - */ -package com.linkedin.datastream.server; - -/** - * A dummy implementation for {@link DatastreamSourceClusterResolver} - */ -public class DummyDatastreamSourceClusterResolver implements DatastreamSourceClusterResolver { - private static final String DUMMY_CLUSTER_NAME = "dummy"; - - @Override - public String getSourceCluster(DatastreamGroup datastreamGroup) { - return DUMMY_CLUSTER_NAME; - } -} diff --git a/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/LoadBasedPartitionAssignmentStrategy.java b/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/LoadBasedPartitionAssignmentStrategy.java index 315c3e3e3..529a10320 100644 --- a/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/LoadBasedPartitionAssignmentStrategy.java +++ b/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/LoadBasedPartitionAssignmentStrategy.java @@ -13,18 +13,21 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.Validate; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; + import com.linkedin.datastream.common.PollUtils; import com.linkedin.datastream.common.RetriesExhaustedException; import com.linkedin.datastream.common.zk.ZkClient; import com.linkedin.datastream.server.ClusterThroughputInfo; import com.linkedin.datastream.server.DatastreamGroup; import com.linkedin.datastream.server.DatastreamGroupPartitionsMetadata; -import com.linkedin.datastream.server.DatastreamSourceClusterResolver; import com.linkedin.datastream.server.DatastreamTask; import com.linkedin.datastream.server.Pair; import com.linkedin.datastream.server.providers.PartitionThroughputProvider; @@ -44,28 +47,28 @@ public class LoadBasedPartitionAssignmentStrategy extends StickyPartitionAssignm private static final int TASK_CAPACITY_UTILIZATION_PCT_DEFAULT = 90; private final PartitionThroughputProvider _throughputProvider; - private final DatastreamSourceClusterResolver _sourceClusterResolver; private final int _taskCapacityMBps; private final int _taskCapacityUtilizationPct; private final int _throughputInfoFetchTimeoutMs; private final int _throughputInfoFetchRetryPeriodMs; + // TODO Make these configurable + private final boolean _enableThroughputBasedPartitionAssignment = true; + private final boolean _enablePartitionNumBasedTaskCountEstimation = true; /** * Creates an instance of {@link LoadBasedPartitionAssignmentStrategy} */ public LoadBasedPartitionAssignmentStrategy(PartitionThroughputProvider throughputProvider, - DatastreamSourceClusterResolver sourceClusterResolver, Optional maxTasks, - Optional imbalanceThreshold, Optional maxPartitionPerTask, boolean enableElasticTaskAssignment, - Optional partitionsPerTask, Optional partitionFullnessFactorPct, - Optional taskCapacityMBps, Optional taskCapacityUtilizationPct, - Optional throughputInfoFetchTimeoutMs, Optional throughputInfoFetchRetryPeriodMs, - Optional zkClient, + Optional maxTasks, Optional imbalanceThreshold, Optional maxPartitionPerTask, + boolean enableElasticTaskAssignment, Optional partitionsPerTask, + Optional partitionFullnessFactorPct, Optional taskCapacityMBps, + Optional taskCapacityUtilizationPct, Optional throughputInfoFetchTimeoutMs, + Optional throughputInfoFetchRetryPeriodMs, Optional zkClient, String clusterName) { super(maxTasks, imbalanceThreshold, maxPartitionPerTask, enableElasticTaskAssignment, partitionsPerTask, partitionFullnessFactorPct, zkClient, clusterName); _throughputProvider = throughputProvider; - _sourceClusterResolver = sourceClusterResolver; _taskCapacityMBps = taskCapacityMBps.orElse(TASK_CAPACITY_MBPS_DEFAULT); _taskCapacityUtilizationPct = taskCapacityUtilizationPct.orElse(TASK_CAPACITY_UTILIZATION_PCT_DEFAULT); _throughputInfoFetchTimeoutMs = throughputInfoFetchTimeoutMs.orElse(THROUGHPUT_INFO_FETCH_TIMEOUT_MS_DEFAULT); @@ -77,56 +80,91 @@ public LoadBasedPartitionAssignmentStrategy(PartitionThroughputProvider throughp public Map> assignPartitions(Map> currentAssignment, DatastreamGroupPartitionsMetadata datastreamPartitions) { DatastreamGroup datastreamGroup = datastreamPartitions.getDatastreamGroup(); - String datastreamGroupName = datastreamGroup.getName(); - Pair, Integer> assignedPartitionsAndTaskCount = getAssignedPartitionsAndTaskCountForDatastreamGroup( - currentAssignment, datastreamGroupName); - List assignedPartitions = assignedPartitionsAndTaskCount.getKey(); - // Do throughput based assignment only initially, when no partitions have been assigned yet - if (!assignedPartitions.isEmpty()) { - return super.assignPartitions(currentAssignment, datastreamPartitions); - } - - Map partitionThroughputInfo; - // Attempting to retrieve partition throughput info with a fallback mechanism to StickyPartitionAssignmentStrategy - try { - partitionThroughputInfo = fetchPartitionThroughputInfo(); - } catch (RetriesExhaustedException ex) { - LOG.warn("Attempts to fetch partition throughput timed out. Falling back to regular partition assignment strategy"); + // For throughput based partition-assignment to kick in, the following conditions must be met: + // (1) Elastic task assignment must be enabled through configuration + // (2) Throughput-based task assignment must be enabled through configuration + boolean enableElasticTaskAssignment = isElasticTaskAssignmentEnabled(datastreamGroup); + if (!enableElasticTaskAssignment || !_enableThroughputBasedPartitionAssignment) { + LOG.info("Throughput based elastic task assignment not enabled. Falling back to sticky partition assignment."); + LOG.info("enableElasticTaskAssignment: {}, enableThroughputBasedPartitionAssignment {}", + enableElasticTaskAssignment, _enableThroughputBasedPartitionAssignment); return super.assignPartitions(currentAssignment, datastreamPartitions); } + String datastreamGroupName = datastreamGroup.getName(); + Pair, Integer> assignedPartitionsAndTaskCount = getAssignedPartitionsAndTaskCountForDatastreamGroup( + currentAssignment, datastreamGroupName); + List assignedPartitions = assignedPartitionsAndTaskCount.getKey(); + int taskCount = assignedPartitionsAndTaskCount.getValue(); LOG.info("Old partition assignment info, assignment: {}", currentAssignment); + Validate.isTrue(taskCount > 0, String.format("No tasks found for datastream group %s", datastreamGroup)); Validate.isTrue(currentAssignment.size() > 0, - "Zero tasks assigned. Retry leader partition assignment."); + "Zero tasks assigned. Retry leader partition assignment"); + + // Calculating unassigned partitions + List unassignedPartitions = new ArrayList<>(datastreamPartitions.getPartitions()); + unassignedPartitions.removeAll(assignedPartitions); - // Resolving cluster name from datastream group - String clusterName = _sourceClusterResolver.getSourceCluster(datastreamPartitions.getDatastreamGroup()); - ClusterThroughputInfo clusterThroughputInfo = partitionThroughputInfo.get(clusterName); + ClusterThroughputInfo clusterThroughputInfo = new ClusterThroughputInfo(StringUtils.EMPTY, Collections.emptyMap()); + if (assignedPartitions.isEmpty()) { + try { + // Attempting to retrieve partition throughput info on initial assignment + clusterThroughputInfo = fetchPartitionThroughputInfo(datastreamGroup); + } catch (RetriesExhaustedException ex) { + LOG.warn("Attempts to fetch partition throughput timed out"); + LOG.info("Throughput information unavailable during initial assignment. Falling back to sticky partition assignment"); + return super.assignPartitions(currentAssignment, datastreamPartitions); + } - // TODO Get task count estimate and perform elastic task count validation - // TODO Get task count estimate based on throughput and pick a winner - LoadBasedTaskCountEstimator estimator = new LoadBasedTaskCountEstimator(_taskCapacityMBps, _taskCapacityUtilizationPct); - int maxTaskCount = estimator.getTaskCount(clusterThroughputInfo, Collections.emptyList(), Collections.emptyList()); - LOG.info("Max task count obtained from estimator: {}", maxTaskCount); + // Task count update happens only on initial assignment (when datastream makes the STOPPED -> READY transition). + // The calculation is based on the maximum of: + // (1) Tasks already allocated for the datastream + // (2) Partition number based estimate, if the appropriate config is enabled + // (3) Throughput based task count estimate + int numTasksNeeded = taskCount; + if (_enablePartitionNumBasedTaskCountEstimation) { + numTasksNeeded = getTaskCountEstimateBasedOnNumPartitions(datastreamPartitions, taskCount); + } - // TODO Get unassigned partitions - // Calculating unassigned partitions - List unassignedPartitions = new ArrayList<>(); + LoadBasedTaskCountEstimator estimator = new LoadBasedTaskCountEstimator(_taskCapacityMBps, _taskCapacityUtilizationPct); + numTasksNeeded = Math.max(numTasksNeeded, estimator.getTaskCount(clusterThroughputInfo, assignedPartitions, + unassignedPartitions)); + // Task count is validated against max tasks config + numTasksNeeded = validateNumTasksAgainstMaxTasks(datastreamPartitions, numTasksNeeded); + if (numTasksNeeded > taskCount) { + updateNumTasksAndForceTaskCreation(datastreamPartitions, numTasksNeeded, taskCount); + } + } + + // TODO Implement metrics // Doing assignment - LoadBasedPartitionAssigner partitionAssigner = new LoadBasedPartitionAssigner(); - return partitionAssigner.assignPartitions(clusterThroughputInfo, currentAssignment, + Map> newAssignment = doAssignment(clusterThroughputInfo, currentAssignment, unassignedPartitions, datastreamPartitions); + partitionSanityChecks(newAssignment, datastreamPartitions); + return newAssignment; + } + + @VisibleForTesting + Map> doAssignment(ClusterThroughputInfo clusterThroughputInfo, + Map> currentAssignment, List unassignedPartitions, + DatastreamGroupPartitionsMetadata datastreamPartitions) { + LoadBasedPartitionAssigner partitionAssigner = new LoadBasedPartitionAssigner(); + Map> assignment = partitionAssigner.assignPartitions(clusterThroughputInfo, + currentAssignment, unassignedPartitions, datastreamPartitions); + LOG.info("new assignment info, assignment: {}", assignment); + return assignment; } - private Map fetchPartitionThroughputInfo() { + private ClusterThroughputInfo fetchPartitionThroughputInfo(DatastreamGroup datastreamGroup) { + AtomicInteger attemptNum = new AtomicInteger(0); return PollUtils.poll(() -> { try { - return _throughputProvider.getThroughputInfo(); + return _throughputProvider.getThroughputInfo(datastreamGroup); } catch (Exception ex) { - // TODO print exception and retry count - LOG.warn("Failed to fetch partition throughput info."); + attemptNum.set(attemptNum.get() + 1); + LOG.warn(String.format("Failed to fetch partition throughput info on attempt %d", attemptNum.get()), ex); return null; } }, Objects::nonNull, _throughputInfoFetchRetryPeriodMs, _throughputInfoFetchTimeoutMs) diff --git a/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/LoadBasedPartitionAssignmentStrategyFactory.java b/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/LoadBasedPartitionAssignmentStrategyFactory.java index cbe9fc209..106e5c715 100644 --- a/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/LoadBasedPartitionAssignmentStrategyFactory.java +++ b/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/LoadBasedPartitionAssignmentStrategyFactory.java @@ -12,8 +12,6 @@ import org.slf4j.LoggerFactory; import com.linkedin.datastream.common.zk.ZkClient; -import com.linkedin.datastream.server.DatastreamSourceClusterResolver; -import com.linkedin.datastream.server.DummyDatastreamSourceClusterResolver; import com.linkedin.datastream.server.api.strategy.AssignmentStrategy; import com.linkedin.datastream.server.providers.NoOpPartitionThroughputProvider; import com.linkedin.datastream.server.providers.PartitionThroughputProvider; @@ -40,9 +38,8 @@ public AssignmentStrategy createStrategy(Properties assignmentStrategyProperties } PartitionThroughputProvider provider = constructPartitionThroughputProvider(); - DatastreamSourceClusterResolver clusterResolver = constructDatastreamSourceClusterResolver(); - return new LoadBasedPartitionAssignmentStrategy(provider, clusterResolver, _config.getMaxTasks(), + return new LoadBasedPartitionAssignmentStrategy(provider, _config.getMaxTasks(), _config.getImbalanceThreshold(), _config.getMaxPartitions(), enableElasticTaskAssignment, _config.getPartitionsPerTask(), _config.getPartitionFullnessThresholdPct(), _config.getTaskCapacityMBps(), _config.getTaskCapacityUtilizationPct(), _config.getThroughputInfoFetchTimeoutMs(), @@ -52,8 +49,4 @@ public AssignmentStrategy createStrategy(Properties assignmentStrategyProperties protected PartitionThroughputProvider constructPartitionThroughputProvider() { return new NoOpPartitionThroughputProvider(); } - - protected DatastreamSourceClusterResolver constructDatastreamSourceClusterResolver() { - return new DummyDatastreamSourceClusterResolver(); - } } diff --git a/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/StickyPartitionAssignmentStrategy.java b/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/StickyPartitionAssignmentStrategy.java index 5ac6380ae..ae03dbfb0 100644 --- a/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/StickyPartitionAssignmentStrategy.java +++ b/datastream-server/src/main/java/com/linkedin/datastream/server/assignment/StickyPartitionAssignmentStrategy.java @@ -27,6 +27,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; + import com.linkedin.datastream.common.DatastreamRuntimeException; import com.linkedin.datastream.common.zk.ZkClient; import com.linkedin.datastream.metrics.BrooklinGaugeInfo; @@ -211,9 +213,13 @@ public Map> assignPartitions(Map 0, String.format("No tasks found for datastream group %s", dgName)); - if (getEnableElasticTaskAssignment(datastreamGroup)) { + if (isElasticTaskAssignmentEnabled(datastreamGroup)) { if (assignedPartitions.isEmpty()) { - performElasticTaskCountValidation(datastreamPartitions, totalTaskCount); + int numTasksNeeded = getTaskCountEstimateBasedOnNumPartitions(datastreamPartitions, totalTaskCount); + numTasksNeeded = validateNumTasksAgainstMaxTasks(datastreamPartitions, numTasksNeeded); + if (numTasksNeeded > totalTaskCount) { + updateNumTasksAndForceTaskCreation(datastreamPartitions, numTasksNeeded, totalTaskCount); + } } updateOrRegisterElasticTaskAssignmentMetrics(datastreamPartitions, totalTaskCount); } @@ -514,7 +520,7 @@ public List getMetricInfos() { @Override protected int constructExpectedNumberOfTasks(DatastreamGroup dg, List instances) { - boolean enableElasticTaskAssignment = getEnableElasticTaskAssignment(dg); + boolean enableElasticTaskAssignment = isElasticTaskAssignmentEnabled(dg); int numTasks = enableElasticTaskAssignment ? getNumTasksFromCacheOrZK(dg.getTaskPrefix()) : getNumTasks(dg, instances.size()); @@ -565,7 +571,7 @@ protected void updateOrRegisterElasticTaskAssignmentMetrics(DatastreamGroupParti _elasticTaskAssignmentInfoHashMap.put(taskPrefix, elasticTaskAssignmentInfo); } - protected void performElasticTaskCountValidation(DatastreamGroupPartitionsMetadata datastreamPartitions, + protected int getTaskCountEstimateBasedOnNumPartitions(DatastreamGroupPartitionsMetadata datastreamPartitions, int totalTaskCount) { // The partitions have not been assigned to any tasks yet and elastic task assignment has been enabled for this // datastream. Assess the number of tasks needed based on partitionsPerTask and the fullness threshold. If @@ -583,25 +589,35 @@ protected void performElasticTaskCountValidation(DatastreamGroupPartitionsMetada int totalPartitions = datastreamPartitions.getPartitions().size(); int numTasksNeeded = (totalPartitions / allowedPartitionsPerTask) + (((totalPartitions % allowedPartitionsPerTask) == 0) ? 0 : 1); + String dgName = datastreamPartitions.getDatastreamGroup().getName(); + LOG.info("Datastream group: {}, Number of tasks needed: {}, total task count: {}", dgName, numTasksNeeded, + totalTaskCount); + return numTasksNeeded; + } + + protected int validateNumTasksAgainstMaxTasks(DatastreamGroupPartitionsMetadata datastreamPartitions, int numTasks) { int maxTasks = resolveConfigWithMetadata(datastreamPartitions.getDatastreamGroup(), CFG_MAX_TASKS, 0); - if ((maxTasks > 0) && (numTasksNeeded > maxTasks)) { + if (maxTasks > 0 && numTasks > maxTasks) { // Only have the maxTasks override kick in if it's present as part of the datastream metadata. - LOG.warn("The number of tasks {} needed to support {} partitions per task with fullness threshold {} " - + "is higher than maxTasks {}, setting numTasks to maxTasks", numTasksNeeded, partitionsPerTask, - partitionFullnessFactorPct, maxTasks); - numTasksNeeded = maxTasks; - } - if (numTasksNeeded > totalTaskCount) { - createOrUpdateNumTasksForDatastreamInZK(datastreamPartitions.getDatastreamGroup().getTaskPrefix(), numTasksNeeded); - setTaskCountForDatastreamGroup(datastreamPartitions.getDatastreamGroup().getTaskPrefix(), numTasksNeeded); - throw new DatastreamRuntimeException( - String.format("Not enough tasks. Existing tasks: %d, tasks needed: %d, total partitions: %d", - totalTaskCount, numTasksNeeded, totalPartitions)); + LOG.warn("The number of tasks needed {} is higher than maxTasks {}. Setting numTasks to maxTasks", + numTasks, maxTasks); + return maxTasks; } - LOG.info("Number of tasks needed: {}, total task count: {}", numTasksNeeded, totalTaskCount); + return numTasks; + } + + protected void updateNumTasksAndForceTaskCreation(DatastreamGroupPartitionsMetadata datastreamPartitions, + int numTasksNeeded, int actualNumTasks) { + createOrUpdateNumTasksForDatastreamInZK(datastreamPartitions.getDatastreamGroup().getTaskPrefix(), numTasksNeeded); + setTaskCountForDatastreamGroup(datastreamPartitions.getDatastreamGroup().getTaskPrefix(), numTasksNeeded); + int totalPartitions = datastreamPartitions.getPartitions().size(); + throw new DatastreamRuntimeException( + String.format("Not enough tasks. Existing tasks: %d, tasks needed: %d, total partitions: %d", + actualNumTasks, numTasksNeeded, totalPartitions)); } - protected boolean getEnableElasticTaskAssignment(DatastreamGroup datastreamGroup) { + @VisibleForTesting + boolean isElasticTaskAssignmentEnabled(DatastreamGroup datastreamGroup) { // Enable elastic assignment only if the config enables it and the datastream metadata for minTasks is present // and is greater than 0 int minTasks = resolveConfigWithMetadata(datastreamGroup, CFG_MIN_TASKS, 0); @@ -619,7 +635,7 @@ private int getNumTasksFromCacheOrZK(String taskPrefix) { /** * check if the computed assignment contains all the partitions */ - private void partitionSanityChecks(Map> assignedTasks, + protected void partitionSanityChecks(Map> assignedTasks, DatastreamGroupPartitionsMetadata allPartitions) { int total = 0; diff --git a/datastream-server/src/main/java/com/linkedin/datastream/server/providers/FileBasedPartitionThroughputProvider.java b/datastream-server/src/main/java/com/linkedin/datastream/server/providers/FileBasedPartitionThroughputProvider.java index 3a82f4e02..f4580070f 100644 --- a/datastream-server/src/main/java/com/linkedin/datastream/server/providers/FileBasedPartitionThroughputProvider.java +++ b/datastream-server/src/main/java/com/linkedin/datastream/server/providers/FileBasedPartitionThroughputProvider.java @@ -12,12 +12,14 @@ import java.util.HashMap; import java.util.Iterator; +import org.apache.commons.lang.NotImplementedException; import org.apache.commons.lang3.StringUtils; import org.codehaus.jackson.JsonNode; import org.codehaus.jackson.map.ObjectMapper; import org.codehaus.jackson.type.TypeReference; import com.linkedin.datastream.server.ClusterThroughputInfo; +import com.linkedin.datastream.server.DatastreamGroup; import com.linkedin.datastream.server.PartitionThroughputInfo; @@ -49,6 +51,11 @@ public ClusterThroughputInfo getThroughputInfo(String clusterName) { return readThroughputInfoFromFile(partitionThroughputFile, clusterName); } + @Override + public ClusterThroughputInfo getThroughputInfo(DatastreamGroup datastreamGroup) { + throw new NotImplementedException(); + } + /** * {@inheritDoc} */ diff --git a/datastream-server/src/main/java/com/linkedin/datastream/server/providers/NoOpPartitionThroughputProvider.java b/datastream-server/src/main/java/com/linkedin/datastream/server/providers/NoOpPartitionThroughputProvider.java index dc3464632..c8f639d89 100644 --- a/datastream-server/src/main/java/com/linkedin/datastream/server/providers/NoOpPartitionThroughputProvider.java +++ b/datastream-server/src/main/java/com/linkedin/datastream/server/providers/NoOpPartitionThroughputProvider.java @@ -8,6 +8,7 @@ import java.util.HashMap; import com.linkedin.datastream.server.ClusterThroughputInfo; +import com.linkedin.datastream.server.DatastreamGroup; /** @@ -19,6 +20,11 @@ public ClusterThroughputInfo getThroughputInfo(String clusterName) { return null; } + @Override + public ClusterThroughputInfo getThroughputInfo(DatastreamGroup datastreamGroup) { + return null; + } + @Override public HashMap getThroughputInfo() { return null; diff --git a/datastream-server/src/test/java/com/linkedin/datastream/server/assignment/TestLoadBasedPartitionAssignmentStrategy.java b/datastream-server/src/test/java/com/linkedin/datastream/server/assignment/TestLoadBasedPartitionAssignmentStrategy.java new file mode 100644 index 000000000..2288a71e2 --- /dev/null +++ b/datastream-server/src/test/java/com/linkedin/datastream/server/assignment/TestLoadBasedPartitionAssignmentStrategy.java @@ -0,0 +1,267 @@ +/** + * Copyright 2021 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD 2-Clause License. See the LICENSE file in the project root for license information. + * See the NOTICE file in the project root for additional information regarding copyright ownership. + */ +package com.linkedin.datastream.server.assignment; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.apache.commons.lang3.StringUtils; +import org.mockito.Mockito; +import org.testng.Assert; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import com.codahale.metrics.MetricRegistry; + +import com.linkedin.datastream.common.Datastream; +import com.linkedin.datastream.common.DatastreamMetadataConstants; +import com.linkedin.datastream.common.DatastreamRuntimeException; +import com.linkedin.datastream.common.zk.ZkClient; +import com.linkedin.datastream.connectors.DummyConnector; +import com.linkedin.datastream.metrics.DynamicMetricsManager; +import com.linkedin.datastream.server.ClusterThroughputInfo; +import com.linkedin.datastream.server.DatastreamGroup; +import com.linkedin.datastream.server.DatastreamGroupPartitionsMetadata; +import com.linkedin.datastream.server.DatastreamTask; +import com.linkedin.datastream.server.DatastreamTaskImpl; +import com.linkedin.datastream.server.PartitionThroughputInfo; +import com.linkedin.datastream.server.providers.PartitionThroughputProvider; +import com.linkedin.datastream.server.zk.KeyBuilder; +import com.linkedin.datastream.server.zk.ZkAdapter; +import com.linkedin.datastream.testutil.DatastreamTestUtils; +import com.linkedin.datastream.testutil.EmbeddedZookeeper; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyObject; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; + + +/** + * Tests for {@link LoadBasedPartitionAssignmentStrategy} + */ +@Test +public class TestLoadBasedPartitionAssignmentStrategy { + private ZkClient _zkClient; + private String _clusterName; + + @BeforeMethod + public void setup() throws IOException { + DynamicMetricsManager.createInstance(new MetricRegistry(), "TestStickyPartitionAssignment"); + _clusterName = "testCluster"; + EmbeddedZookeeper embeddedZookeeper = new EmbeddedZookeeper(); + String zkConnectionString = embeddedZookeeper.getConnection(); + embeddedZookeeper.startup(); + _zkClient = new ZkClient(zkConnectionString); + } + + @Test + public void fallbackToBaseClassWhenElasticTaskAssignmentDisabledTest() { + PartitionThroughputProvider mockProvider = mock(PartitionThroughputProvider.class); + boolean enableElasticTaskAssignment = false; + Optional maxTasks = Optional.of(100); + Optional imbalanceThreshold = Optional.of(50); + Optional maxPartitionPerTask = Optional.of(100); + Optional partitionsPerTask = Optional.of(50); + Optional partitionFullnessFactorPct = Optional.of(80); + Optional taskCapacityMBps = Optional.of(5); + Optional taskCapacityUtilizationPct = Optional.of(90); + Optional throughputInfoFetchTimeoutMs = Optional.of(1000); + Optional throughputInfoFetchRetryPeriodMs = Optional.of(200); + Optional zkClient = Optional.empty(); + + LoadBasedPartitionAssignmentStrategy strategy = new LoadBasedPartitionAssignmentStrategy(mockProvider, + maxTasks, imbalanceThreshold, maxPartitionPerTask, enableElasticTaskAssignment, partitionsPerTask, + partitionFullnessFactorPct, taskCapacityMBps, taskCapacityUtilizationPct, throughputInfoFetchTimeoutMs, + throughputInfoFetchRetryPeriodMs, zkClient, _clusterName); + + Datastream ds1 = DatastreamTestUtils.createDatastreams(DummyConnector.CONNECTOR_TYPE, "ds1")[0]; + ds1.getSource().setPartitions(0); + ds1.getMetadata().put(DatastreamMetadataConstants.TASK_PREFIX, DatastreamTaskImpl.getTaskPrefix(ds1)); + Map> currentAssignment = new HashMap<>(); + currentAssignment.put("instance1", new HashSet<>(Collections.singletonList(createTaskForDatastream(ds1)))); + + DatastreamGroup datastreamGroup = new DatastreamGroup(Collections.singletonList(ds1)); + DatastreamGroupPartitionsMetadata metadata = new DatastreamGroupPartitionsMetadata(datastreamGroup, + Collections.singletonList("P1")); + strategy.assignPartitions(currentAssignment, metadata); + Assert.assertFalse(strategy.isElasticTaskAssignmentEnabled(datastreamGroup)); + + // Verify that partition throughput provider is not used when elastic task assignment is disabled + Mockito.verify(mockProvider, times(0)).getThroughputInfo(); + Mockito.verify(mockProvider, times(0)).getThroughputInfo(any(DatastreamGroup.class)); + Mockito.verify(mockProvider, times(0)).getThroughputInfo(any(String.class)); + } + + @Test + public void fallbackToBaseClassWhenThroughputFetchFailsTest() { + PartitionThroughputProvider mockProvider = mock(PartitionThroughputProvider.class); + Mockito.when(mockProvider.getThroughputInfo(any(DatastreamGroup.class))).thenThrow(new RuntimeException()); + boolean enableElasticTaskAssignment = true; + Optional maxTasks = Optional.of(100); + Optional imbalanceThreshold = Optional.of(50); + Optional maxPartitionPerTask = Optional.of(100); + Optional partitionsPerTask = Optional.of(50); + Optional partitionFullnessFactorPct = Optional.of(80); + Optional taskCapacityMBps = Optional.of(5); + Optional taskCapacityUtilizationPct = Optional.of(90); + Optional throughputInfoFetchTimeoutMs = Optional.of(1000); + Optional throughputInfoFetchRetryPeriodMs = Optional.of(200); + Optional zkClient = Optional.of(_zkClient); + + LoadBasedPartitionAssignmentStrategy strategy = Mockito.spy(new LoadBasedPartitionAssignmentStrategy(mockProvider, + maxTasks, imbalanceThreshold, maxPartitionPerTask, enableElasticTaskAssignment, partitionsPerTask, + partitionFullnessFactorPct, taskCapacityMBps, taskCapacityUtilizationPct, throughputInfoFetchTimeoutMs, + throughputInfoFetchRetryPeriodMs, zkClient, _clusterName)); + + Datastream ds1 = DatastreamTestUtils.createDatastreams(DummyConnector.CONNECTOR_TYPE, "ds1")[0]; + ds1.getSource().setPartitions(0); + ds1.getMetadata().put(DatastreamMetadataConstants.TASK_PREFIX, DatastreamTaskImpl.getTaskPrefix(ds1)); + ds1.getMetadata().put(StickyPartitionAssignmentStrategy.CFG_MIN_TASKS, String.valueOf(10)); + Map> currentAssignment = new HashMap<>(); + currentAssignment.put("instance1", new HashSet<>(Collections.singletonList(createTaskForDatastream(ds1)))); + + DatastreamGroup datastreamGroup = new DatastreamGroup(Collections.singletonList(ds1)); + DatastreamGroupPartitionsMetadata metadata = new DatastreamGroupPartitionsMetadata(datastreamGroup, + Collections.singletonList("P1")); + Assert.assertTrue(strategy.isElasticTaskAssignmentEnabled(datastreamGroup)); + Map> newAssignment = strategy.assignPartitions(currentAssignment, metadata); + + Mockito.verify(mockProvider, atLeastOnce()).getThroughputInfo(any(DatastreamGroup.class)); + Mockito.verify(strategy, never()).doAssignment(anyObject(), anyObject(), anyObject(), anyObject()); + Assert.assertNotNull(newAssignment); + } + + @Test + public void doesntFetchPartitionInfoOnIncrementalAssignmentTest() { + PartitionThroughputProvider mockProvider = mock(PartitionThroughputProvider.class); + boolean enableElasticTaskAssignment = true; + Optional maxTasks = Optional.of(100); + Optional imbalanceThreshold = Optional.of(50); + Optional maxPartitionPerTask = Optional.of(100); + Optional partitionsPerTask = Optional.of(50); + Optional partitionFullnessFactorPct = Optional.of(80); + Optional taskCapacityMBps = Optional.of(5); + Optional taskCapacityUtilizationPct = Optional.of(90); + Optional throughputInfoFetchTimeoutMs = Optional.of(1000); + Optional throughputInfoFetchRetryPeriodMs = Optional.of(200); + Optional zkClient = Optional.of(_zkClient); + + LoadBasedPartitionAssignmentStrategy strategy = new LoadBasedPartitionAssignmentStrategy(mockProvider, + maxTasks, imbalanceThreshold, maxPartitionPerTask, enableElasticTaskAssignment, partitionsPerTask, + partitionFullnessFactorPct, taskCapacityMBps, taskCapacityUtilizationPct, throughputInfoFetchTimeoutMs, + throughputInfoFetchRetryPeriodMs, zkClient, _clusterName); + + Datastream ds1 = DatastreamTestUtils.createDatastreams(DummyConnector.CONNECTOR_TYPE, "ds1")[0]; + ds1.getSource().setPartitions(0); + ds1.getMetadata().put(DatastreamMetadataConstants.TASK_PREFIX, DatastreamTaskImpl.getTaskPrefix(ds1)); + Map> currentAssignment = new HashMap<>(); + DatastreamTask task = createTaskForDatastream(ds1, Collections.singletonList("P1")); + currentAssignment.put("instance1", new HashSet<>(Collections.singletonList(task))); + + DatastreamGroupPartitionsMetadata metadata = new DatastreamGroupPartitionsMetadata(new DatastreamGroup( + Collections.singletonList(ds1)), Collections.singletonList("P2")); + strategy.assignPartitions(currentAssignment, metadata); + + // Verify that partition throughput provider is not used when the current assignment is not empty + Mockito.verify(mockProvider, times(0)).getThroughputInfo(); + Mockito.verify(mockProvider, times(0)).getThroughputInfo(any(DatastreamGroup.class)); + Mockito.verify(mockProvider, times(0)).getThroughputInfo(any(String.class)); + } + + @Test + public void updatesNumTasksAndThrowsExceptionWhenNoSufficientTasksTest() { + PartitionThroughputProvider mockProvider = mock(PartitionThroughputProvider.class); + Map partitionThroughputMap = new HashMap<>(); + partitionThroughputMap.put("P1", new PartitionThroughputInfo(100000, 0, "P1")); + partitionThroughputMap.put("P2", new PartitionThroughputInfo(100000, 0, "P2")); + partitionThroughputMap.put("P3", new PartitionThroughputInfo(100000, 0, "P3")); + ClusterThroughputInfo clusterThroughputInfo = new ClusterThroughputInfo(StringUtils.EMPTY, partitionThroughputMap); + Mockito.when(mockProvider.getThroughputInfo(any(DatastreamGroup.class))).thenReturn(clusterThroughputInfo); + boolean enableElasticTaskAssignment = true; + Optional maxTasks = Optional.of(100); + Optional imbalanceThreshold = Optional.of(50); + Optional maxPartitionPerTask = Optional.of(100); + Optional partitionsPerTask = Optional.of(50); + Optional partitionFullnessFactorPct = Optional.of(80); + Optional taskCapacityMBps = Optional.of(5); + Optional taskCapacityUtilizationPct = Optional.of(90); + Optional throughputInfoFetchTimeoutMs = Optional.of(1000); + Optional throughputInfoFetchRetryPeriodMs = Optional.of(200); + Optional zkClient = Optional.of(_zkClient); + + LoadBasedPartitionAssignmentStrategy strategy = new LoadBasedPartitionAssignmentStrategy(mockProvider, + maxTasks, imbalanceThreshold, maxPartitionPerTask, enableElasticTaskAssignment, partitionsPerTask, + partitionFullnessFactorPct, taskCapacityMBps, taskCapacityUtilizationPct, throughputInfoFetchTimeoutMs, + throughputInfoFetchRetryPeriodMs, zkClient, _clusterName); + + Datastream ds1 = DatastreamTestUtils.createDatastreams(DummyConnector.CONNECTOR_TYPE, "ds1")[0]; + ds1.getMetadata().put(StickyPartitionAssignmentStrategy.CFG_MIN_TASKS, String.valueOf(10)); + ds1.getMetadata().put(DatastreamMetadataConstants.TASK_PREFIX, DatastreamTaskImpl.getTaskPrefix(ds1)); + ds1.getSource().setPartitions(0); + String taskPrefix = DatastreamTaskImpl.getTaskPrefix(ds1); + ds1.getMetadata().put(DatastreamMetadataConstants.TASK_PREFIX, taskPrefix); + _zkClient.ensurePath(KeyBuilder.datastream(_clusterName, taskPrefix)); + Map> currentAssignment = new HashMap<>(); + currentAssignment.put("instance1", new HashSet<>(Collections.singletonList(createTaskForDatastream(ds1)))); + + DatastreamGroupPartitionsMetadata metadata = new DatastreamGroupPartitionsMetadata(new DatastreamGroup( + Collections.singletonList(ds1)), Arrays.asList("P1", "P2")); + Assert.expectThrows(DatastreamRuntimeException.class, () -> strategy.assignPartitions(currentAssignment, metadata)); + int numTasks = getNumTasksForDatastreamFromZK(taskPrefix); + Assert.assertEquals(numTasks, 2); + + // make sure throughput info is fetched + Mockito.verify(mockProvider, atLeastOnce()).getThroughputInfo(any(DatastreamGroup.class)); + + // test that strategy honors maxTasks config + Datastream ds2 = DatastreamTestUtils.createDatastreams(DummyConnector.CONNECTOR_TYPE, "ds2")[0]; + ds2.getMetadata().put(StickyPartitionAssignmentStrategy.CFG_MIN_TASKS, String.valueOf(1)); + ds2.getMetadata().put(BroadcastStrategyFactory.CFG_MAX_TASKS, String.valueOf(2)); + ds2.getMetadata().put(DatastreamMetadataConstants.TASK_PREFIX, DatastreamTaskImpl.getTaskPrefix(ds2)); + ds2.getSource().setPartitions(0); + String taskPrefix2 = DatastreamTaskImpl.getTaskPrefix(ds2); + ds2.getMetadata().put(DatastreamMetadataConstants.TASK_PREFIX, taskPrefix2); + _zkClient.ensurePath(KeyBuilder.datastream(_clusterName, taskPrefix2)); + Map> currentAssignment2 = new HashMap<>(); + currentAssignment2.put("instance1", new HashSet<>(Collections.singletonList(createTaskForDatastream(ds2)))); + + DatastreamGroupPartitionsMetadata metadata2 = new DatastreamGroupPartitionsMetadata(new DatastreamGroup( + Collections.singletonList(ds2)), Arrays.asList("P1", "P2", "P3")); + Assert.expectThrows(DatastreamRuntimeException.class, () -> strategy.assignPartitions(currentAssignment2, metadata2)); + int numTasks2 = getNumTasksForDatastreamFromZK(taskPrefix2); + // updated numTasks must be no bigger than 2 + Assert.assertEquals(numTasks2, 2); + } + + private DatastreamTask createTaskForDatastream(Datastream datastream) { + return createTaskForDatastream(datastream, Collections.emptyList()); + } + + private DatastreamTask createTaskForDatastream(Datastream datastream, List partitions) { + DatastreamTaskImpl task = new DatastreamTaskImpl(Collections.singletonList(datastream)); + task.setPartitionsV2(partitions); + ZkAdapter mockAdapter = Mockito.mock(ZkAdapter.class); + Mockito.when(mockAdapter.checkIsTaskLocked(anyString(), anyString(), anyString())).thenReturn(true); + task.setZkAdapter(mockAdapter); + return task; + } + + private int getNumTasksForDatastreamFromZK(String taskPrefix) { + String numTasksPath = KeyBuilder.datastreamNumTasks(_clusterName, taskPrefix); + return Integer.parseInt(_zkClient.readData(numTasksPath)); + } +}