diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/batching/TransformCall.java b/src/main/java/com/jozufozu/flywheel/backend/engine/batching/TransformCall.java index 12358edfc..d014514b5 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/engine/batching/TransformCall.java +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/batching/TransformCall.java @@ -13,7 +13,8 @@ import com.jozufozu.flywheel.api.material.Material; import com.jozufozu.flywheel.api.material.MaterialVertexTransformer; import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.vertex.MutableVertexList; -import com.jozufozu.flywheel.lib.task.ForEachPlan; +import com.jozufozu.flywheel.api.vertex.ReusableVertexList; +import com.jozufozu.flywheel.lib.task.ForEachSlicePlan; import com.jozufozu.flywheel.lib.vertex.VertexTransformations; import com.mojang.blaze3d.vertex.PoseStack; import com.mojang.math.Matrix3f; @@ -37,22 +38,26 @@ public class TransformCall { meshVertexCount = mesh.getVertexCount(); Vector4fc meshBoundingSphere = mesh.boundingSphere(); - drawPlan = ForEachPlan.of(instancer::getAll, (instance, ctx) -> { - var boundingSphere = new Vector4f(meshBoundingSphere); - boundingSphereTransformer.transform(boundingSphere, instance); + drawPlan = ForEachSlicePlan.of(instancer::getAll, (subList, ctx) -> { + ReusableVertexList vertexList = ctx.buffer.slice(0, meshVertexCount); + Vector4f boundingSphere = new Vector4f(); - if (!ctx.frustum - .testSphere(boundingSphere.x, boundingSphere.y, boundingSphere.z, boundingSphere.w)) { - return; + for (I instance : subList) { + boundingSphere.set(meshBoundingSphere); + boundingSphereTransformer.transform(boundingSphere, instance); + + if (!ctx.frustum.testSphere(boundingSphere.x, boundingSphere.y, boundingSphere.z, boundingSphere.w)) { + continue; + } + + final int baseVertex = ctx.vertexCounter.getAndAdd(meshVertexCount); + vertexList.ptr(ctx.buffer.ptrForVertex(baseVertex)); + + mesh.copyTo(vertexList.ptr()); + instanceVertexTransformer.transform(vertexList, instance); + materialVertexTransformer.transform(vertexList, ctx.level); + applyMatrices(vertexList, ctx.matrices); } - - final int baseVertex = ctx.vertexCounter.getAndAdd(meshVertexCount); - var sub = ctx.buffer.slice(baseVertex, meshVertexCount); - - mesh.copyTo(sub.ptr()); - instanceVertexTransformer.transform(sub, instance); - materialVertexTransformer.transform(sub, ctx.level); - applyMatrices(sub, ctx.matrices); }); } diff --git a/src/main/java/com/jozufozu/flywheel/backend/task/ParallelTaskExecutor.java b/src/main/java/com/jozufozu/flywheel/backend/task/ParallelTaskExecutor.java index 2661c1313..bfc6aa064 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/task/ParallelTaskExecutor.java +++ b/src/main/java/com/jozufozu/flywheel/backend/task/ParallelTaskExecutor.java @@ -145,10 +145,9 @@ public class ParallelTaskExecutor implements TaskExecutor { processTask(task); } else { // then wait for the other threads to finish. - waitGroup.await(); - // at this point we know taskQueue is empty, - // but one of the worker threads may have submitted a main thread task. - if (mainThreadQueue.isEmpty()) { + boolean done = waitGroup.await(10_000); + // If we timed-out tasks may have been added to the queue, so check again. + if (done && mainThreadQueue.isEmpty()) { // if they didn't, we're done. break; } @@ -157,13 +156,17 @@ public class ParallelTaskExecutor implements TaskExecutor { } public void discardAndAwait() { - // Discard everyone else's work... - while (taskQueue.pollLast() != null) { - waitGroup.done(); - } + while (true) { + // Discard everyone else's work... + while (taskQueue.pollLast() != null) { + waitGroup.done(); + } - // ...wait for any stragglers... - waitGroup.await(); + // ...wait for any stragglers... + if (waitGroup.await(100_000)) { + break; + } + } // ...and clear the main thread queue. mainThreadQueue.clear(); } diff --git a/src/main/java/com/jozufozu/flywheel/lib/task/ForEachPlan.java b/src/main/java/com/jozufozu/flywheel/lib/task/ForEachPlan.java index 5ae77ba4b..221322ad6 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/task/ForEachPlan.java +++ b/src/main/java/com/jozufozu/flywheel/lib/task/ForEachPlan.java @@ -8,8 +8,17 @@ import java.util.function.Supplier; import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.TaskExecutor; -public record ForEachPlan(Supplier> listSupplier, - BiConsumer action) implements SimplyComposedPlan { +/** + * A plan that executes code on each element of a provided list. + *

+ * Operations are dynamically batched based on the number of available threads. + * + * @param listSupplier A supplier of the list to iterate over. + * @param action The action to perform on each element. + * @param The type of the list elements. + * @param The type of the context object. + */ +public record ForEachPlan(Supplier> listSupplier, BiConsumer action) implements SimplyComposedPlan { public static Plan of(Supplier> iterable, BiConsumer forEach) { return new ForEachPlan<>(iterable, forEach); } diff --git a/src/main/java/com/jozufozu/flywheel/lib/task/ForEachSlicePlan.java b/src/main/java/com/jozufozu/flywheel/lib/task/ForEachSlicePlan.java new file mode 100644 index 000000000..d090834e0 --- /dev/null +++ b/src/main/java/com/jozufozu/flywheel/lib/task/ForEachSlicePlan.java @@ -0,0 +1,30 @@ +package com.jozufozu.flywheel.lib.task; + +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import com.jozufozu.flywheel.api.task.Plan; +import com.jozufozu.flywheel.api.task.TaskExecutor; + +/** + * A plan that executes code over many slices of a provided list. + *

+ * The size of the slice is dynamically determined based on the number of available threads. + * + * @param listSupplier A supplier of the list to iterate over. + * @param action The action to perform on each sub list. + * @param The type of the list elements. + * @param The type of the context object. + */ +public record ForEachSlicePlan(Supplier> listSupplier, + BiConsumer, C> action) implements SimplyComposedPlan { + public static Plan of(Supplier> iterable, BiConsumer, C> forEach) { + return new ForEachSlicePlan<>(iterable, forEach); + } + + @Override + public void execute(TaskExecutor taskExecutor, C context, Runnable onCompletion) { + taskExecutor.execute(() -> PlanUtil.distributeSlices(taskExecutor, context, onCompletion, listSupplier.get(), action)); + } +} diff --git a/src/main/java/com/jozufozu/flywheel/lib/task/PlanUtil.java b/src/main/java/com/jozufozu/flywheel/lib/task/PlanUtil.java index 430dff57a..352cc249e 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/task/PlanUtil.java +++ b/src/main/java/com/jozufozu/flywheel/lib/task/PlanUtil.java @@ -1,5 +1,6 @@ package com.jozufozu.flywheel.lib.task; +import java.util.Collections; import java.util.List; import java.util.function.BiConsumer; @@ -12,38 +13,84 @@ public class PlanUtil { if (size == 0) { onCompletion.run(); - } else if (size <= getChunkSize(taskExecutor, size)) { - processList(context, onCompletion, list, action); + return; + } + + final int sliceSize = sliceSize(taskExecutor, size); + + if (size <= sliceSize) { + for (T t : list) { + action.accept(t, context); + } + onCompletion.run(); + } else if (sliceSize == 1) { + var synchronizer = new Synchronizer(size, onCompletion); + for (T t : list) { + taskExecutor.execute(() -> { + action.accept(t, context); + synchronizer.decrementAndEventuallyRun(); + }); + } } else { - dispatchChunks(taskExecutor, context, onCompletion, list, action); + var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, sliceSize), onCompletion); + int remaining = size; + + while (remaining > 0) { + int end = remaining; + remaining -= sliceSize; + int start = Math.max(remaining, 0); + + var subList = list.subList(start, end); + taskExecutor.execute(() -> { + for (T t : subList) { + action.accept(t, context); + } + synchronizer.decrementAndEventuallyRun(); + }); + } } } - public static int getChunkSize(TaskExecutor taskExecutor, int totalSize) { + public static void distributeSlices(TaskExecutor taskExecutor, C context, Runnable onCompletion, List list, BiConsumer, C> action) { + final int size = list.size(); + + if (size == 0) { + onCompletion.run(); + return; + } + + final int sliceSize = sliceSize(taskExecutor, size); + + if (size <= sliceSize) { + action.accept(list, context); + onCompletion.run(); + } else if (sliceSize == 1) { + var synchronizer = new Synchronizer(size, onCompletion); + for (T t : list) { + taskExecutor.execute(() -> { + action.accept(Collections.singletonList(t), context); + synchronizer.decrementAndEventuallyRun(); + }); + } + } else { + var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, sliceSize), onCompletion); + int remaining = size; + + while (remaining > 0) { + int end = remaining; + remaining -= sliceSize; + int start = Math.max(remaining, 0); + + var subList = list.subList(start, end); + taskExecutor.execute(() -> { + action.accept(subList, context); + synchronizer.decrementAndEventuallyRun(); + }); + } + } + } + + public static int sliceSize(TaskExecutor taskExecutor, int totalSize) { return MoreMath.ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32); } - - static void dispatchChunks(TaskExecutor taskExecutor, C context, Runnable onCompletion, List list, BiConsumer action) { - final int size = list.size(); - final int chunkSize = getChunkSize(taskExecutor, size); - - var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, chunkSize), onCompletion); - int remaining = size; - - while (remaining > 0) { - int end = remaining; - remaining -= chunkSize; - int start = Math.max(remaining, 0); - - var subList = list.subList(start, end); - taskExecutor.execute(() -> processList(context, synchronizer, subList, action)); - } - } - - static void processList(C context, Runnable onCompletion, List list, BiConsumer action) { - for (var t : list) { - action.accept(t, context); - } - onCompletion.run(); - } } diff --git a/src/main/java/com/jozufozu/flywheel/lib/task/WaitGroup.java b/src/main/java/com/jozufozu/flywheel/lib/task/WaitGroup.java index 92411bac2..793932fc0 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/task/WaitGroup.java +++ b/src/main/java/com/jozufozu/flywheel/lib/task/WaitGroup.java @@ -31,12 +31,22 @@ public class WaitGroup { } } - public void await() { - // TODO: comprehensive performance tracking for tasks + /** + * Spins for up to the given number of nanoseconds before returning. + * + * @param nsTimeout How long to wait for the counter to reach 0. + * @return {@code true} if the counter reached 0, {@code false} if the timeout was reached. + */ + public boolean await(int nsTimeout) { + long startTime = System.nanoTime(); while (counter.get() > 0) { + if (System.nanoTime() - startTime > nsTimeout) { + return false; + } // spin in place to avoid sleeping the main thread Thread.onSpinWait(); } + return true; } public void _reset() { diff --git a/src/main/java/com/jozufozu/flywheel/lib/vertex/VertexTransformations.java b/src/main/java/com/jozufozu/flywheel/lib/vertex/VertexTransformations.java index 9842fa430..6d87ca504 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/vertex/VertexTransformations.java +++ b/src/main/java/com/jozufozu/flywheel/lib/vertex/VertexTransformations.java @@ -1,8 +1,5 @@ package com.jozufozu.flywheel.lib.vertex; -import static org.joml.Math.fma; -import static org.joml.Math.invsqrt; - import com.jozufozu.flywheel.api.vertex.MutableVertexList; import com.jozufozu.flywheel.lib.math.MatrixUtil; import com.mojang.math.Matrix3f; @@ -18,6 +15,9 @@ public final class VertexTransformations { vertexList.z(index, MatrixUtil.transformPositionZ(matrix, x, y, z)); } + /** + * Assumes the matrix preserves scale. + */ public static void transformNormal(MutableVertexList vertexList, int index, Matrix3f matrix) { float nx = vertexList.normalX(index); float ny = vertexList.normalY(index); @@ -25,13 +25,14 @@ public final class VertexTransformations { float tnx = MatrixUtil.transformNormalX(matrix, nx, ny, nz); float tny = MatrixUtil.transformNormalY(matrix, nx, ny, nz); float tnz = MatrixUtil.transformNormalZ(matrix, nx, ny, nz); - float sqrLength = fma(tnx, tnx, fma(tny, tny, tnz * tnz)); - if (sqrLength != 0) { - float f = invsqrt(sqrLength); - tnx *= f; - tny *= f; - tnz *= f; - } + // seems to be the case that sqrLength is always ~1.0 + // float sqrLength = fma(tnx, tnx, fma(tny, tny, tnz * tnz)); + // if (sqrLength != 0) { + // float f = invsqrt(sqrLength); + // tnx *= f; + // tny *= f; + // tnz *= f; + // } vertexList.normalX(index, tnx); vertexList.normalY(index, tny); vertexList.normalZ(index, tnz);