mirror of
https://github.com/Jozufozu/Flywheel.git
synced 2025-01-07 12:56:31 +01:00
Needs to wait
- Commit to non-blocking waitgroup impl - Debug log when await takes suspiciously long
This commit is contained in:
parent
f29dcbc486
commit
0861d8bfd2
2 changed files with 30 additions and 39 deletions
|
@ -8,6 +8,7 @@ import java.util.concurrent.ConcurrentLinkedDeque;
|
|||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.slf4j.Logger;
|
||||
|
||||
|
@ -99,10 +100,12 @@ public class ParallelTaskExecutor implements TaskExecutor {
|
|||
|
||||
threads.clear();
|
||||
taskQueue.clear();
|
||||
mainThreadQueue.clear();
|
||||
waitGroup._reset();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute(Runnable task) {
|
||||
public void execute(@NotNull Runnable task) {
|
||||
if (!running.get()) {
|
||||
throw new IllegalStateException("Executor is stopped");
|
||||
}
|
||||
|
@ -130,39 +133,35 @@ public class ParallelTaskExecutor implements TaskExecutor {
|
|||
@Override
|
||||
public void syncPoint() {
|
||||
Runnable task;
|
||||
// Finish everyone else's work...
|
||||
while (true) {
|
||||
if ((task = mainThreadQueue.poll()) != null) {
|
||||
// Prioritize main thread tasks.
|
||||
processMainThreadTask(task);
|
||||
} else if ((task = taskQueue.pollLast()) != null) {
|
||||
// then work on tasks from the queue.
|
||||
processTask(task);
|
||||
} else {
|
||||
// and wait for any stragglers.
|
||||
// then wait for the other threads to finish.
|
||||
waitGroup.await();
|
||||
// at this point there will be no more tasks in the queue, but
|
||||
// one of the worker threads may have submitted a main thread task.
|
||||
if (mainThreadQueue.isEmpty()) {
|
||||
// if they didn't, we're done.
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private Runnable pollForSyncPoint() {
|
||||
Runnable task = mainThreadQueue.poll();
|
||||
if (task != null) {
|
||||
return task;
|
||||
}
|
||||
return taskQueue.pollLast();
|
||||
}
|
||||
|
||||
public void discardAndAwait() {
|
||||
// Discard everyone else's work...
|
||||
while (taskQueue.pollLast() != null) {
|
||||
waitGroup.done();
|
||||
}
|
||||
|
||||
// and wait for any stragglers.
|
||||
// ...wait for any stragglers...
|
||||
waitGroup.await();
|
||||
// ...and clear the main thread queue.
|
||||
mainThreadQueue.clear();
|
||||
}
|
||||
|
||||
|
@ -183,7 +182,6 @@ public class ParallelTaskExecutor implements TaskExecutor {
|
|||
return task;
|
||||
}
|
||||
|
||||
// TODO: task context
|
||||
private void processTask(Runnable task) {
|
||||
try {
|
||||
task.run();
|
||||
|
|
|
@ -2,8 +2,14 @@ package com.jozufozu.flywheel.lib.task;
|
|||
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
// https://stackoverflow.com/questions/29655531
|
||||
import org.slf4j.Logger;
|
||||
|
||||
import com.jozufozu.flywheel.util.StringUtil;
|
||||
import com.mojang.logging.LogUtils;
|
||||
|
||||
public class WaitGroup {
|
||||
private static final Logger LOGGER = LogUtils.getLogger();
|
||||
|
||||
private final AtomicInteger counter = new AtomicInteger(0);
|
||||
|
||||
public void add() {
|
||||
|
@ -19,38 +25,25 @@ public class WaitGroup {
|
|||
}
|
||||
|
||||
public void done() {
|
||||
var result = counter.decrementAndGet();
|
||||
if (result == 0) {
|
||||
synchronized (this) {
|
||||
this.notifyAll();
|
||||
}
|
||||
} else if (result < 0) {
|
||||
if (counter.decrementAndGet() < 0) {
|
||||
throw new IllegalStateException("WaitGroup counter is negative!");
|
||||
}
|
||||
}
|
||||
|
||||
public void await() {
|
||||
try {
|
||||
awaitInternal();
|
||||
} catch (InterruptedException ignored) {
|
||||
// noop
|
||||
}
|
||||
}
|
||||
|
||||
private void awaitInternal() throws InterruptedException {
|
||||
// var start = System.nanoTime();
|
||||
// TODO: comprehensive performance tracking for tasks
|
||||
long start = System.nanoTime();
|
||||
int count = 0;
|
||||
while (counter.get() > 0) {
|
||||
// spin in place to avoid sleeping the main thread
|
||||
// synchronized (this) {
|
||||
// this.wait(timeoutMs);
|
||||
// }
|
||||
count++;
|
||||
}
|
||||
long end = System.nanoTime();
|
||||
long elapsed = end - start;
|
||||
|
||||
if (elapsed > 1000000) { // > 1ms
|
||||
LOGGER.debug("Waited " + StringUtil.formatTime(elapsed) + ", looped " + count + " times");
|
||||
}
|
||||
// var end = System.nanoTime();
|
||||
// var elapsed = end - start;
|
||||
//
|
||||
// if (elapsed > 1000000) {
|
||||
// Flywheel.LOGGER.info("Waited " + StringUtil.formatTime(elapsed));
|
||||
// }
|
||||
}
|
||||
|
||||
public void _reset() {
|
||||
|
|
Loading…
Reference in a new issue