Needs to wait

- Commit to non-blocking waitgroup impl
- Debug log when await takes suspiciously long
This commit is contained in:
Jozufozu 2023-04-14 17:14:12 -07:00
parent d7f8c9fcea
commit 494c5a68e0
2 changed files with 30 additions and 39 deletions

View file

@ -8,6 +8,7 @@ import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
@ -99,10 +100,12 @@ public class ParallelTaskExecutor implements TaskExecutor {
threads.clear();
taskQueue.clear();
mainThreadQueue.clear();
waitGroup._reset();
}
@Override
public void execute(Runnable task) {
public void execute(@NotNull Runnable task) {
if (!running.get()) {
throw new IllegalStateException("Executor is stopped");
}
@ -130,39 +133,35 @@ public class ParallelTaskExecutor implements TaskExecutor {
@Override
public void syncPoint() {
Runnable task;
// Finish everyone else's work...
while (true) {
if ((task = mainThreadQueue.poll()) != null) {
// Prioritize main thread tasks.
processMainThreadTask(task);
} else if ((task = taskQueue.pollLast()) != null) {
// then work on tasks from the queue.
processTask(task);
} else {
// and wait for any stragglers.
// then wait for the other threads to finish.
waitGroup.await();
// at this point there will be no more tasks in the queue, but
// one of the worker threads may have submitted a main thread task.
if (mainThreadQueue.isEmpty()) {
// if they didn't, we're done.
break;
}
}
}
}
@Nullable
private Runnable pollForSyncPoint() {
Runnable task = mainThreadQueue.poll();
if (task != null) {
return task;
}
return taskQueue.pollLast();
}
public void discardAndAwait() {
// Discard everyone else's work...
while (taskQueue.pollLast() != null) {
waitGroup.done();
}
// and wait for any stragglers.
// ...wait for any stragglers...
waitGroup.await();
// ...and clear the main thread queue.
mainThreadQueue.clear();
}
@ -183,7 +182,6 @@ public class ParallelTaskExecutor implements TaskExecutor {
return task;
}
// TODO: task context
private void processTask(Runnable task) {
try {
task.run();

View file

@ -2,8 +2,14 @@ package com.jozufozu.flywheel.lib.task;
import java.util.concurrent.atomic.AtomicInteger;
// https://stackoverflow.com/questions/29655531
import org.slf4j.Logger;
import com.jozufozu.flywheel.util.StringUtil;
import com.mojang.logging.LogUtils;
public class WaitGroup {
private static final Logger LOGGER = LogUtils.getLogger();
private final AtomicInteger counter = new AtomicInteger(0);
public void add() {
@ -19,38 +25,25 @@ public class WaitGroup {
}
public void done() {
var result = counter.decrementAndGet();
if (result == 0) {
synchronized (this) {
this.notifyAll();
}
} else if (result < 0) {
if (counter.decrementAndGet() < 0) {
throw new IllegalStateException("WaitGroup counter is negative!");
}
}
public void await() {
try {
awaitInternal();
} catch (InterruptedException ignored) {
// noop
}
}
private void awaitInternal() throws InterruptedException {
// var start = System.nanoTime();
// TODO: comprehensive performance tracking for tasks
long start = System.nanoTime();
int count = 0;
while (counter.get() > 0) {
// spin in place to avoid sleeping the main thread
// synchronized (this) {
// this.wait(timeoutMs);
// }
count++;
}
long end = System.nanoTime();
long elapsed = end - start;
if (elapsed > 1000000) { // > 1ms
LOGGER.debug("Waited " + StringUtil.formatTime(elapsed) + ", looped " + count + " times");
}
// var end = System.nanoTime();
// var elapsed = end - start;
//
// if (elapsed > 1000000) {
// Flywheel.LOGGER.info("Waited " + StringUtil.formatTime(elapsed));
// }
}
public void _reset() {