Barely better batches

- TransformCall now uses ForEachSlicePlan to reduce the number of
  objects created.
- WaitGroup#await can now timeout. This allows the main thread to
  contribute more work in a syncPoint.
- Don't normalize in transformNormal, things already are normalized.
This commit is contained in:
Jozufozu 2023-05-29 20:29:46 -07:00
parent 257ee07e0e
commit fcd70cccd0
7 changed files with 172 additions and 67 deletions

View file

@ -13,7 +13,8 @@ import com.jozufozu.flywheel.api.material.Material;
import com.jozufozu.flywheel.api.material.MaterialVertexTransformer;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.vertex.MutableVertexList;
import com.jozufozu.flywheel.lib.task.ForEachPlan;
import com.jozufozu.flywheel.api.vertex.ReusableVertexList;
import com.jozufozu.flywheel.lib.task.ForEachSlicePlan;
import com.jozufozu.flywheel.lib.vertex.VertexTransformations;
import com.mojang.blaze3d.vertex.PoseStack;
import com.mojang.math.Matrix3f;
@ -37,22 +38,26 @@ public class TransformCall<I extends Instance> {
meshVertexCount = mesh.getVertexCount();
Vector4fc meshBoundingSphere = mesh.boundingSphere();
drawPlan = ForEachPlan.of(instancer::getAll, (instance, ctx) -> {
var boundingSphere = new Vector4f(meshBoundingSphere);
drawPlan = ForEachSlicePlan.of(instancer::getAll, (subList, ctx) -> {
ReusableVertexList vertexList = ctx.buffer.slice(0, meshVertexCount);
Vector4f boundingSphere = new Vector4f();
for (I instance : subList) {
boundingSphere.set(meshBoundingSphere);
boundingSphereTransformer.transform(boundingSphere, instance);
if (!ctx.frustum
.testSphere(boundingSphere.x, boundingSphere.y, boundingSphere.z, boundingSphere.w)) {
return;
if (!ctx.frustum.testSphere(boundingSphere.x, boundingSphere.y, boundingSphere.z, boundingSphere.w)) {
continue;
}
final int baseVertex = ctx.vertexCounter.getAndAdd(meshVertexCount);
var sub = ctx.buffer.slice(baseVertex, meshVertexCount);
vertexList.ptr(ctx.buffer.ptrForVertex(baseVertex));
mesh.copyTo(sub.ptr());
instanceVertexTransformer.transform(sub, instance);
materialVertexTransformer.transform(sub, ctx.level);
applyMatrices(sub, ctx.matrices);
mesh.copyTo(vertexList.ptr());
instanceVertexTransformer.transform(vertexList, instance);
materialVertexTransformer.transform(vertexList, ctx.level);
applyMatrices(vertexList, ctx.matrices);
}
});
}

View file

@ -145,10 +145,9 @@ public class ParallelTaskExecutor implements TaskExecutor {
processTask(task);
} else {
// then wait for the other threads to finish.
waitGroup.await();
// at this point we know taskQueue is empty,
// but one of the worker threads may have submitted a main thread task.
if (mainThreadQueue.isEmpty()) {
boolean done = waitGroup.await(10_000);
// If we timed-out tasks may have been added to the queue, so check again.
if (done && mainThreadQueue.isEmpty()) {
// if they didn't, we're done.
break;
}
@ -157,13 +156,17 @@ public class ParallelTaskExecutor implements TaskExecutor {
}
public void discardAndAwait() {
while (true) {
// Discard everyone else's work...
while (taskQueue.pollLast() != null) {
waitGroup.done();
}
// ...wait for any stragglers...
waitGroup.await();
if (waitGroup.await(100_000)) {
break;
}
}
// ...and clear the main thread queue.
mainThreadQueue.clear();
}

View file

@ -8,8 +8,17 @@ import java.util.function.Supplier;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor;
public record ForEachPlan<T, C>(Supplier<List<T>> listSupplier,
BiConsumer<T, C> action) implements SimplyComposedPlan<C> {
/**
* A plan that executes code on each element of a provided list.
* <p>
* Operations are dynamically batched based on the number of available threads.
*
* @param listSupplier A supplier of the list to iterate over.
* @param action The action to perform on each element.
* @param <T> The type of the list elements.
* @param <C> The type of the context object.
*/
public record ForEachPlan<T, C>(Supplier<List<T>> listSupplier, BiConsumer<T, C> action) implements SimplyComposedPlan<C> {
public static <T, C> Plan<C> of(Supplier<List<T>> iterable, BiConsumer<T, C> forEach) {
return new ForEachPlan<>(iterable, forEach);
}

View file

@ -0,0 +1,30 @@
package com.jozufozu.flywheel.lib.task;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor;
/**
* A plan that executes code over many slices of a provided list.
* <p>
* The size of the slice is dynamically determined based on the number of available threads.
*
* @param listSupplier A supplier of the list to iterate over.
* @param action The action to perform on each sub list.
* @param <T> The type of the list elements.
* @param <C> The type of the context object.
*/
public record ForEachSlicePlan<T, C>(Supplier<List<T>> listSupplier,
BiConsumer<List<T>, C> action) implements SimplyComposedPlan<C> {
public static <T, C> Plan<C> of(Supplier<List<T>> iterable, BiConsumer<List<T>, C> forEach) {
return new ForEachSlicePlan<>(iterable, forEach);
}
@Override
public void execute(TaskExecutor taskExecutor, C context, Runnable onCompletion) {
taskExecutor.execute(() -> PlanUtil.distributeSlices(taskExecutor, context, onCompletion, listSupplier.get(), action));
}
}

View file

@ -1,5 +1,6 @@
package com.jozufozu.flywheel.lib.task;
import java.util.Collections;
import java.util.List;
import java.util.function.BiConsumer;
@ -12,38 +13,84 @@ public class PlanUtil {
if (size == 0) {
onCompletion.run();
} else if (size <= getChunkSize(taskExecutor, size)) {
processList(context, onCompletion, list, action);
return;
}
final int sliceSize = sliceSize(taskExecutor, size);
if (size <= sliceSize) {
for (T t : list) {
action.accept(t, context);
}
onCompletion.run();
} else if (sliceSize == 1) {
var synchronizer = new Synchronizer(size, onCompletion);
for (T t : list) {
taskExecutor.execute(() -> {
action.accept(t, context);
synchronizer.decrementAndEventuallyRun();
});
}
} else {
dispatchChunks(taskExecutor, context, onCompletion, list, action);
}
}
public static int getChunkSize(TaskExecutor taskExecutor, int totalSize) {
return MoreMath.ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32);
}
static <C, T> void dispatchChunks(TaskExecutor taskExecutor, C context, Runnable onCompletion, List<T> list, BiConsumer<T, C> action) {
final int size = list.size();
final int chunkSize = getChunkSize(taskExecutor, size);
var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, chunkSize), onCompletion);
var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, sliceSize), onCompletion);
int remaining = size;
while (remaining > 0) {
int end = remaining;
remaining -= chunkSize;
remaining -= sliceSize;
int start = Math.max(remaining, 0);
var subList = list.subList(start, end);
taskExecutor.execute(() -> processList(context, synchronizer, subList, action));
taskExecutor.execute(() -> {
for (T t : subList) {
action.accept(t, context);
}
synchronizer.decrementAndEventuallyRun();
});
}
}
}
static <C, T> void processList(C context, Runnable onCompletion, List<T> list, BiConsumer<T, C> action) {
for (var t : list) {
action.accept(t, context);
}
public static <C, T> void distributeSlices(TaskExecutor taskExecutor, C context, Runnable onCompletion, List<T> list, BiConsumer<List<T>, C> action) {
final int size = list.size();
if (size == 0) {
onCompletion.run();
return;
}
final int sliceSize = sliceSize(taskExecutor, size);
if (size <= sliceSize) {
action.accept(list, context);
onCompletion.run();
} else if (sliceSize == 1) {
var synchronizer = new Synchronizer(size, onCompletion);
for (T t : list) {
taskExecutor.execute(() -> {
action.accept(Collections.singletonList(t), context);
synchronizer.decrementAndEventuallyRun();
});
}
} else {
var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, sliceSize), onCompletion);
int remaining = size;
while (remaining > 0) {
int end = remaining;
remaining -= sliceSize;
int start = Math.max(remaining, 0);
var subList = list.subList(start, end);
taskExecutor.execute(() -> {
action.accept(subList, context);
synchronizer.decrementAndEventuallyRun();
});
}
}
}
public static int sliceSize(TaskExecutor taskExecutor, int totalSize) {
return MoreMath.ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32);
}
}

View file

@ -31,12 +31,22 @@ public class WaitGroup {
}
}
public void await() {
// TODO: comprehensive performance tracking for tasks
/**
* Spins for up to the given number of nanoseconds before returning.
*
* @param nsTimeout How long to wait for the counter to reach 0.
* @return {@code true} if the counter reached 0, {@code false} if the timeout was reached.
*/
public boolean await(int nsTimeout) {
long startTime = System.nanoTime();
while (counter.get() > 0) {
if (System.nanoTime() - startTime > nsTimeout) {
return false;
}
// spin in place to avoid sleeping the main thread
Thread.onSpinWait();
}
return true;
}
public void _reset() {

View file

@ -1,8 +1,5 @@
package com.jozufozu.flywheel.lib.vertex;
import static org.joml.Math.fma;
import static org.joml.Math.invsqrt;
import com.jozufozu.flywheel.api.vertex.MutableVertexList;
import com.jozufozu.flywheel.lib.math.MatrixUtil;
import com.mojang.math.Matrix3f;
@ -18,6 +15,9 @@ public final class VertexTransformations {
vertexList.z(index, MatrixUtil.transformPositionZ(matrix, x, y, z));
}
/**
* Assumes the matrix preserves scale.
*/
public static void transformNormal(MutableVertexList vertexList, int index, Matrix3f matrix) {
float nx = vertexList.normalX(index);
float ny = vertexList.normalY(index);
@ -25,13 +25,14 @@ public final class VertexTransformations {
float tnx = MatrixUtil.transformNormalX(matrix, nx, ny, nz);
float tny = MatrixUtil.transformNormalY(matrix, nx, ny, nz);
float tnz = MatrixUtil.transformNormalZ(matrix, nx, ny, nz);
float sqrLength = fma(tnx, tnx, fma(tny, tny, tnz * tnz));
if (sqrLength != 0) {
float f = invsqrt(sqrLength);
tnx *= f;
tny *= f;
tnz *= f;
}
// seems to be the case that sqrLength is always ~1.0
// float sqrLength = fma(tnx, tnx, fma(tny, tny, tnz * tnz));
// if (sqrLength != 0) {
// float f = invsqrt(sqrLength);
// tnx *= f;
// tny *= f;
// tnz *= f;
// }
vertexList.normalX(index, tnx);
vertexList.normalY(index, tny);
vertexList.normalZ(index, tnz);