Smaller batches

- Extract TransformSet to full class
- Move planning code into BatchingTransformManager
- Reduce exposure BatchingTransformManager of internals
- Comment out WaitGroup log :ioa:
- Move ceilingDiv to MoreMath
- Use dynamic chunk size for TransformCalls
This commit is contained in:
Jozufozu 2023-04-16 12:46:20 -07:00
parent cb74fa4603
commit e03590b270
10 changed files with 114 additions and 126 deletions

View file

@ -22,7 +22,7 @@ public class BatchedMeshPool {
private final Map<Mesh, BufferedMesh> meshes = new HashMap<>(); private final Map<Mesh, BufferedMesh> meshes = new HashMap<>();
private final List<BufferedMesh> allBuffered = new ArrayList<>(); private final List<BufferedMesh> allBuffered = new ArrayList<>();
private final List<BufferedMesh> pendingUpload = new ArrayList<>(); private final List<BufferedMesh> pendingBuffer = new ArrayList<>();
private MemoryBlock memory; private MemoryBlock memory;
private long byteSize; private long byteSize;
@ -54,7 +54,7 @@ public class BatchedMeshPool {
BufferedMesh bufferedMesh = new BufferedMesh(m, byteSize); BufferedMesh bufferedMesh = new BufferedMesh(m, byteSize);
byteSize += bufferedMesh.size(); byteSize += bufferedMesh.size();
allBuffered.add(bufferedMesh); allBuffered.add(bufferedMesh);
pendingUpload.add(bufferedMesh); pendingBuffer.add(bufferedMesh);
dirty = true; dirty = true;
return bufferedMesh; return bufferedMesh;
@ -73,10 +73,10 @@ public class BatchedMeshPool {
} }
realloc(); realloc();
uploadPending(); bufferPending();
dirty = false; dirty = false;
pendingUpload.clear(); pendingBuffer.clear();
} }
} }
@ -94,7 +94,7 @@ public class BatchedMeshPool {
int byteIndex = 0; int byteIndex = 0;
for (BufferedMesh mesh : allBuffered) { for (BufferedMesh mesh : allBuffered) {
if (mesh.byteIndex != byteIndex) { if (mesh.byteIndex != byteIndex) {
pendingUpload.add(mesh); pendingBuffer.add(mesh);
} }
mesh.byteIndex = byteIndex; mesh.byteIndex = byteIndex;
@ -122,13 +122,13 @@ public class BatchedMeshPool {
} }
} }
private void uploadPending() { private void bufferPending() {
try { try {
for (BufferedMesh mesh : pendingUpload) { for (BufferedMesh mesh : pendingBuffer) {
mesh.buffer(vertexList); mesh.buffer(vertexList);
} }
pendingUpload.clear(); pendingBuffer.clear();
} catch (Exception e) { } catch (Exception e) {
Flywheel.LOGGER.error("Error uploading pooled meshes:", e); Flywheel.LOGGER.error("Error uploading pooled meshes:", e);
} }
@ -140,7 +140,7 @@ public class BatchedMeshPool {
} }
meshes.clear(); meshes.clear();
allBuffered.clear(); allBuffered.clear();
pendingUpload.clear(); pendingBuffer.clear();
} }
@Override @Override

View file

@ -1,6 +1,5 @@
package com.jozufozu.flywheel.backend.engine.batching; package com.jozufozu.flywheel.backend.engine.batching;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import com.jozufozu.flywheel.api.event.RenderContext; import com.jozufozu.flywheel.api.event.RenderContext;
@ -12,8 +11,6 @@ import com.jozufozu.flywheel.api.model.Model;
import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor; import com.jozufozu.flywheel.api.task.TaskExecutor;
import com.jozufozu.flywheel.backend.engine.AbstractEngine; import com.jozufozu.flywheel.backend.engine.AbstractEngine;
import com.jozufozu.flywheel.lib.task.NestedPlan;
import com.jozufozu.flywheel.lib.task.PlanUtil;
import com.jozufozu.flywheel.util.FlwUtil; import com.jozufozu.flywheel.util.FlwUtil;
import net.minecraft.world.phys.Vec3; import net.minecraft.world.phys.Vec3;
@ -33,52 +30,12 @@ public class BatchingEngine extends AbstractEngine {
@Override @Override
public Plan planThisFrame(RenderContext context) { public Plan planThisFrame(RenderContext context) {
return PlanUtil.of(transformManager::flush)
.then(planTransformers(context));
}
private Plan planTransformers(RenderContext context) {
Vec3 cameraPos = context.camera() Vec3 cameraPos = context.camera()
.getPosition(); .getPosition();
var stack = FlwUtil.copyPoseStack(context.stack()); var stack = FlwUtil.copyPoseStack(context.stack());
stack.translate(renderOrigin.getX() - cameraPos.x, renderOrigin.getY() - cameraPos.y, renderOrigin.getZ() - cameraPos.z); stack.translate(renderOrigin.getX() - cameraPos.x, renderOrigin.getY() - cameraPos.y, renderOrigin.getZ() - cameraPos.z);
var matrices = stack.last(); return transformManager.plan(stack.last(), context.level(), drawTracker);
var level = context.level();
var plans = new ArrayList<Plan>();
for (var transformSetEntry : transformManager.getTransformSetsView()
.entrySet()) {
var stage = transformSetEntry.getKey();
var transformSet = transformSetEntry.getValue();
for (var entry : transformSet) {
var renderType = entry.getKey();
var transformCalls = entry.getValue();
int vertices = 0;
for (var transformCall : transformCalls) {
transformCall.setup();
vertices += transformCall.getTotalVertexCount();
}
if (vertices == 0) {
continue;
}
DrawBuffer buffer = drawTracker.getBuffer(renderType, stage);
buffer.prepare(vertices);
int startVertex = 0;
for (var transformCall : transformCalls) {
plans.add(transformCall.getPlan(buffer, startVertex, matrices, level));
startVertex += transformCall.getTotalVertexCount();
}
}
}
return new NestedPlan(plans);
} }
@Override @Override

View file

@ -1,28 +1,24 @@
package com.jozufozu.flywheel.backend.engine.batching; package com.jozufozu.flywheel.backend.engine.batching;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap; import java.util.EnumMap;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.jetbrains.annotations.NotNull;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import com.jozufozu.flywheel.api.event.RenderStage; import com.jozufozu.flywheel.api.event.RenderStage;
import com.jozufozu.flywheel.api.instance.Instance; import com.jozufozu.flywheel.api.instance.Instance;
import com.jozufozu.flywheel.api.instance.InstanceType; import com.jozufozu.flywheel.api.instance.InstanceType;
import com.jozufozu.flywheel.api.instance.Instancer; import com.jozufozu.flywheel.api.instance.Instancer;
import com.jozufozu.flywheel.api.model.Mesh; import com.jozufozu.flywheel.api.model.Mesh;
import com.jozufozu.flywheel.api.model.Model; import com.jozufozu.flywheel.api.model.Model;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.backend.engine.InstancerKey; import com.jozufozu.flywheel.backend.engine.InstancerKey;
import com.jozufozu.flywheel.lib.task.NestedPlan;
import com.mojang.blaze3d.vertex.PoseStack;
import com.mojang.blaze3d.vertex.VertexFormat; import com.mojang.blaze3d.vertex.VertexFormat;
import net.minecraft.client.multiplayer.ClientLevel;
import net.minecraft.client.renderer.RenderType; import net.minecraft.client.renderer.RenderType;
public class BatchingTransformManager { public class BatchingTransformManager {
@ -30,15 +26,17 @@ public class BatchingTransformManager {
private final List<UninitializedInstancer> uninitializedInstancers = new ArrayList<>(); private final List<UninitializedInstancer> uninitializedInstancers = new ArrayList<>();
private final List<CPUInstancer<?>> initializedInstancers = new ArrayList<>(); private final List<CPUInstancer<?>> initializedInstancers = new ArrayList<>();
private final Map<RenderStage, TransformSet> transformSets = new EnumMap<>(RenderStage.class); private final Map<RenderStage, TransformSet> transformSets = new EnumMap<>(RenderStage.class);
private final Map<RenderStage, TransformSet> transformSetsView = Collections.unmodifiableMap(transformSets);
private final Map<VertexFormat, BatchedMeshPool> meshPools = new HashMap<>(); private final Map<VertexFormat, BatchedMeshPool> meshPools = new HashMap<>();
public TransformSet get(RenderStage stage) { public Plan plan(PoseStack.Pose matrices, ClientLevel level, BatchingDrawTracker tracker) {
return transformSets.getOrDefault(stage, TransformSet.EMPTY); flush();
var plans = new ArrayList<Plan>();
for (var transformSet : transformSets.values()) {
plans.add(transformSet.plan(matrices, level, tracker));
} }
public Map<RenderStage, TransformSet> getTransformSetsView() { return new NestedPlan(plans);
return transformSetsView;
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -95,36 +93,6 @@ public class BatchingTransformManager {
.alloc(mesh); .alloc(mesh);
} }
public static class TransformSet implements Iterable<Map.Entry<RenderType, Collection<TransformCall<?>>>> {
public static final TransformSet EMPTY = new TransformSet(ImmutableListMultimap.of());
private final ListMultimap<RenderType, TransformCall<?>> transformCalls;
public TransformSet(RenderStage renderStage) {
transformCalls = ArrayListMultimap.create();
}
public TransformSet(ListMultimap<RenderType, TransformCall<?>> transformCalls) {
this.transformCalls = transformCalls;
}
public void put(RenderType shaderState, TransformCall<?> transformCall) {
transformCalls.put(shaderState, transformCall);
}
public boolean isEmpty() {
return transformCalls.isEmpty();
}
@NotNull
@Override
public Iterator<Map.Entry<RenderType, Collection<TransformCall<?>>>> iterator() {
return transformCalls.asMap()
.entrySet()
.iterator();
}
}
private record UninitializedInstancer(CPUInstancer<?> instancer, Model model, RenderStage stage) { private record UninitializedInstancer(CPUInstancer<?> instancer, Model model, RenderStage stage) {
} }
} }

View file

@ -19,6 +19,10 @@ public class CPUInstancer<I extends Instance> extends AbstractInstancer<I> {
return instances; return instances;
} }
public I get(int index) {
return instances.get(index);
}
public void update() { public void update() {
removeDeletedInstances(); removeDeletedInstances();
} }

View file

@ -75,7 +75,7 @@ public class DrawBuffer {
} }
ReusableVertexList vertexList = provider.createVertexList(); ReusableVertexList vertexList = provider.createVertexList();
vertexList.ptr(memory.ptr() + startVertex * stride); vertexList.ptr(memory.ptr() + (long) startVertex * stride);
vertexList.vertexCount(vertexCount); vertexList.vertexCount(vertexCount);
return vertexList; return vertexList;
} }

View file

@ -9,6 +9,7 @@ import com.jozufozu.flywheel.api.material.Material;
import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.vertex.MutableVertexList; import com.jozufozu.flywheel.api.vertex.MutableVertexList;
import com.jozufozu.flywheel.api.vertex.ReusableVertexList; import com.jozufozu.flywheel.api.vertex.ReusableVertexList;
import com.jozufozu.flywheel.lib.math.MoreMath;
import com.jozufozu.flywheel.lib.task.SimplePlan; import com.jozufozu.flywheel.lib.task.SimplePlan;
import com.jozufozu.flywheel.lib.vertex.VertexTransformations; import com.jozufozu.flywheel.lib.vertex.VertexTransformations;
import com.mojang.blaze3d.vertex.PoseStack; import com.mojang.blaze3d.vertex.PoseStack;
@ -42,37 +43,32 @@ public class TransformCall<I extends Instance> {
instancer.update(); instancer.update();
} }
public Plan getPlan(DrawBuffer buffer, int startVertex, PoseStack.Pose matrices, ClientLevel level) { public Plan plan(DrawBuffer buffer, int startVertex, PoseStack.Pose matrices, ClientLevel level) {
int instances = instancer.getInstanceCount(); final int totalCount = instancer.getInstanceCount();
final int chunkSize = MoreMath.ceilingDiv(totalCount, 6 * 32);
var out = new ArrayList<Runnable>(); final var out = new ArrayList<Runnable>();
int remaining = totalCount;
while (instances > 0) { while (remaining > 0) {
int end = instances; int end = remaining;
instances -= 512; remaining -= chunkSize;
int start = Math.max(instances, 0); int start = Math.max(remaining, 0);
int vertexCount = meshVertexCount * (end - start); int vertexCount = meshVertexCount * (end - start);
ReusableVertexList sub = buffer.slice(startVertex, vertexCount); ReusableVertexList sub = buffer.slice(startVertex, vertexCount);
startVertex += vertexCount; startVertex += vertexCount;
out.add(() -> transformRange(sub, start, end, matrices, level)); out.add(() -> transform(sub, matrices, level, instancer.getRange(start, end)));
} }
return new SimplePlan(out); return new SimplePlan(out);
} }
public void transformRange(ReusableVertexList vertexList, int from, int to, PoseStack.Pose matrices, ClientLevel level) { private void transform(ReusableVertexList vertexList, PoseStack.Pose matrices, ClientLevel level, List<I> instances) {
transformList(vertexList, instancer.getRange(from, to), matrices, level); // save the total size of the slice for later.
} final long anchorPtr = vertexList.ptr();
final int totalVertexCount = vertexList.vertexCount();
public void transformAll(ReusableVertexList vertexList, PoseStack.Pose matrices, ClientLevel level) {
transformList(vertexList, instancer.getAll(), matrices, level);
}
public void transformList(ReusableVertexList vertexList, List<I> instances, PoseStack.Pose matrices, ClientLevel level) {
long anchorPtr = vertexList.ptr();
int totalVertexCount = vertexList.vertexCount();
// while working on individual instances, the vertex list should expose just a single copy of the mesh.
vertexList.vertexCount(meshVertexCount); vertexList.vertexCount(meshVertexCount);
InstanceVertexTransformer<I> instanceVertexTransformer = instancer.type.getVertexTransformer(); InstanceVertexTransformer<I> instanceVertexTransformer = instancer.type.getVertexTransformer();
@ -85,6 +81,7 @@ public class TransformCall<I extends Instance> {
vertexList.ptr(vertexList.ptr() + meshByteSize); vertexList.ptr(vertexList.ptr() + meshByteSize);
} }
// restore the original size of the slice to apply per-vertex transformations.
vertexList.ptr(anchorPtr); vertexList.ptr(anchorPtr);
vertexList.vertexCount(totalVertexCount); vertexList.vertexCount(totalVertexCount);
material.getVertexTransformer().transform(vertexList, level); material.getVertexTransformer().transform(vertexList, level);

View file

@ -0,0 +1,62 @@
package com.jozufozu.flywheel.backend.engine.batching;
import java.util.ArrayList;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import com.jozufozu.flywheel.api.event.RenderStage;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.lib.task.NestedPlan;
import com.mojang.blaze3d.vertex.PoseStack;
import net.minecraft.client.multiplayer.ClientLevel;
import net.minecraft.client.renderer.RenderType;
public class TransformSet {
private final RenderStage stage;
private final ListMultimap<RenderType, TransformCall<?>> transformCalls;
public TransformSet(RenderStage renderStage) {
stage = renderStage;
transformCalls = ArrayListMultimap.create();
}
public Plan plan(PoseStack.Pose matrices, ClientLevel level, BatchingDrawTracker tracker) {
var plans = new ArrayList<Plan>();
for (var entry : transformCalls.asMap()
.entrySet()) {
var renderType = entry.getKey();
var transformCalls = entry.getValue();
int vertices = 0;
for (var transformCall : transformCalls) {
transformCall.setup();
vertices += transformCall.getTotalVertexCount();
}
if (vertices == 0) {
continue;
}
DrawBuffer buffer = tracker.getBuffer(renderType, this.stage);
buffer.prepare(vertices);
int startVertex = 0;
for (var transformCall : transformCalls) {
plans.add(transformCall.plan(buffer, startVertex, matrices, level));
startVertex += transformCall.getTotalVertexCount();
}
}
return new NestedPlan(plans);
}
public void put(RenderType shaderState, TransformCall<?> transformCall) {
transformCalls.put(shaderState, transformCall);
}
public boolean isEmpty() {
return transformCalls.isEmpty();
}
}

View file

@ -163,4 +163,8 @@ public final class MoreMath {
MemoryUtil.memPutFloat(ptr + 88, nzW); MemoryUtil.memPutFloat(ptr + 88, nzW);
MemoryUtil.memPutFloat(ptr + 92, pzW); MemoryUtil.memPutFloat(ptr + 92, pzW);
} }
public static int ceilingDiv(int numerator, int denominator) {
return (numerator + denominator - 1) / denominator;
}
} }

View file

@ -6,6 +6,7 @@ import java.util.function.Supplier;
import com.jozufozu.flywheel.api.task.Plan; import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor; import com.jozufozu.flywheel.api.task.TaskExecutor;
import com.jozufozu.flywheel.lib.math.MoreMath;
public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action) implements Plan { public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action) implements Plan {
public static <T> Plan of(Supplier<List<T>> iterable, Consumer<T> forEach) { public static <T> Plan of(Supplier<List<T>> iterable, Consumer<T> forEach) {
@ -32,7 +33,7 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
final int size = suppliedList.size(); final int size = suppliedList.size();
final int chunkSize = getChunkSize(taskExecutor, size); final int chunkSize = getChunkSize(taskExecutor, size);
var synchronizer = new Synchronizer(ceilingDiv(size, chunkSize), onCompletion); var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, chunkSize), onCompletion);
int remaining = size; int remaining = size;
while (remaining > 0) { while (remaining > 0) {
@ -46,11 +47,7 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
} }
private static int getChunkSize(TaskExecutor taskExecutor, int totalSize) { private static int getChunkSize(TaskExecutor taskExecutor, int totalSize) {
return ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32); return MoreMath.ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32);
}
private static int ceilingDiv(int numerator, int denominator) {
return (numerator + denominator - 1) / denominator;
} }
private void processList(List<T> suppliedList, Runnable onCompletion) { private void processList(List<T> suppliedList, Runnable onCompletion) {

View file

@ -4,7 +4,6 @@ import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger; import org.slf4j.Logger;
import com.jozufozu.flywheel.util.StringUtil;
import com.mojang.logging.LogUtils; import com.mojang.logging.LogUtils;
public class WaitGroup { public class WaitGroup {
@ -42,7 +41,7 @@ public class WaitGroup {
long elapsed = end - start; long elapsed = end - start;
if (elapsed > 1000000) { // > 1ms if (elapsed > 1000000) { // > 1ms
LOGGER.debug("Waited " + StringUtil.formatTime(elapsed) + ", looped " + count + " times"); // LOGGER.debug("Waited " + StringUtil.formatTime(elapsed) + ", looped " + count + " times");
} }
} }