From 494c5a68e05c23c6cad9789388c1cade49ff1ac9 Mon Sep 17 00:00:00 2001 From: Jozufozu Date: Fri, 14 Apr 2023 17:14:12 -0700 Subject: [PATCH] Needs to wait - Commit to non-blocking waitgroup impl - Debug log when await takes suspiciously long --- .../backend/task/ParallelTaskExecutor.java | 26 ++++++----- .../jozufozu/flywheel/lib/task/WaitGroup.java | 43 ++++++++----------- 2 files changed, 30 insertions(+), 39 deletions(-) diff --git a/src/main/java/com/jozufozu/flywheel/backend/task/ParallelTaskExecutor.java b/src/main/java/com/jozufozu/flywheel/backend/task/ParallelTaskExecutor.java index 9de43e445..7e7e16d68 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/task/ParallelTaskExecutor.java +++ b/src/main/java/com/jozufozu/flywheel/backend/task/ParallelTaskExecutor.java @@ -8,6 +8,7 @@ import java.util.concurrent.ConcurrentLinkedDeque; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicBoolean; +import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import org.slf4j.Logger; @@ -99,10 +100,12 @@ public class ParallelTaskExecutor implements TaskExecutor { threads.clear(); taskQueue.clear(); + mainThreadQueue.clear(); + waitGroup._reset(); } @Override - public void execute(Runnable task) { + public void execute(@NotNull Runnable task) { if (!running.get()) { throw new IllegalStateException("Executor is stopped"); } @@ -130,39 +133,35 @@ public class ParallelTaskExecutor implements TaskExecutor { @Override public void syncPoint() { Runnable task; - // Finish everyone else's work... while (true) { if ((task = mainThreadQueue.poll()) != null) { + // Prioritize main thread tasks. processMainThreadTask(task); } else if ((task = taskQueue.pollLast()) != null) { + // then work on tasks from the queue. processTask(task); } else { - // and wait for any stragglers. + // then wait for the other threads to finish. waitGroup.await(); + // at this point there will be no more tasks in the queue, but + // one of the worker threads may have submitted a main thread task. if (mainThreadQueue.isEmpty()) { + // if they didn't, we're done. break; } } } } - @Nullable - private Runnable pollForSyncPoint() { - Runnable task = mainThreadQueue.poll(); - if (task != null) { - return task; - } - return taskQueue.pollLast(); - } - public void discardAndAwait() { // Discard everyone else's work... while (taskQueue.pollLast() != null) { waitGroup.done(); } - // and wait for any stragglers. + // ...wait for any stragglers... waitGroup.await(); + // ...and clear the main thread queue. mainThreadQueue.clear(); } @@ -183,7 +182,6 @@ public class ParallelTaskExecutor implements TaskExecutor { return task; } - // TODO: task context private void processTask(Runnable task) { try { task.run(); diff --git a/src/main/java/com/jozufozu/flywheel/lib/task/WaitGroup.java b/src/main/java/com/jozufozu/flywheel/lib/task/WaitGroup.java index 28ce14729..0d34ab324 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/task/WaitGroup.java +++ b/src/main/java/com/jozufozu/flywheel/lib/task/WaitGroup.java @@ -2,8 +2,14 @@ package com.jozufozu.flywheel.lib.task; import java.util.concurrent.atomic.AtomicInteger; -// https://stackoverflow.com/questions/29655531 +import org.slf4j.Logger; + +import com.jozufozu.flywheel.util.StringUtil; +import com.mojang.logging.LogUtils; + public class WaitGroup { + private static final Logger LOGGER = LogUtils.getLogger(); + private final AtomicInteger counter = new AtomicInteger(0); public void add() { @@ -19,38 +25,25 @@ public class WaitGroup { } public void done() { - var result = counter.decrementAndGet(); - if (result == 0) { - synchronized (this) { - this.notifyAll(); - } - } else if (result < 0) { + if (counter.decrementAndGet() < 0) { throw new IllegalStateException("WaitGroup counter is negative!"); } } public void await() { - try { - awaitInternal(); - } catch (InterruptedException ignored) { - // noop - } - } - - private void awaitInternal() throws InterruptedException { - // var start = System.nanoTime(); + // TODO: comprehensive performance tracking for tasks + long start = System.nanoTime(); + int count = 0; while (counter.get() > 0) { // spin in place to avoid sleeping the main thread - // synchronized (this) { - // this.wait(timeoutMs); - // } + count++; + } + long end = System.nanoTime(); + long elapsed = end - start; + + if (elapsed > 1000000) { // > 1ms + LOGGER.debug("Waited " + StringUtil.formatTime(elapsed) + ", looped " + count + " times"); } - // var end = System.nanoTime(); - // var elapsed = end - start; - // - // if (elapsed > 1000000) { - // Flywheel.LOGGER.info("Waited " + StringUtil.formatTime(elapsed)); - // } } public void _reset() {