Weaving threads

- Restore WaitGroup abstraction
- Simplify WaitGroup to make proper use of atomics
- Fix logic error in ParallelTaskExecutor causing task counter to go
  below zero when executing main thread tasks
- Use ConcurrentHashMap in models to allow parallel access
- Reduce chunk size in RunOnAllPlan
- Only queue in InstanceManager
- Defer syncPoint within renderStage
- AbstractStorage return immutable view for tickable/dynamic instances
- Process InstanceManager queues off-thread
This commit is contained in:
Jozufozu 2023-04-13 15:17:02 -07:00
parent 76dcf3fed6
commit f29dcbc486
25 changed files with 241 additions and 121 deletions

View file

@ -1,6 +1,7 @@
package com.jozufozu.flywheel.api.task; package com.jozufozu.flywheel.api.task;
import com.jozufozu.flywheel.lib.task.BarrierPlan; import com.jozufozu.flywheel.lib.task.BarrierPlan;
import com.jozufozu.flywheel.lib.task.NestedPlan;
public interface Plan { public interface Plan {
/** /**
@ -22,13 +23,23 @@ public interface Plan {
* Create a new plan that executes this plan, then the given plan. * Create a new plan that executes this plan, then the given plan.
* *
* @param plan The plan to execute after this plan. * @param plan The plan to execute after this plan.
* @return The new, composed plan. * @return The composed plan.
*/ */
default Plan then(Plan plan) { default Plan then(Plan plan) {
// TODO: AbstractPlan? // TODO: AbstractPlan?
return new BarrierPlan(this, plan); return new BarrierPlan(this, plan);
} }
/**
* Create a new plan that executes this plan and the given plan in parallel.
*
* @param plan The plan to execute in parallel with this plan.
* @return The composed plan.
*/
default Plan and(Plan plan) {
return NestedPlan.of(this, plan);
}
/** /**
* If possible, create a new plan that accomplishes everything * If possible, create a new plan that accomplishes everything
* this plan does but with a simpler execution schedule. * this plan does but with a simpler execution schedule.

View file

@ -81,4 +81,9 @@ public class BatchingDrawTracker {
buffers.clear(); buffers.clear();
} }
} }
public boolean hasStage(RenderStage stage) {
return !activeBuffers.get(stage)
.isEmpty();
}
} }

View file

@ -82,6 +82,10 @@ public class BatchingEngine implements Engine {
@Override @Override
public void renderStage(TaskExecutor executor, RenderContext context, RenderStage stage) { public void renderStage(TaskExecutor executor, RenderContext context, RenderStage stage) {
if (!drawTracker.hasStage(stage)) {
return;
}
executor.syncPoint();
drawTracker.draw(stage); drawTracker.draw(stage);
} }

View file

@ -185,4 +185,7 @@ public class IndirectCullingGroup<P extends InstancePart> {
meshPool.delete(); meshPool.delete();
} }
public boolean hasStage(RenderStage stage) {
return drawSet.contains(stage);
}
} }

View file

@ -73,6 +73,15 @@ public class IndirectDrawManager {
initializedInstancers.add(instancer); initializedInstancers.add(instancer);
} }
public boolean hasStage(RenderStage stage) {
for (var list : renderLists.values()) {
if (list.hasStage(stage)) {
return true;
}
}
return false;
}
private record UninitializedInstancer(IndirectInstancer<?> instancer, Model model, RenderStage stage) { private record UninitializedInstancer(IndirectInstancer<?> instancer, Model model, RenderStage stage) {
} }
} }

View file

@ -53,6 +53,12 @@ public class IndirectEngine implements Engine {
@Override @Override
public void renderStage(TaskExecutor executor, RenderContext context, RenderStage stage) { public void renderStage(TaskExecutor executor, RenderContext context, RenderStage stage) {
if (!drawManager.hasStage(stage)) {
return;
}
executor.syncPoint();
try (var restoreState = GlStateTracker.getRestoreState()) { try (var restoreState = GlStateTracker.getRestoreState()) {
setup(); setup();

View file

@ -66,6 +66,8 @@ public class InstancingEngine implements Engine {
return; return;
} }
executor.syncPoint();
try (var state = GlStateTracker.getRestoreState()) { try (var state = GlStateTracker.getRestoreState()) {
setup(); setup();

View file

@ -13,6 +13,7 @@ import org.slf4j.Logger;
import com.jozufozu.flywheel.Flywheel; import com.jozufozu.flywheel.Flywheel;
import com.jozufozu.flywheel.api.task.TaskExecutor; import com.jozufozu.flywheel.api.task.TaskExecutor;
import com.jozufozu.flywheel.lib.task.WaitGroup;
import com.mojang.logging.LogUtils; import com.mojang.logging.LogUtils;
import net.minecraft.util.Mth; import net.minecraft.util.Mth;
@ -29,17 +30,13 @@ public class ParallelTaskExecutor implements TaskExecutor {
* If set to false, the executor will shut down. * If set to false, the executor will shut down.
*/ */
private final AtomicBoolean running = new AtomicBoolean(false); private final AtomicBoolean running = new AtomicBoolean(false);
/**
* Synchronized via {@link #tasksCompletedNotifier}.
*/
private int incompleteTaskCounter = 0;
private final List<WorkerThread> threads = new ArrayList<>(); private final List<WorkerThread> threads = new ArrayList<>();
private final Deque<Runnable> taskQueue = new ConcurrentLinkedDeque<>(); private final Deque<Runnable> taskQueue = new ConcurrentLinkedDeque<>();
private final Queue<Runnable> mainThreadQueue = new ConcurrentLinkedQueue<>(); private final Queue<Runnable> mainThreadQueue = new ConcurrentLinkedQueue<>();
private final Object taskNotifier = new Object(); private final Object taskNotifier = new Object();
private final Object tasksCompletedNotifier = new Object(); private final WaitGroup waitGroup = new WaitGroup();
public ParallelTaskExecutor(String name) { public ParallelTaskExecutor(String name) {
this.name = name; this.name = name;
@ -102,10 +99,6 @@ public class ParallelTaskExecutor implements TaskExecutor {
threads.clear(); threads.clear();
taskQueue.clear(); taskQueue.clear();
synchronized (tasksCompletedNotifier) {
incompleteTaskCounter = 0;
tasksCompletedNotifier.notifyAll();
}
} }
@Override @Override
@ -114,12 +107,9 @@ public class ParallelTaskExecutor implements TaskExecutor {
throw new IllegalStateException("Executor is stopped"); throw new IllegalStateException("Executor is stopped");
} }
waitGroup.add();
taskQueue.add(task); taskQueue.add(task);
synchronized (tasksCompletedNotifier) {
incompleteTaskCounter++;
}
synchronized (taskNotifier) { synchronized (taskNotifier) {
taskNotifier.notifyAll(); taskNotifier.notifyAll();
} }
@ -140,19 +130,17 @@ public class ParallelTaskExecutor implements TaskExecutor {
@Override @Override
public void syncPoint() { public void syncPoint() {
Runnable task; Runnable task;
// Finish everyone else's work... // Finish everyone else's work...
while ((task = pollForSyncPoint()) != null) { while (true) {
if ((task = mainThreadQueue.poll()) != null) {
processMainThreadTask(task);
} else if ((task = taskQueue.pollLast()) != null) {
processTask(task); processTask(task);
} } else {
// and wait for any stragglers. // and wait for any stragglers.
synchronized (tasksCompletedNotifier) { waitGroup.await();
while (incompleteTaskCounter > 0) { if (mainThreadQueue.isEmpty()) {
try { break;
tasksCompletedNotifier.wait();
} catch (InterruptedException e) {
//
} }
} }
} }
@ -170,23 +158,12 @@ public class ParallelTaskExecutor implements TaskExecutor {
public void discardAndAwait() { public void discardAndAwait() {
// Discard everyone else's work... // Discard everyone else's work...
while (taskQueue.pollLast() != null) { while (taskQueue.pollLast() != null) {
synchronized (tasksCompletedNotifier) { waitGroup.done();
if (--incompleteTaskCounter == 0) {
tasksCompletedNotifier.notifyAll();
}
}
} }
// and wait for any stragglers. // and wait for any stragglers.
synchronized (tasksCompletedNotifier) { waitGroup.await();
while (incompleteTaskCounter > 0) { mainThreadQueue.clear();
try {
tasksCompletedNotifier.wait();
} catch (InterruptedException e) {
//
}
}
}
} }
@Nullable @Nullable
@ -213,11 +190,15 @@ public class ParallelTaskExecutor implements TaskExecutor {
} catch (Exception e) { } catch (Exception e) {
Flywheel.LOGGER.error("Error running task", e); Flywheel.LOGGER.error("Error running task", e);
} finally { } finally {
synchronized (tasksCompletedNotifier) { waitGroup.done();
if (--incompleteTaskCounter == 0) {
tasksCompletedNotifier.notifyAll();
} }
} }
private void processMainThreadTask(Runnable task) {
try {
task.run();
} catch (Exception e) {
Flywheel.LOGGER.error("Error running main thread task", e);
} }
} }
@ -233,7 +214,6 @@ public class ParallelTaskExecutor implements TaskExecutor {
} }
private class WorkerThread extends Thread { private class WorkerThread extends Thread {
private final AtomicBoolean running = ParallelTaskExecutor.this.running;
public WorkerThread(String name) { public WorkerThread(String name) {
super(name); super(name);
@ -242,7 +222,7 @@ public class ParallelTaskExecutor implements TaskExecutor {
@Override @Override
public void run() { public void run() {
// Run until the executor shuts down // Run until the executor shuts down
while (running.get()) { while (ParallelTaskExecutor.this.running.get()) {
Runnable task = getNextTask(); Runnable task = getNextTask();
if (task == null) { if (task == null) {

View file

@ -28,7 +28,7 @@ public class EntityWorldHandler {
if (FlwUtil.canUseInstancing(level)) { if (FlwUtil.canUseInstancing(level)) {
InstancedRenderDispatcher.getEntities(level) InstancedRenderDispatcher.getEntities(level)
.remove(event.getEntity()); .queueRemove(event.getEntity());
} }
} }
} }

View file

@ -70,11 +70,11 @@ public class InstanceWorld implements AutoCloseable {
* </p> * </p>
*/ */
public void tick(double cameraX, double cameraY, double cameraZ) { public void tick(double cameraX, double cameraY, double cameraZ) {
var blockEntityPlan = blockEntities.planThisTick(cameraX, cameraY, cameraZ); taskExecutor.syncPoint();
var entityPlan = entities.planThisTick(cameraX, cameraY, cameraZ);
var effectPlan = effects.planThisTick(cameraX, cameraY, cameraZ);
PlanUtil.of(blockEntityPlan, entityPlan, effectPlan) blockEntities.planThisTick(cameraX, cameraY, cameraZ)
.and(entities.planThisTick(cameraX, cameraY, cameraZ))
.and(effects.planThisTick(cameraX, cameraY, cameraZ))
.maybeSimplify() .maybeSimplify()
.execute(taskExecutor); .execute(taskExecutor);
} }
@ -114,7 +114,6 @@ public class InstanceWorld implements AutoCloseable {
* Draw all instances for the given stage. * Draw all instances for the given stage.
*/ */
public void renderStage(RenderContext context, RenderStage stage) { public void renderStage(RenderContext context, RenderStage stage) {
taskExecutor.syncPoint();
engine.renderStage(taskExecutor, context, stage); engine.renderStage(taskExecutor, context, stage);
} }

View file

@ -160,7 +160,7 @@ public class InstancedRenderDispatcher {
// Block entities are loaded while chunks are baked. // Block entities are loaded while chunks are baked.
// Entities are loaded with the level, so when chunks are reloaded they need to be re-added. // Entities are loaded with the level, so when chunks are reloaded they need to be re-added.
ClientLevelExtension.getAllLoadedEntities(level) ClientLevelExtension.getAllLoadedEntities(level)
.forEach(world.getEntities()::add); .forEach(world.getEntities()::queueAdd);
} }
public static void addDebugInfo(List<String> info) { public static void addDebugInfo(List<String> info) {

View file

@ -6,7 +6,6 @@ import java.util.concurrent.ConcurrentLinkedQueue;
import org.joml.FrustumIntersection; import org.joml.FrustumIntersection;
import com.jozufozu.flywheel.api.instance.DynamicInstance; import com.jozufozu.flywheel.api.instance.DynamicInstance;
import com.jozufozu.flywheel.api.instance.Instance;
import com.jozufozu.flywheel.api.instance.TickableInstance; import com.jozufozu.flywheel.api.instance.TickableInstance;
import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.config.FlwConfig; import com.jozufozu.flywheel.config.FlwConfig;
@ -15,7 +14,8 @@ import com.jozufozu.flywheel.impl.instancing.ratelimit.DistanceUpdateLimiter;
import com.jozufozu.flywheel.impl.instancing.ratelimit.NonLimiter; import com.jozufozu.flywheel.impl.instancing.ratelimit.NonLimiter;
import com.jozufozu.flywheel.impl.instancing.storage.Storage; import com.jozufozu.flywheel.impl.instancing.storage.Storage;
import com.jozufozu.flywheel.impl.instancing.storage.Transaction; import com.jozufozu.flywheel.impl.instancing.storage.Transaction;
import com.jozufozu.flywheel.lib.task.PlanUtil; import com.jozufozu.flywheel.lib.task.RunOnAllPlan;
import com.jozufozu.flywheel.lib.task.SimplePlan;
public abstract class InstanceManager<T> { public abstract class InstanceManager<T> {
private final Queue<Transaction<T>> queue = new ConcurrentLinkedQueue<>(); private final Queue<Transaction<T>> queue = new ConcurrentLinkedQueue<>();
@ -47,14 +47,6 @@ public abstract class InstanceManager<T> {
return getStorage().getAllInstances().size(); return getStorage().getAllInstances().size();
} }
public void add(T obj) {
if (!getStorage().willAccept(obj)) {
return;
}
getStorage().add(obj);
}
public void queueAdd(T obj) { public void queueAdd(T obj) {
if (!getStorage().willAccept(obj)) { if (!getStorage().willAccept(obj)) {
return; return;
@ -63,33 +55,10 @@ public abstract class InstanceManager<T> {
queue.add(Transaction.add(obj)); queue.add(Transaction.add(obj));
} }
public void remove(T obj) {
getStorage().remove(obj);
}
public void queueRemove(T obj) { public void queueRemove(T obj) {
queue.add(Transaction.remove(obj)); queue.add(Transaction.remove(obj));
} }
/**
* Update the instance associated with an object.
*
* <p>
* By default this is the only hook an {@link Instance} has to change its internal state. This is the lowest frequency
* update hook {@link Instance} gets. For more frequent updates, see {@link TickableInstance} and
* {@link DynamicInstance}.
* </p>
*
* @param obj the object to update.
*/
public void update(T obj) {
if (!getStorage().willAccept(obj)) {
return;
}
getStorage().update(obj);
}
public void queueUpdate(T obj) { public void queueUpdate(T obj) {
if (!getStorage().willAccept(obj)) { if (!getStorage().willAccept(obj)) {
return; return;
@ -115,9 +84,11 @@ public abstract class InstanceManager<T> {
} }
public Plan planThisTick(double cameraX, double cameraY, double cameraZ) { public Plan planThisTick(double cameraX, double cameraY, double cameraZ) {
return SimplePlan.of(() -> {
tickLimiter.tick(); tickLimiter.tick();
processQueue(); processQueue();
return PlanUtil.runOnAll(getStorage()::getTickableInstances, instance -> tickInstance(instance, cameraX, cameraY, cameraZ)); })
.then(RunOnAllPlan.of(getStorage()::getTickableInstances, instance -> tickInstance(instance, cameraX, cameraY, cameraZ)));
} }
protected void tickInstance(TickableInstance instance, double cameraX, double cameraY, double cameraZ) { protected void tickInstance(TickableInstance instance, double cameraX, double cameraY, double cameraZ) {
@ -127,9 +98,11 @@ public abstract class InstanceManager<T> {
} }
public Plan planThisFrame(double cameraX, double cameraY, double cameraZ, FrustumIntersection frustum) { public Plan planThisFrame(double cameraX, double cameraY, double cameraZ, FrustumIntersection frustum) {
return SimplePlan.of(() -> {
frameLimiter.tick(); frameLimiter.tick();
processQueue(); processQueue();
return PlanUtil.runOnAll(getStorage()::getDynamicInstances, instance -> updateInstance(instance, cameraX, cameraY, cameraZ, frustum)); })
.then(RunOnAllPlan.of(getStorage()::getDynamicInstances, instance -> updateInstance(instance, cameraX, cameraY, cameraZ, frustum)));
} }
protected void updateInstance(DynamicInstance instance, double cameraX, double cameraY, double cameraZ, FrustumIntersection frustum) { protected void updateInstance(DynamicInstance instance, double cameraX, double cameraY, double cameraZ, FrustumIntersection frustum) {

View file

@ -1,6 +1,7 @@
package com.jozufozu.flywheel.impl.instancing.storage; package com.jozufozu.flywheel.impl.instancing.storage;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import com.jozufozu.flywheel.api.backend.Engine; import com.jozufozu.flywheel.api.backend.Engine;
@ -11,7 +12,9 @@ import com.jozufozu.flywheel.api.instance.TickableInstance;
public abstract class AbstractStorage<T> implements Storage<T> { public abstract class AbstractStorage<T> implements Storage<T> {
protected final Engine engine; protected final Engine engine;
protected final List<TickableInstance> tickableInstances = new ArrayList<>(); protected final List<TickableInstance> tickableInstances = new ArrayList<>();
protected final List<TickableInstance> tickableInstancesView = Collections.unmodifiableList(tickableInstances);
protected final List<DynamicInstance> dynamicInstances = new ArrayList<>(); protected final List<DynamicInstance> dynamicInstances = new ArrayList<>();
protected final List<DynamicInstance> dynamicInstancesView = Collections.unmodifiableList(dynamicInstances);
protected AbstractStorage(Engine engine) { protected AbstractStorage(Engine engine) {
this.engine = engine; this.engine = engine;
@ -19,12 +22,12 @@ public abstract class AbstractStorage<T> implements Storage<T> {
@Override @Override
public List<TickableInstance> getTickableInstances() { public List<TickableInstance> getTickableInstances() {
return tickableInstances; return tickableInstancesView;
} }
@Override @Override
public List<DynamicInstance> getDynamicInstances() { public List<DynamicInstance> getDynamicInstances() {
return dynamicInstances; return dynamicInstancesView;
} }
protected void setup(Instance instance) { protected void setup(Instance instance) {

View file

@ -1,6 +1,8 @@
package com.jozufozu.flywheel.lib.light; package com.jozufozu.flywheel.lib.light;
import java.util.Queue;
import java.util.Set; import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.stream.Stream; import java.util.stream.Stream;
import com.jozufozu.flywheel.lib.box.ImmutableBox; import com.jozufozu.flywheel.lib.box.ImmutableBox;
@ -26,6 +28,8 @@ public class LightUpdater {
private final WeakContainmentMultiMap<LightListener> listenersBySection = new WeakContainmentMultiMap<>(); private final WeakContainmentMultiMap<LightListener> listenersBySection = new WeakContainmentMultiMap<>();
private final Set<TickingLightListener> tickingListeners = FlwUtil.createWeakHashSet(); private final Set<TickingLightListener> tickingListeners = FlwUtil.createWeakHashSet();
private final Queue<LightListener> queue = new ConcurrentLinkedQueue<>();
public static LightUpdater get(LevelAccessor level) { public static LightUpdater get(LevelAccessor level) {
if (LightUpdated.receivesLightUpdates(level)) { if (LightUpdated.receivesLightUpdates(level)) {
// The level is valid, add it to the map. // The level is valid, add it to the map.
@ -41,8 +45,8 @@ public class LightUpdater {
} }
public void tick() { public void tick() {
processQueue();
tickSerial(); tickSerial();
//tickParallel();
} }
private void tickSerial() { private void tickSerial() {
@ -59,8 +63,20 @@ public class LightUpdater {
* @param listener The object that wants to receive light update notifications. * @param listener The object that wants to receive light update notifications.
*/ */
public void addListener(LightListener listener) { public void addListener(LightListener listener) {
if (listener instanceof TickingLightListener) queue.add(listener);
}
private synchronized void processQueue() {
LightListener listener;
while ((listener = queue.poll()) != null) {
doAdd(listener);
}
}
private void doAdd(LightListener listener) {
if (listener instanceof TickingLightListener) {
tickingListeners.add(((TickingLightListener) listener)); tickingListeners.add(((TickingLightListener) listener));
}
ImmutableBox box = listener.getVolume(); ImmutableBox box = listener.getVolume();
@ -94,9 +110,13 @@ public class LightUpdater {
* @param pos The section position where light changed. * @param pos The section position where light changed.
*/ */
public void onLightUpdate(LightLayer type, SectionPos pos) { public void onLightUpdate(LightLayer type, SectionPos pos) {
processQueue();
Set<LightListener> listeners = listenersBySection.get(pos.asLong()); Set<LightListener> listeners = listenersBySection.get(pos.asLong());
if (listeners == null || listeners.isEmpty()) return; if (listeners == null || listeners.isEmpty()) {
return;
}
listeners.removeIf(LightListener::isInvalid); listeners.removeIf(LightListener::isInvalid);

View file

@ -1,8 +1,8 @@
package com.jozufozu.flywheel.lib.model; package com.jozufozu.flywheel.lib.model;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import com.jozufozu.flywheel.api.event.ReloadRenderersEvent; import com.jozufozu.flywheel.api.event.ReloadRenderersEvent;
import com.jozufozu.flywheel.api.model.Model; import com.jozufozu.flywheel.api.model.Model;
@ -16,9 +16,9 @@ import net.minecraft.core.Direction;
import net.minecraft.world.level.block.state.BlockState; import net.minecraft.world.level.block.state.BlockState;
public final class Models { public final class Models {
private static final Map<BlockState, Model> BLOCK_STATE = new HashMap<>(); private static final Map<BlockState, Model> BLOCK_STATE = new ConcurrentHashMap<>();
private static final Map<PartialModel, Model> PARTIAL = new HashMap<>(); private static final Map<PartialModel, Model> PARTIAL = new ConcurrentHashMap<>();
private static final Map<Pair<PartialModel, Direction>, Model> PARTIAL_DIR = new HashMap<>(); private static final Map<Pair<PartialModel, Direction>, Model> PARTIAL_DIR = new ConcurrentHashMap<>();
public static Model block(BlockState state) { public static Model block(BlockState state) {
return BLOCK_STATE.computeIfAbsent(state, it -> new BlockModelBuilder(it).build()); return BLOCK_STATE.computeIfAbsent(state, it -> new BlockModelBuilder(it).build());

View file

@ -4,12 +4,13 @@ import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import com.google.common.collect.ImmutableList;
import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor; import com.jozufozu.flywheel.api.task.TaskExecutor;
public record NestedPlan(List<Plan> parallelPlans) implements Plan { public record NestedPlan(List<Plan> parallelPlans) implements Plan {
public static NestedPlan of(Plan... plans) { public static NestedPlan of(Plan... plans) {
return new NestedPlan(List.of(plans)); return new NestedPlan(ImmutableList.copyOf(plans));
} }
@Override @Override
@ -33,6 +34,14 @@ public record NestedPlan(List<Plan> parallelPlans) implements Plan {
} }
} }
@Override
public Plan and(Plan plan) {
return new NestedPlan(ImmutableList.<Plan>builder()
.addAll(parallelPlans)
.add(plan)
.build());
}
@Override @Override
public Plan maybeSimplify() { public Plan maybeSimplify() {
if (parallelPlans.isEmpty()) { if (parallelPlans.isEmpty()) {

View file

@ -1,15 +1,10 @@
package com.jozufozu.flywheel.lib.task; package com.jozufozu.flywheel.lib.task;
import java.util.List; import java.util.List;
import java.util.function.Consumer;
import java.util.function.Supplier;
import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.Plan;
public class PlanUtil { public class PlanUtil {
public static <T> Plan runOnAll(Supplier<List<T>> iterable, Consumer<T> forEach) {
return new RunOnAllPlan<>(iterable, forEach);
}
public static Plan of() { public static Plan of() {
return UnitPlan.INSTANCE; return UnitPlan.INSTANCE;

View file

@ -8,16 +8,19 @@ import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor; import com.jozufozu.flywheel.api.task.TaskExecutor;
public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action) implements Plan { public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action) implements Plan {
public static <T> Plan of(Supplier<List<T>> iterable, Consumer<T> forEach) {
return new RunOnAllPlan<>(iterable, forEach);
}
@Override @Override
public void execute(TaskExecutor taskExecutor, Runnable onCompletion) { public void execute(TaskExecutor taskExecutor, Runnable onCompletion) {
// TODO: unit tests, fix CME?
taskExecutor.execute(() -> { taskExecutor.execute(() -> {
var list = listSupplier.get(); var list = listSupplier.get();
final int size = list.size(); final int size = list.size();
if (size == 0) { if (size == 0) {
onCompletion.run(); onCompletion.run();
} else if (size <= getChunkingThreshold(taskExecutor)) { } else if (size <= getChunkingThreshold()) {
processList(list, onCompletion); processList(list, onCompletion);
} else { } else {
dispatchChunks(list, taskExecutor, onCompletion); dispatchChunks(list, taskExecutor, onCompletion);
@ -27,12 +30,9 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
private void dispatchChunks(List<T> suppliedList, TaskExecutor taskExecutor, Runnable onCompletion) { private void dispatchChunks(List<T> suppliedList, TaskExecutor taskExecutor, Runnable onCompletion) {
final int size = suppliedList.size(); final int size = suppliedList.size();
final int threadCount = taskExecutor.getThreadCount(); final int chunkSize = getChunkSize(taskExecutor, size);
final int chunkSize = (size + threadCount - 1) / threadCount; // ceiling division var synchronizer = new Synchronizer(ceilingDiv(size, chunkSize), onCompletion);
final int chunkCount = (size + chunkSize - 1) / chunkSize; // ceiling division
var synchronizer = new Synchronizer(chunkCount, onCompletion);
int remaining = size; int remaining = size;
while (remaining > 0) { while (remaining > 0) {
@ -45,6 +45,14 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
} }
} }
private static int getChunkSize(TaskExecutor taskExecutor, int totalSize) {
return ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32);
}
private static int ceilingDiv(int numerator, int denominator) {
return (numerator + denominator - 1) / denominator;
}
private void processList(List<T> suppliedList, Runnable onCompletion) { private void processList(List<T> suppliedList, Runnable onCompletion) {
for (T t : suppliedList) { for (T t : suppliedList) {
action.accept(t); action.accept(t);
@ -52,7 +60,7 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
onCompletion.run(); onCompletion.run();
} }
private static int getChunkingThreshold(TaskExecutor taskExecutor) { private static int getChunkingThreshold() {
return 512; return 256;
} }
} }

View file

@ -18,7 +18,7 @@ public record SimplePlan(List<Runnable> parallelTasks) implements Plan {
} }
var synchronizer = new Synchronizer(parallelTasks.size(), onCompletion); var synchronizer = new Synchronizer(parallelTasks.size(), onCompletion);
for (Runnable task : parallelTasks) { for (var task : parallelTasks) {
taskExecutor.execute(() -> { taskExecutor.execute(() -> {
task.run(); task.run();
synchronizer.decrementAndEventuallyRun(); synchronizer.decrementAndEventuallyRun();

View file

@ -0,0 +1,59 @@
package com.jozufozu.flywheel.lib.task;
import java.util.concurrent.atomic.AtomicInteger;
// https://stackoverflow.com/questions/29655531
public class WaitGroup {
private final AtomicInteger counter = new AtomicInteger(0);
public void add() {
add(1);
}
public void add(int i) {
if (i == 0) {
return;
}
counter.addAndGet(i);
}
public void done() {
var result = counter.decrementAndGet();
if (result == 0) {
synchronized (this) {
this.notifyAll();
}
} else if (result < 0) {
throw new IllegalStateException("WaitGroup counter is negative!");
}
}
public void await() {
try {
awaitInternal();
} catch (InterruptedException ignored) {
// noop
}
}
private void awaitInternal() throws InterruptedException {
// var start = System.nanoTime();
while (counter.get() > 0) {
// spin in place to avoid sleeping the main thread
// synchronized (this) {
// this.wait(timeoutMs);
// }
}
// var end = System.nanoTime();
// var elapsed = end - start;
//
// if (elapsed > 1000000) {
// Flywheel.LOGGER.info("Waited " + StringUtil.formatTime(elapsed));
// }
}
public void _reset() {
counter.set(0);
}
}

View file

@ -25,7 +25,7 @@ public class InstanceAddMixin {
private void flywheel$onBlockEntityAdded(BlockEntity blockEntity, CallbackInfo ci) { private void flywheel$onBlockEntityAdded(BlockEntity blockEntity, CallbackInfo ci) {
if (level.isClientSide && FlwUtil.canUseInstancing(level)) { if (level.isClientSide && FlwUtil.canUseInstancing(level)) {
InstancedRenderDispatcher.getBlockEntities(level) InstancedRenderDispatcher.getBlockEntities(level)
.add(blockEntity); .queueAdd(blockEntity);
} }
} }
} }

View file

@ -24,7 +24,7 @@ public class InstanceRemoveMixin {
private void flywheel$removeInstance(CallbackInfo ci) { private void flywheel$removeInstance(CallbackInfo ci) {
if (level instanceof ClientLevel && FlwUtil.canUseInstancing(level)) { if (level instanceof ClientLevel && FlwUtil.canUseInstancing(level)) {
InstancedRenderDispatcher.getBlockEntities(level) InstancedRenderDispatcher.getBlockEntities(level)
.remove((BlockEntity) (Object) this); .queueRemove((BlockEntity) (Object) this);
} }
} }

View file

@ -35,6 +35,6 @@ public class InstanceUpdateMixin {
} }
InstancedRenderDispatcher.getBlockEntities(level) InstancedRenderDispatcher.getBlockEntities(level)
.update(blockEntity); .queueUpdate(blockEntity);
} }
} }

View file

@ -0,0 +1,20 @@
package com.jozufozu.flywheel.lib.task;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import com.jozufozu.flywheel.api.task.Plan;
public class PlanCompositionTest {
public static final Runnable NOOP = () -> {
};
public static final Plan SIMPLE = SimplePlan.of(NOOP);
@Test
void nestedPlanAnd() {
var empty = NestedPlan.of(SIMPLE);
Assertions.assertEquals(NestedPlan.of(SIMPLE, SIMPLE), empty.and(SIMPLE));
}
}

View file

@ -0,0 +1,14 @@
package com.jozufozu.flywheel.lib.task;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class WaitGroupTest {
@Test
public void testExtraDone() {
WaitGroup wg = new WaitGroup();
wg.add();
wg.done();
Assertions.assertThrows(IllegalStateException.class, wg::done);
}
}