mirror of
https://github.com/Jozufozu/Flywheel.git
synced 2025-01-27 21:37:56 +01:00
Needs to wait
- Commit to non-blocking waitgroup impl - Debug log when await takes suspiciously long
This commit is contained in:
parent
d7f8c9fcea
commit
494c5a68e0
2 changed files with 30 additions and 39 deletions
|
@ -8,6 +8,7 @@ import java.util.concurrent.ConcurrentLinkedDeque;
|
||||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.jetbrains.annotations.Nullable;
|
import org.jetbrains.annotations.Nullable;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
|
||||||
|
@ -99,10 +100,12 @@ public class ParallelTaskExecutor implements TaskExecutor {
|
||||||
|
|
||||||
threads.clear();
|
threads.clear();
|
||||||
taskQueue.clear();
|
taskQueue.clear();
|
||||||
|
mainThreadQueue.clear();
|
||||||
|
waitGroup._reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void execute(Runnable task) {
|
public void execute(@NotNull Runnable task) {
|
||||||
if (!running.get()) {
|
if (!running.get()) {
|
||||||
throw new IllegalStateException("Executor is stopped");
|
throw new IllegalStateException("Executor is stopped");
|
||||||
}
|
}
|
||||||
|
@ -130,39 +133,35 @@ public class ParallelTaskExecutor implements TaskExecutor {
|
||||||
@Override
|
@Override
|
||||||
public void syncPoint() {
|
public void syncPoint() {
|
||||||
Runnable task;
|
Runnable task;
|
||||||
// Finish everyone else's work...
|
|
||||||
while (true) {
|
while (true) {
|
||||||
if ((task = mainThreadQueue.poll()) != null) {
|
if ((task = mainThreadQueue.poll()) != null) {
|
||||||
|
// Prioritize main thread tasks.
|
||||||
processMainThreadTask(task);
|
processMainThreadTask(task);
|
||||||
} else if ((task = taskQueue.pollLast()) != null) {
|
} else if ((task = taskQueue.pollLast()) != null) {
|
||||||
|
// then work on tasks from the queue.
|
||||||
processTask(task);
|
processTask(task);
|
||||||
} else {
|
} else {
|
||||||
// and wait for any stragglers.
|
// then wait for the other threads to finish.
|
||||||
waitGroup.await();
|
waitGroup.await();
|
||||||
|
// at this point there will be no more tasks in the queue, but
|
||||||
|
// one of the worker threads may have submitted a main thread task.
|
||||||
if (mainThreadQueue.isEmpty()) {
|
if (mainThreadQueue.isEmpty()) {
|
||||||
|
// if they didn't, we're done.
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
|
||||||
private Runnable pollForSyncPoint() {
|
|
||||||
Runnable task = mainThreadQueue.poll();
|
|
||||||
if (task != null) {
|
|
||||||
return task;
|
|
||||||
}
|
|
||||||
return taskQueue.pollLast();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void discardAndAwait() {
|
public void discardAndAwait() {
|
||||||
// Discard everyone else's work...
|
// Discard everyone else's work...
|
||||||
while (taskQueue.pollLast() != null) {
|
while (taskQueue.pollLast() != null) {
|
||||||
waitGroup.done();
|
waitGroup.done();
|
||||||
}
|
}
|
||||||
|
|
||||||
// and wait for any stragglers.
|
// ...wait for any stragglers...
|
||||||
waitGroup.await();
|
waitGroup.await();
|
||||||
|
// ...and clear the main thread queue.
|
||||||
mainThreadQueue.clear();
|
mainThreadQueue.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,7 +182,6 @@ public class ParallelTaskExecutor implements TaskExecutor {
|
||||||
return task;
|
return task;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: task context
|
|
||||||
private void processTask(Runnable task) {
|
private void processTask(Runnable task) {
|
||||||
try {
|
try {
|
||||||
task.run();
|
task.run();
|
||||||
|
|
|
@ -2,8 +2,14 @@ package com.jozufozu.flywheel.lib.task;
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
// https://stackoverflow.com/questions/29655531
|
import org.slf4j.Logger;
|
||||||
|
|
||||||
|
import com.jozufozu.flywheel.util.StringUtil;
|
||||||
|
import com.mojang.logging.LogUtils;
|
||||||
|
|
||||||
public class WaitGroup {
|
public class WaitGroup {
|
||||||
|
private static final Logger LOGGER = LogUtils.getLogger();
|
||||||
|
|
||||||
private final AtomicInteger counter = new AtomicInteger(0);
|
private final AtomicInteger counter = new AtomicInteger(0);
|
||||||
|
|
||||||
public void add() {
|
public void add() {
|
||||||
|
@ -19,38 +25,25 @@ public class WaitGroup {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void done() {
|
public void done() {
|
||||||
var result = counter.decrementAndGet();
|
if (counter.decrementAndGet() < 0) {
|
||||||
if (result == 0) {
|
|
||||||
synchronized (this) {
|
|
||||||
this.notifyAll();
|
|
||||||
}
|
|
||||||
} else if (result < 0) {
|
|
||||||
throw new IllegalStateException("WaitGroup counter is negative!");
|
throw new IllegalStateException("WaitGroup counter is negative!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void await() {
|
public void await() {
|
||||||
try {
|
// TODO: comprehensive performance tracking for tasks
|
||||||
awaitInternal();
|
long start = System.nanoTime();
|
||||||
} catch (InterruptedException ignored) {
|
int count = 0;
|
||||||
// noop
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void awaitInternal() throws InterruptedException {
|
|
||||||
// var start = System.nanoTime();
|
|
||||||
while (counter.get() > 0) {
|
while (counter.get() > 0) {
|
||||||
// spin in place to avoid sleeping the main thread
|
// spin in place to avoid sleeping the main thread
|
||||||
// synchronized (this) {
|
count++;
|
||||||
// this.wait(timeoutMs);
|
}
|
||||||
// }
|
long end = System.nanoTime();
|
||||||
|
long elapsed = end - start;
|
||||||
|
|
||||||
|
if (elapsed > 1000000) { // > 1ms
|
||||||
|
LOGGER.debug("Waited " + StringUtil.formatTime(elapsed) + ", looped " + count + " times");
|
||||||
}
|
}
|
||||||
// var end = System.nanoTime();
|
|
||||||
// var elapsed = end - start;
|
|
||||||
//
|
|
||||||
// if (elapsed > 1000000) {
|
|
||||||
// Flywheel.LOGGER.info("Waited " + StringUtil.formatTime(elapsed));
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void _reset() {
|
public void _reset() {
|
||||||
|
|
Loading…
Reference in a new issue