Weaving threads

- Restore WaitGroup abstraction
- Simplify WaitGroup to make proper use of atomics
- Fix logic error in ParallelTaskExecutor causing task counter to go
  below zero when executing main thread tasks
- Use ConcurrentHashMap in models to allow parallel access
- Reduce chunk size in RunOnAllPlan
- Only queue in InstanceManager
- Defer syncPoint within renderStage
- AbstractStorage return immutable view for tickable/dynamic instances
- Process InstanceManager queues off-thread
This commit is contained in:
Jozufozu 2023-04-13 15:17:02 -07:00
parent 76dcf3fed6
commit f29dcbc486
25 changed files with 241 additions and 121 deletions

View file

@ -1,6 +1,7 @@
package com.jozufozu.flywheel.api.task;
import com.jozufozu.flywheel.lib.task.BarrierPlan;
import com.jozufozu.flywheel.lib.task.NestedPlan;
public interface Plan {
/**
@ -22,13 +23,23 @@ public interface Plan {
* Create a new plan that executes this plan, then the given plan.
*
* @param plan The plan to execute after this plan.
* @return The new, composed plan.
* @return The composed plan.
*/
default Plan then(Plan plan) {
// TODO: AbstractPlan?
return new BarrierPlan(this, plan);
}
/**
* Create a new plan that executes this plan and the given plan in parallel.
*
* @param plan The plan to execute in parallel with this plan.
* @return The composed plan.
*/
default Plan and(Plan plan) {
return NestedPlan.of(this, plan);
}
/**
* If possible, create a new plan that accomplishes everything
* this plan does but with a simpler execution schedule.

View file

@ -81,4 +81,9 @@ public class BatchingDrawTracker {
buffers.clear();
}
}
public boolean hasStage(RenderStage stage) {
return !activeBuffers.get(stage)
.isEmpty();
}
}

View file

@ -82,6 +82,10 @@ public class BatchingEngine implements Engine {
@Override
public void renderStage(TaskExecutor executor, RenderContext context, RenderStage stage) {
if (!drawTracker.hasStage(stage)) {
return;
}
executor.syncPoint();
drawTracker.draw(stage);
}

View file

@ -185,4 +185,7 @@ public class IndirectCullingGroup<P extends InstancePart> {
meshPool.delete();
}
public boolean hasStage(RenderStage stage) {
return drawSet.contains(stage);
}
}

View file

@ -73,6 +73,15 @@ public class IndirectDrawManager {
initializedInstancers.add(instancer);
}
public boolean hasStage(RenderStage stage) {
for (var list : renderLists.values()) {
if (list.hasStage(stage)) {
return true;
}
}
return false;
}
private record UninitializedInstancer(IndirectInstancer<?> instancer, Model model, RenderStage stage) {
}
}

View file

@ -53,6 +53,12 @@ public class IndirectEngine implements Engine {
@Override
public void renderStage(TaskExecutor executor, RenderContext context, RenderStage stage) {
if (!drawManager.hasStage(stage)) {
return;
}
executor.syncPoint();
try (var restoreState = GlStateTracker.getRestoreState()) {
setup();

View file

@ -66,6 +66,8 @@ public class InstancingEngine implements Engine {
return;
}
executor.syncPoint();
try (var state = GlStateTracker.getRestoreState()) {
setup();

View file

@ -13,6 +13,7 @@ import org.slf4j.Logger;
import com.jozufozu.flywheel.Flywheel;
import com.jozufozu.flywheel.api.task.TaskExecutor;
import com.jozufozu.flywheel.lib.task.WaitGroup;
import com.mojang.logging.LogUtils;
import net.minecraft.util.Mth;
@ -29,17 +30,13 @@ public class ParallelTaskExecutor implements TaskExecutor {
* If set to false, the executor will shut down.
*/
private final AtomicBoolean running = new AtomicBoolean(false);
/**
* Synchronized via {@link #tasksCompletedNotifier}.
*/
private int incompleteTaskCounter = 0;
private final List<WorkerThread> threads = new ArrayList<>();
private final Deque<Runnable> taskQueue = new ConcurrentLinkedDeque<>();
private final Queue<Runnable> mainThreadQueue = new ConcurrentLinkedQueue<>();
private final Object taskNotifier = new Object();
private final Object tasksCompletedNotifier = new Object();
private final WaitGroup waitGroup = new WaitGroup();
public ParallelTaskExecutor(String name) {
this.name = name;
@ -102,10 +99,6 @@ public class ParallelTaskExecutor implements TaskExecutor {
threads.clear();
taskQueue.clear();
synchronized (tasksCompletedNotifier) {
incompleteTaskCounter = 0;
tasksCompletedNotifier.notifyAll();
}
}
@Override
@ -114,12 +107,9 @@ public class ParallelTaskExecutor implements TaskExecutor {
throw new IllegalStateException("Executor is stopped");
}
waitGroup.add();
taskQueue.add(task);
synchronized (tasksCompletedNotifier) {
incompleteTaskCounter++;
}
synchronized (taskNotifier) {
taskNotifier.notifyAll();
}
@ -140,19 +130,17 @@ public class ParallelTaskExecutor implements TaskExecutor {
@Override
public void syncPoint() {
Runnable task;
// Finish everyone else's work...
while ((task = pollForSyncPoint()) != null) {
processTask(task);
}
// and wait for any stragglers.
synchronized (tasksCompletedNotifier) {
while (incompleteTaskCounter > 0) {
try {
tasksCompletedNotifier.wait();
} catch (InterruptedException e) {
//
while (true) {
if ((task = mainThreadQueue.poll()) != null) {
processMainThreadTask(task);
} else if ((task = taskQueue.pollLast()) != null) {
processTask(task);
} else {
// and wait for any stragglers.
waitGroup.await();
if (mainThreadQueue.isEmpty()) {
break;
}
}
}
@ -170,23 +158,12 @@ public class ParallelTaskExecutor implements TaskExecutor {
public void discardAndAwait() {
// Discard everyone else's work...
while (taskQueue.pollLast() != null) {
synchronized (tasksCompletedNotifier) {
if (--incompleteTaskCounter == 0) {
tasksCompletedNotifier.notifyAll();
}
}
waitGroup.done();
}
// and wait for any stragglers.
synchronized (tasksCompletedNotifier) {
while (incompleteTaskCounter > 0) {
try {
tasksCompletedNotifier.wait();
} catch (InterruptedException e) {
//
}
}
}
waitGroup.await();
mainThreadQueue.clear();
}
@Nullable
@ -213,11 +190,15 @@ public class ParallelTaskExecutor implements TaskExecutor {
} catch (Exception e) {
Flywheel.LOGGER.error("Error running task", e);
} finally {
synchronized (tasksCompletedNotifier) {
if (--incompleteTaskCounter == 0) {
tasksCompletedNotifier.notifyAll();
}
}
waitGroup.done();
}
}
private void processMainThreadTask(Runnable task) {
try {
task.run();
} catch (Exception e) {
Flywheel.LOGGER.error("Error running main thread task", e);
}
}
@ -233,7 +214,6 @@ public class ParallelTaskExecutor implements TaskExecutor {
}
private class WorkerThread extends Thread {
private final AtomicBoolean running = ParallelTaskExecutor.this.running;
public WorkerThread(String name) {
super(name);
@ -242,7 +222,7 @@ public class ParallelTaskExecutor implements TaskExecutor {
@Override
public void run() {
// Run until the executor shuts down
while (running.get()) {
while (ParallelTaskExecutor.this.running.get()) {
Runnable task = getNextTask();
if (task == null) {

View file

@ -28,7 +28,7 @@ public class EntityWorldHandler {
if (FlwUtil.canUseInstancing(level)) {
InstancedRenderDispatcher.getEntities(level)
.remove(event.getEntity());
.queueRemove(event.getEntity());
}
}
}

View file

@ -70,11 +70,11 @@ public class InstanceWorld implements AutoCloseable {
* </p>
*/
public void tick(double cameraX, double cameraY, double cameraZ) {
var blockEntityPlan = blockEntities.planThisTick(cameraX, cameraY, cameraZ);
var entityPlan = entities.planThisTick(cameraX, cameraY, cameraZ);
var effectPlan = effects.planThisTick(cameraX, cameraY, cameraZ);
taskExecutor.syncPoint();
PlanUtil.of(blockEntityPlan, entityPlan, effectPlan)
blockEntities.planThisTick(cameraX, cameraY, cameraZ)
.and(entities.planThisTick(cameraX, cameraY, cameraZ))
.and(effects.planThisTick(cameraX, cameraY, cameraZ))
.maybeSimplify()
.execute(taskExecutor);
}
@ -114,7 +114,6 @@ public class InstanceWorld implements AutoCloseable {
* Draw all instances for the given stage.
*/
public void renderStage(RenderContext context, RenderStage stage) {
taskExecutor.syncPoint();
engine.renderStage(taskExecutor, context, stage);
}

View file

@ -160,7 +160,7 @@ public class InstancedRenderDispatcher {
// Block entities are loaded while chunks are baked.
// Entities are loaded with the level, so when chunks are reloaded they need to be re-added.
ClientLevelExtension.getAllLoadedEntities(level)
.forEach(world.getEntities()::add);
.forEach(world.getEntities()::queueAdd);
}
public static void addDebugInfo(List<String> info) {

View file

@ -6,7 +6,6 @@ import java.util.concurrent.ConcurrentLinkedQueue;
import org.joml.FrustumIntersection;
import com.jozufozu.flywheel.api.instance.DynamicInstance;
import com.jozufozu.flywheel.api.instance.Instance;
import com.jozufozu.flywheel.api.instance.TickableInstance;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.config.FlwConfig;
@ -15,7 +14,8 @@ import com.jozufozu.flywheel.impl.instancing.ratelimit.DistanceUpdateLimiter;
import com.jozufozu.flywheel.impl.instancing.ratelimit.NonLimiter;
import com.jozufozu.flywheel.impl.instancing.storage.Storage;
import com.jozufozu.flywheel.impl.instancing.storage.Transaction;
import com.jozufozu.flywheel.lib.task.PlanUtil;
import com.jozufozu.flywheel.lib.task.RunOnAllPlan;
import com.jozufozu.flywheel.lib.task.SimplePlan;
public abstract class InstanceManager<T> {
private final Queue<Transaction<T>> queue = new ConcurrentLinkedQueue<>();
@ -47,14 +47,6 @@ public abstract class InstanceManager<T> {
return getStorage().getAllInstances().size();
}
public void add(T obj) {
if (!getStorage().willAccept(obj)) {
return;
}
getStorage().add(obj);
}
public void queueAdd(T obj) {
if (!getStorage().willAccept(obj)) {
return;
@ -63,33 +55,10 @@ public abstract class InstanceManager<T> {
queue.add(Transaction.add(obj));
}
public void remove(T obj) {
getStorage().remove(obj);
}
public void queueRemove(T obj) {
queue.add(Transaction.remove(obj));
}
/**
* Update the instance associated with an object.
*
* <p>
* By default this is the only hook an {@link Instance} has to change its internal state. This is the lowest frequency
* update hook {@link Instance} gets. For more frequent updates, see {@link TickableInstance} and
* {@link DynamicInstance}.
* </p>
*
* @param obj the object to update.
*/
public void update(T obj) {
if (!getStorage().willAccept(obj)) {
return;
}
getStorage().update(obj);
}
public void queueUpdate(T obj) {
if (!getStorage().willAccept(obj)) {
return;
@ -115,9 +84,11 @@ public abstract class InstanceManager<T> {
}
public Plan planThisTick(double cameraX, double cameraY, double cameraZ) {
tickLimiter.tick();
processQueue();
return PlanUtil.runOnAll(getStorage()::getTickableInstances, instance -> tickInstance(instance, cameraX, cameraY, cameraZ));
return SimplePlan.of(() -> {
tickLimiter.tick();
processQueue();
})
.then(RunOnAllPlan.of(getStorage()::getTickableInstances, instance -> tickInstance(instance, cameraX, cameraY, cameraZ)));
}
protected void tickInstance(TickableInstance instance, double cameraX, double cameraY, double cameraZ) {
@ -127,9 +98,11 @@ public abstract class InstanceManager<T> {
}
public Plan planThisFrame(double cameraX, double cameraY, double cameraZ, FrustumIntersection frustum) {
frameLimiter.tick();
processQueue();
return PlanUtil.runOnAll(getStorage()::getDynamicInstances, instance -> updateInstance(instance, cameraX, cameraY, cameraZ, frustum));
return SimplePlan.of(() -> {
frameLimiter.tick();
processQueue();
})
.then(RunOnAllPlan.of(getStorage()::getDynamicInstances, instance -> updateInstance(instance, cameraX, cameraY, cameraZ, frustum)));
}
protected void updateInstance(DynamicInstance instance, double cameraX, double cameraY, double cameraZ, FrustumIntersection frustum) {

View file

@ -1,6 +1,7 @@
package com.jozufozu.flywheel.impl.instancing.storage;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import com.jozufozu.flywheel.api.backend.Engine;
@ -11,7 +12,9 @@ import com.jozufozu.flywheel.api.instance.TickableInstance;
public abstract class AbstractStorage<T> implements Storage<T> {
protected final Engine engine;
protected final List<TickableInstance> tickableInstances = new ArrayList<>();
protected final List<TickableInstance> tickableInstancesView = Collections.unmodifiableList(tickableInstances);
protected final List<DynamicInstance> dynamicInstances = new ArrayList<>();
protected final List<DynamicInstance> dynamicInstancesView = Collections.unmodifiableList(dynamicInstances);
protected AbstractStorage(Engine engine) {
this.engine = engine;
@ -19,12 +22,12 @@ public abstract class AbstractStorage<T> implements Storage<T> {
@Override
public List<TickableInstance> getTickableInstances() {
return tickableInstances;
return tickableInstancesView;
}
@Override
public List<DynamicInstance> getDynamicInstances() {
return dynamicInstances;
return dynamicInstancesView;
}
protected void setup(Instance instance) {

View file

@ -1,6 +1,8 @@
package com.jozufozu.flywheel.lib.light;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.Stream;
import com.jozufozu.flywheel.lib.box.ImmutableBox;
@ -26,6 +28,8 @@ public class LightUpdater {
private final WeakContainmentMultiMap<LightListener> listenersBySection = new WeakContainmentMultiMap<>();
private final Set<TickingLightListener> tickingListeners = FlwUtil.createWeakHashSet();
private final Queue<LightListener> queue = new ConcurrentLinkedQueue<>();
public static LightUpdater get(LevelAccessor level) {
if (LightUpdated.receivesLightUpdates(level)) {
// The level is valid, add it to the map.
@ -41,8 +45,8 @@ public class LightUpdater {
}
public void tick() {
processQueue();
tickSerial();
//tickParallel();
}
private void tickSerial() {
@ -59,8 +63,20 @@ public class LightUpdater {
* @param listener The object that wants to receive light update notifications.
*/
public void addListener(LightListener listener) {
if (listener instanceof TickingLightListener)
queue.add(listener);
}
private synchronized void processQueue() {
LightListener listener;
while ((listener = queue.poll()) != null) {
doAdd(listener);
}
}
private void doAdd(LightListener listener) {
if (listener instanceof TickingLightListener) {
tickingListeners.add(((TickingLightListener) listener));
}
ImmutableBox box = listener.getVolume();
@ -94,9 +110,13 @@ public class LightUpdater {
* @param pos The section position where light changed.
*/
public void onLightUpdate(LightLayer type, SectionPos pos) {
processQueue();
Set<LightListener> listeners = listenersBySection.get(pos.asLong());
if (listeners == null || listeners.isEmpty()) return;
if (listeners == null || listeners.isEmpty()) {
return;
}
listeners.removeIf(LightListener::isInvalid);

View file

@ -1,8 +1,8 @@
package com.jozufozu.flywheel.lib.model;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import com.jozufozu.flywheel.api.event.ReloadRenderersEvent;
import com.jozufozu.flywheel.api.model.Model;
@ -16,9 +16,9 @@ import net.minecraft.core.Direction;
import net.minecraft.world.level.block.state.BlockState;
public final class Models {
private static final Map<BlockState, Model> BLOCK_STATE = new HashMap<>();
private static final Map<PartialModel, Model> PARTIAL = new HashMap<>();
private static final Map<Pair<PartialModel, Direction>, Model> PARTIAL_DIR = new HashMap<>();
private static final Map<BlockState, Model> BLOCK_STATE = new ConcurrentHashMap<>();
private static final Map<PartialModel, Model> PARTIAL = new ConcurrentHashMap<>();
private static final Map<Pair<PartialModel, Direction>, Model> PARTIAL_DIR = new ConcurrentHashMap<>();
public static Model block(BlockState state) {
return BLOCK_STATE.computeIfAbsent(state, it -> new BlockModelBuilder(it).build());

View file

@ -4,12 +4,13 @@ import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import com.google.common.collect.ImmutableList;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor;
public record NestedPlan(List<Plan> parallelPlans) implements Plan {
public static NestedPlan of(Plan... plans) {
return new NestedPlan(List.of(plans));
return new NestedPlan(ImmutableList.copyOf(plans));
}
@Override
@ -33,6 +34,14 @@ public record NestedPlan(List<Plan> parallelPlans) implements Plan {
}
}
@Override
public Plan and(Plan plan) {
return new NestedPlan(ImmutableList.<Plan>builder()
.addAll(parallelPlans)
.add(plan)
.build());
}
@Override
public Plan maybeSimplify() {
if (parallelPlans.isEmpty()) {

View file

@ -1,15 +1,10 @@
package com.jozufozu.flywheel.lib.task;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Supplier;
import com.jozufozu.flywheel.api.task.Plan;
public class PlanUtil {
public static <T> Plan runOnAll(Supplier<List<T>> iterable, Consumer<T> forEach) {
return new RunOnAllPlan<>(iterable, forEach);
}
public static Plan of() {
return UnitPlan.INSTANCE;

View file

@ -8,16 +8,19 @@ import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor;
public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action) implements Plan {
public static <T> Plan of(Supplier<List<T>> iterable, Consumer<T> forEach) {
return new RunOnAllPlan<>(iterable, forEach);
}
@Override
public void execute(TaskExecutor taskExecutor, Runnable onCompletion) {
// TODO: unit tests, fix CME?
taskExecutor.execute(() -> {
var list = listSupplier.get();
final int size = list.size();
if (size == 0) {
onCompletion.run();
} else if (size <= getChunkingThreshold(taskExecutor)) {
} else if (size <= getChunkingThreshold()) {
processList(list, onCompletion);
} else {
dispatchChunks(list, taskExecutor, onCompletion);
@ -27,12 +30,9 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
private void dispatchChunks(List<T> suppliedList, TaskExecutor taskExecutor, Runnable onCompletion) {
final int size = suppliedList.size();
final int threadCount = taskExecutor.getThreadCount();
final int chunkSize = getChunkSize(taskExecutor, size);
final int chunkSize = (size + threadCount - 1) / threadCount; // ceiling division
final int chunkCount = (size + chunkSize - 1) / chunkSize; // ceiling division
var synchronizer = new Synchronizer(chunkCount, onCompletion);
var synchronizer = new Synchronizer(ceilingDiv(size, chunkSize), onCompletion);
int remaining = size;
while (remaining > 0) {
@ -45,6 +45,14 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
}
}
private static int getChunkSize(TaskExecutor taskExecutor, int totalSize) {
return ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32);
}
private static int ceilingDiv(int numerator, int denominator) {
return (numerator + denominator - 1) / denominator;
}
private void processList(List<T> suppliedList, Runnable onCompletion) {
for (T t : suppliedList) {
action.accept(t);
@ -52,7 +60,7 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
onCompletion.run();
}
private static int getChunkingThreshold(TaskExecutor taskExecutor) {
return 512;
private static int getChunkingThreshold() {
return 256;
}
}

View file

@ -18,7 +18,7 @@ public record SimplePlan(List<Runnable> parallelTasks) implements Plan {
}
var synchronizer = new Synchronizer(parallelTasks.size(), onCompletion);
for (Runnable task : parallelTasks) {
for (var task : parallelTasks) {
taskExecutor.execute(() -> {
task.run();
synchronizer.decrementAndEventuallyRun();

View file

@ -0,0 +1,59 @@
package com.jozufozu.flywheel.lib.task;
import java.util.concurrent.atomic.AtomicInteger;
// https://stackoverflow.com/questions/29655531
public class WaitGroup {
private final AtomicInteger counter = new AtomicInteger(0);
public void add() {
add(1);
}
public void add(int i) {
if (i == 0) {
return;
}
counter.addAndGet(i);
}
public void done() {
var result = counter.decrementAndGet();
if (result == 0) {
synchronized (this) {
this.notifyAll();
}
} else if (result < 0) {
throw new IllegalStateException("WaitGroup counter is negative!");
}
}
public void await() {
try {
awaitInternal();
} catch (InterruptedException ignored) {
// noop
}
}
private void awaitInternal() throws InterruptedException {
// var start = System.nanoTime();
while (counter.get() > 0) {
// spin in place to avoid sleeping the main thread
// synchronized (this) {
// this.wait(timeoutMs);
// }
}
// var end = System.nanoTime();
// var elapsed = end - start;
//
// if (elapsed > 1000000) {
// Flywheel.LOGGER.info("Waited " + StringUtil.formatTime(elapsed));
// }
}
public void _reset() {
counter.set(0);
}
}

View file

@ -25,7 +25,7 @@ public class InstanceAddMixin {
private void flywheel$onBlockEntityAdded(BlockEntity blockEntity, CallbackInfo ci) {
if (level.isClientSide && FlwUtil.canUseInstancing(level)) {
InstancedRenderDispatcher.getBlockEntities(level)
.add(blockEntity);
.queueAdd(blockEntity);
}
}
}

View file

@ -24,7 +24,7 @@ public class InstanceRemoveMixin {
private void flywheel$removeInstance(CallbackInfo ci) {
if (level instanceof ClientLevel && FlwUtil.canUseInstancing(level)) {
InstancedRenderDispatcher.getBlockEntities(level)
.remove((BlockEntity) (Object) this);
.queueRemove((BlockEntity) (Object) this);
}
}

View file

@ -35,6 +35,6 @@ public class InstanceUpdateMixin {
}
InstancedRenderDispatcher.getBlockEntities(level)
.update(blockEntity);
.queueUpdate(blockEntity);
}
}

View file

@ -0,0 +1,20 @@
package com.jozufozu.flywheel.lib.task;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import com.jozufozu.flywheel.api.task.Plan;
public class PlanCompositionTest {
public static final Runnable NOOP = () -> {
};
public static final Plan SIMPLE = SimplePlan.of(NOOP);
@Test
void nestedPlanAnd() {
var empty = NestedPlan.of(SIMPLE);
Assertions.assertEquals(NestedPlan.of(SIMPLE, SIMPLE), empty.and(SIMPLE));
}
}

View file

@ -0,0 +1,14 @@
package com.jozufozu.flywheel.lib.task;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class WaitGroupTest {
@Test
public void testExtraDone() {
WaitGroup wg = new WaitGroup();
wg.add();
wg.done();
Assertions.assertThrows(IllegalStateException.class, wg::done);
}
}