Needs to wait

- Commit to non-blocking waitgroup impl
- Debug log when await takes suspiciously long
This commit is contained in:
Jozufozu 2023-04-14 17:14:12 -07:00
parent f29dcbc486
commit 0861d8bfd2
2 changed files with 30 additions and 39 deletions

View file

@ -8,6 +8,7 @@ import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -99,10 +100,12 @@ public class ParallelTaskExecutor implements TaskExecutor {
threads.clear(); threads.clear();
taskQueue.clear(); taskQueue.clear();
mainThreadQueue.clear();
waitGroup._reset();
} }
@Override @Override
public void execute(Runnable task) { public void execute(@NotNull Runnable task) {
if (!running.get()) { if (!running.get()) {
throw new IllegalStateException("Executor is stopped"); throw new IllegalStateException("Executor is stopped");
} }
@ -130,39 +133,35 @@ public class ParallelTaskExecutor implements TaskExecutor {
@Override @Override
public void syncPoint() { public void syncPoint() {
Runnable task; Runnable task;
// Finish everyone else's work...
while (true) { while (true) {
if ((task = mainThreadQueue.poll()) != null) { if ((task = mainThreadQueue.poll()) != null) {
// Prioritize main thread tasks.
processMainThreadTask(task); processMainThreadTask(task);
} else if ((task = taskQueue.pollLast()) != null) { } else if ((task = taskQueue.pollLast()) != null) {
// then work on tasks from the queue.
processTask(task); processTask(task);
} else { } else {
// and wait for any stragglers. // then wait for the other threads to finish.
waitGroup.await(); waitGroup.await();
// at this point there will be no more tasks in the queue, but
// one of the worker threads may have submitted a main thread task.
if (mainThreadQueue.isEmpty()) { if (mainThreadQueue.isEmpty()) {
// if they didn't, we're done.
break; break;
} }
} }
} }
} }
@Nullable
private Runnable pollForSyncPoint() {
Runnable task = mainThreadQueue.poll();
if (task != null) {
return task;
}
return taskQueue.pollLast();
}
public void discardAndAwait() { public void discardAndAwait() {
// Discard everyone else's work... // Discard everyone else's work...
while (taskQueue.pollLast() != null) { while (taskQueue.pollLast() != null) {
waitGroup.done(); waitGroup.done();
} }
// and wait for any stragglers. // ...wait for any stragglers...
waitGroup.await(); waitGroup.await();
// ...and clear the main thread queue.
mainThreadQueue.clear(); mainThreadQueue.clear();
} }
@ -183,7 +182,6 @@ public class ParallelTaskExecutor implements TaskExecutor {
return task; return task;
} }
// TODO: task context
private void processTask(Runnable task) { private void processTask(Runnable task) {
try { try {
task.run(); task.run();

View file

@ -2,8 +2,14 @@ package com.jozufozu.flywheel.lib.task;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
// https://stackoverflow.com/questions/29655531 import org.slf4j.Logger;
import com.jozufozu.flywheel.util.StringUtil;
import com.mojang.logging.LogUtils;
public class WaitGroup { public class WaitGroup {
private static final Logger LOGGER = LogUtils.getLogger();
private final AtomicInteger counter = new AtomicInteger(0); private final AtomicInteger counter = new AtomicInteger(0);
public void add() { public void add() {
@ -19,38 +25,25 @@ public class WaitGroup {
} }
public void done() { public void done() {
var result = counter.decrementAndGet(); if (counter.decrementAndGet() < 0) {
if (result == 0) {
synchronized (this) {
this.notifyAll();
}
} else if (result < 0) {
throw new IllegalStateException("WaitGroup counter is negative!"); throw new IllegalStateException("WaitGroup counter is negative!");
} }
} }
public void await() { public void await() {
try { // TODO: comprehensive performance tracking for tasks
awaitInternal(); long start = System.nanoTime();
} catch (InterruptedException ignored) { int count = 0;
// noop
}
}
private void awaitInternal() throws InterruptedException {
// var start = System.nanoTime();
while (counter.get() > 0) { while (counter.get() > 0) {
// spin in place to avoid sleeping the main thread // spin in place to avoid sleeping the main thread
// synchronized (this) { count++;
// this.wait(timeoutMs); }
// } long end = System.nanoTime();
long elapsed = end - start;
if (elapsed > 1000000) { // > 1ms
LOGGER.debug("Waited " + StringUtil.formatTime(elapsed) + ", looped " + count + " times");
} }
// var end = System.nanoTime();
// var elapsed = end - start;
//
// if (elapsed > 1000000) {
// Flywheel.LOGGER.info("Waited " + StringUtil.formatTime(elapsed));
// }
} }
public void _reset() { public void _reset() {