Barely better batches

- TransformCall now uses ForEachSlicePlan to reduce the number of
  objects created.
- WaitGroup#await can now timeout. This allows the main thread to
  contribute more work in a syncPoint.
- Don't normalize in transformNormal, things already are normalized.
This commit is contained in:
Jozufozu 2023-05-29 20:29:46 -07:00
parent 257ee07e0e
commit fcd70cccd0
7 changed files with 172 additions and 67 deletions

View file

@ -13,7 +13,8 @@ import com.jozufozu.flywheel.api.material.Material;
import com.jozufozu.flywheel.api.material.MaterialVertexTransformer; import com.jozufozu.flywheel.api.material.MaterialVertexTransformer;
import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.vertex.MutableVertexList; import com.jozufozu.flywheel.api.vertex.MutableVertexList;
import com.jozufozu.flywheel.lib.task.ForEachPlan; import com.jozufozu.flywheel.api.vertex.ReusableVertexList;
import com.jozufozu.flywheel.lib.task.ForEachSlicePlan;
import com.jozufozu.flywheel.lib.vertex.VertexTransformations; import com.jozufozu.flywheel.lib.vertex.VertexTransformations;
import com.mojang.blaze3d.vertex.PoseStack; import com.mojang.blaze3d.vertex.PoseStack;
import com.mojang.math.Matrix3f; import com.mojang.math.Matrix3f;
@ -37,22 +38,26 @@ public class TransformCall<I extends Instance> {
meshVertexCount = mesh.getVertexCount(); meshVertexCount = mesh.getVertexCount();
Vector4fc meshBoundingSphere = mesh.boundingSphere(); Vector4fc meshBoundingSphere = mesh.boundingSphere();
drawPlan = ForEachPlan.of(instancer::getAll, (instance, ctx) -> { drawPlan = ForEachSlicePlan.of(instancer::getAll, (subList, ctx) -> {
var boundingSphere = new Vector4f(meshBoundingSphere); ReusableVertexList vertexList = ctx.buffer.slice(0, meshVertexCount);
Vector4f boundingSphere = new Vector4f();
for (I instance : subList) {
boundingSphere.set(meshBoundingSphere);
boundingSphereTransformer.transform(boundingSphere, instance); boundingSphereTransformer.transform(boundingSphere, instance);
if (!ctx.frustum if (!ctx.frustum.testSphere(boundingSphere.x, boundingSphere.y, boundingSphere.z, boundingSphere.w)) {
.testSphere(boundingSphere.x, boundingSphere.y, boundingSphere.z, boundingSphere.w)) { continue;
return;
} }
final int baseVertex = ctx.vertexCounter.getAndAdd(meshVertexCount); final int baseVertex = ctx.vertexCounter.getAndAdd(meshVertexCount);
var sub = ctx.buffer.slice(baseVertex, meshVertexCount); vertexList.ptr(ctx.buffer.ptrForVertex(baseVertex));
mesh.copyTo(sub.ptr()); mesh.copyTo(vertexList.ptr());
instanceVertexTransformer.transform(sub, instance); instanceVertexTransformer.transform(vertexList, instance);
materialVertexTransformer.transform(sub, ctx.level); materialVertexTransformer.transform(vertexList, ctx.level);
applyMatrices(sub, ctx.matrices); applyMatrices(vertexList, ctx.matrices);
}
}); });
} }

View file

@ -145,10 +145,9 @@ public class ParallelTaskExecutor implements TaskExecutor {
processTask(task); processTask(task);
} else { } else {
// then wait for the other threads to finish. // then wait for the other threads to finish.
waitGroup.await(); boolean done = waitGroup.await(10_000);
// at this point we know taskQueue is empty, // If we timed-out tasks may have been added to the queue, so check again.
// but one of the worker threads may have submitted a main thread task. if (done && mainThreadQueue.isEmpty()) {
if (mainThreadQueue.isEmpty()) {
// if they didn't, we're done. // if they didn't, we're done.
break; break;
} }
@ -157,13 +156,17 @@ public class ParallelTaskExecutor implements TaskExecutor {
} }
public void discardAndAwait() { public void discardAndAwait() {
while (true) {
// Discard everyone else's work... // Discard everyone else's work...
while (taskQueue.pollLast() != null) { while (taskQueue.pollLast() != null) {
waitGroup.done(); waitGroup.done();
} }
// ...wait for any stragglers... // ...wait for any stragglers...
waitGroup.await(); if (waitGroup.await(100_000)) {
break;
}
}
// ...and clear the main thread queue. // ...and clear the main thread queue.
mainThreadQueue.clear(); mainThreadQueue.clear();
} }

View file

@ -8,8 +8,17 @@ import java.util.function.Supplier;
import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor; import com.jozufozu.flywheel.api.task.TaskExecutor;
public record ForEachPlan<T, C>(Supplier<List<T>> listSupplier, /**
BiConsumer<T, C> action) implements SimplyComposedPlan<C> { * A plan that executes code on each element of a provided list.
* <p>
* Operations are dynamically batched based on the number of available threads.
*
* @param listSupplier A supplier of the list to iterate over.
* @param action The action to perform on each element.
* @param <T> The type of the list elements.
* @param <C> The type of the context object.
*/
public record ForEachPlan<T, C>(Supplier<List<T>> listSupplier, BiConsumer<T, C> action) implements SimplyComposedPlan<C> {
public static <T, C> Plan<C> of(Supplier<List<T>> iterable, BiConsumer<T, C> forEach) { public static <T, C> Plan<C> of(Supplier<List<T>> iterable, BiConsumer<T, C> forEach) {
return new ForEachPlan<>(iterable, forEach); return new ForEachPlan<>(iterable, forEach);
} }

View file

@ -0,0 +1,30 @@
package com.jozufozu.flywheel.lib.task;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor;
/**
* A plan that executes code over many slices of a provided list.
* <p>
* The size of the slice is dynamically determined based on the number of available threads.
*
* @param listSupplier A supplier of the list to iterate over.
* @param action The action to perform on each sub list.
* @param <T> The type of the list elements.
* @param <C> The type of the context object.
*/
public record ForEachSlicePlan<T, C>(Supplier<List<T>> listSupplier,
BiConsumer<List<T>, C> action) implements SimplyComposedPlan<C> {
public static <T, C> Plan<C> of(Supplier<List<T>> iterable, BiConsumer<List<T>, C> forEach) {
return new ForEachSlicePlan<>(iterable, forEach);
}
@Override
public void execute(TaskExecutor taskExecutor, C context, Runnable onCompletion) {
taskExecutor.execute(() -> PlanUtil.distributeSlices(taskExecutor, context, onCompletion, listSupplier.get(), action));
}
}

View file

@ -1,5 +1,6 @@
package com.jozufozu.flywheel.lib.task; package com.jozufozu.flywheel.lib.task;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
@ -12,38 +13,84 @@ public class PlanUtil {
if (size == 0) { if (size == 0) {
onCompletion.run(); onCompletion.run();
} else if (size <= getChunkSize(taskExecutor, size)) { return;
processList(context, onCompletion, list, action); }
final int sliceSize = sliceSize(taskExecutor, size);
if (size <= sliceSize) {
for (T t : list) {
action.accept(t, context);
}
onCompletion.run();
} else if (sliceSize == 1) {
var synchronizer = new Synchronizer(size, onCompletion);
for (T t : list) {
taskExecutor.execute(() -> {
action.accept(t, context);
synchronizer.decrementAndEventuallyRun();
});
}
} else { } else {
dispatchChunks(taskExecutor, context, onCompletion, list, action); var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, sliceSize), onCompletion);
}
}
public static int getChunkSize(TaskExecutor taskExecutor, int totalSize) {
return MoreMath.ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32);
}
static <C, T> void dispatchChunks(TaskExecutor taskExecutor, C context, Runnable onCompletion, List<T> list, BiConsumer<T, C> action) {
final int size = list.size();
final int chunkSize = getChunkSize(taskExecutor, size);
var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, chunkSize), onCompletion);
int remaining = size; int remaining = size;
while (remaining > 0) { while (remaining > 0) {
int end = remaining; int end = remaining;
remaining -= chunkSize; remaining -= sliceSize;
int start = Math.max(remaining, 0); int start = Math.max(remaining, 0);
var subList = list.subList(start, end); var subList = list.subList(start, end);
taskExecutor.execute(() -> processList(context, synchronizer, subList, action)); taskExecutor.execute(() -> {
for (T t : subList) {
action.accept(t, context);
}
synchronizer.decrementAndEventuallyRun();
});
}
} }
} }
static <C, T> void processList(C context, Runnable onCompletion, List<T> list, BiConsumer<T, C> action) { public static <C, T> void distributeSlices(TaskExecutor taskExecutor, C context, Runnable onCompletion, List<T> list, BiConsumer<List<T>, C> action) {
for (var t : list) { final int size = list.size();
action.accept(t, context);
} if (size == 0) {
onCompletion.run(); onCompletion.run();
return;
}
final int sliceSize = sliceSize(taskExecutor, size);
if (size <= sliceSize) {
action.accept(list, context);
onCompletion.run();
} else if (sliceSize == 1) {
var synchronizer = new Synchronizer(size, onCompletion);
for (T t : list) {
taskExecutor.execute(() -> {
action.accept(Collections.singletonList(t), context);
synchronizer.decrementAndEventuallyRun();
});
}
} else {
var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, sliceSize), onCompletion);
int remaining = size;
while (remaining > 0) {
int end = remaining;
remaining -= sliceSize;
int start = Math.max(remaining, 0);
var subList = list.subList(start, end);
taskExecutor.execute(() -> {
action.accept(subList, context);
synchronizer.decrementAndEventuallyRun();
});
}
}
}
public static int sliceSize(TaskExecutor taskExecutor, int totalSize) {
return MoreMath.ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32);
} }
} }

View file

@ -31,12 +31,22 @@ public class WaitGroup {
} }
} }
public void await() { /**
// TODO: comprehensive performance tracking for tasks * Spins for up to the given number of nanoseconds before returning.
*
* @param nsTimeout How long to wait for the counter to reach 0.
* @return {@code true} if the counter reached 0, {@code false} if the timeout was reached.
*/
public boolean await(int nsTimeout) {
long startTime = System.nanoTime();
while (counter.get() > 0) { while (counter.get() > 0) {
if (System.nanoTime() - startTime > nsTimeout) {
return false;
}
// spin in place to avoid sleeping the main thread // spin in place to avoid sleeping the main thread
Thread.onSpinWait(); Thread.onSpinWait();
} }
return true;
} }
public void _reset() { public void _reset() {

View file

@ -1,8 +1,5 @@
package com.jozufozu.flywheel.lib.vertex; package com.jozufozu.flywheel.lib.vertex;
import static org.joml.Math.fma;
import static org.joml.Math.invsqrt;
import com.jozufozu.flywheel.api.vertex.MutableVertexList; import com.jozufozu.flywheel.api.vertex.MutableVertexList;
import com.jozufozu.flywheel.lib.math.MatrixUtil; import com.jozufozu.flywheel.lib.math.MatrixUtil;
import com.mojang.math.Matrix3f; import com.mojang.math.Matrix3f;
@ -18,6 +15,9 @@ public final class VertexTransformations {
vertexList.z(index, MatrixUtil.transformPositionZ(matrix, x, y, z)); vertexList.z(index, MatrixUtil.transformPositionZ(matrix, x, y, z));
} }
/**
* Assumes the matrix preserves scale.
*/
public static void transformNormal(MutableVertexList vertexList, int index, Matrix3f matrix) { public static void transformNormal(MutableVertexList vertexList, int index, Matrix3f matrix) {
float nx = vertexList.normalX(index); float nx = vertexList.normalX(index);
float ny = vertexList.normalY(index); float ny = vertexList.normalY(index);
@ -25,13 +25,14 @@ public final class VertexTransformations {
float tnx = MatrixUtil.transformNormalX(matrix, nx, ny, nz); float tnx = MatrixUtil.transformNormalX(matrix, nx, ny, nz);
float tny = MatrixUtil.transformNormalY(matrix, nx, ny, nz); float tny = MatrixUtil.transformNormalY(matrix, nx, ny, nz);
float tnz = MatrixUtil.transformNormalZ(matrix, nx, ny, nz); float tnz = MatrixUtil.transformNormalZ(matrix, nx, ny, nz);
float sqrLength = fma(tnx, tnx, fma(tny, tny, tnz * tnz)); // seems to be the case that sqrLength is always ~1.0
if (sqrLength != 0) { // float sqrLength = fma(tnx, tnx, fma(tny, tny, tnz * tnz));
float f = invsqrt(sqrLength); // if (sqrLength != 0) {
tnx *= f; // float f = invsqrt(sqrLength);
tny *= f; // tnx *= f;
tnz *= f; // tny *= f;
} // tnz *= f;
// }
vertexList.normalX(index, tnx); vertexList.normalX(index, tnx);
vertexList.normalY(index, tny); vertexList.normalY(index, tny);
vertexList.normalZ(index, tnz); vertexList.normalZ(index, tnz);