Smaller batches

- Extract TransformSet to full class
- Move planning code into BatchingTransformManager
- Reduce exposure BatchingTransformManager of internals
- Comment out WaitGroup log :ioa:
- Move ceilingDiv to MoreMath
- Use dynamic chunk size for TransformCalls
This commit is contained in:
Jozufozu 2023-04-16 12:46:20 -07:00
parent cb74fa4603
commit e03590b270
10 changed files with 114 additions and 126 deletions

View file

@ -22,7 +22,7 @@ public class BatchedMeshPool {
private final Map<Mesh, BufferedMesh> meshes = new HashMap<>();
private final List<BufferedMesh> allBuffered = new ArrayList<>();
private final List<BufferedMesh> pendingUpload = new ArrayList<>();
private final List<BufferedMesh> pendingBuffer = new ArrayList<>();
private MemoryBlock memory;
private long byteSize;
@ -54,7 +54,7 @@ public class BatchedMeshPool {
BufferedMesh bufferedMesh = new BufferedMesh(m, byteSize);
byteSize += bufferedMesh.size();
allBuffered.add(bufferedMesh);
pendingUpload.add(bufferedMesh);
pendingBuffer.add(bufferedMesh);
dirty = true;
return bufferedMesh;
@ -73,10 +73,10 @@ public class BatchedMeshPool {
}
realloc();
uploadPending();
bufferPending();
dirty = false;
pendingUpload.clear();
pendingBuffer.clear();
}
}
@ -94,7 +94,7 @@ public class BatchedMeshPool {
int byteIndex = 0;
for (BufferedMesh mesh : allBuffered) {
if (mesh.byteIndex != byteIndex) {
pendingUpload.add(mesh);
pendingBuffer.add(mesh);
}
mesh.byteIndex = byteIndex;
@ -122,13 +122,13 @@ public class BatchedMeshPool {
}
}
private void uploadPending() {
private void bufferPending() {
try {
for (BufferedMesh mesh : pendingUpload) {
for (BufferedMesh mesh : pendingBuffer) {
mesh.buffer(vertexList);
}
pendingUpload.clear();
pendingBuffer.clear();
} catch (Exception e) {
Flywheel.LOGGER.error("Error uploading pooled meshes:", e);
}
@ -140,7 +140,7 @@ public class BatchedMeshPool {
}
meshes.clear();
allBuffered.clear();
pendingUpload.clear();
pendingBuffer.clear();
}
@Override

View file

@ -1,6 +1,5 @@
package com.jozufozu.flywheel.backend.engine.batching;
import java.util.ArrayList;
import java.util.List;
import com.jozufozu.flywheel.api.event.RenderContext;
@ -12,8 +11,6 @@ import com.jozufozu.flywheel.api.model.Model;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor;
import com.jozufozu.flywheel.backend.engine.AbstractEngine;
import com.jozufozu.flywheel.lib.task.NestedPlan;
import com.jozufozu.flywheel.lib.task.PlanUtil;
import com.jozufozu.flywheel.util.FlwUtil;
import net.minecraft.world.phys.Vec3;
@ -33,52 +30,12 @@ public class BatchingEngine extends AbstractEngine {
@Override
public Plan planThisFrame(RenderContext context) {
return PlanUtil.of(transformManager::flush)
.then(planTransformers(context));
}
private Plan planTransformers(RenderContext context) {
Vec3 cameraPos = context.camera()
.getPosition();
var stack = FlwUtil.copyPoseStack(context.stack());
stack.translate(renderOrigin.getX() - cameraPos.x, renderOrigin.getY() - cameraPos.y, renderOrigin.getZ() - cameraPos.z);
var matrices = stack.last();
var level = context.level();
var plans = new ArrayList<Plan>();
for (var transformSetEntry : transformManager.getTransformSetsView()
.entrySet()) {
var stage = transformSetEntry.getKey();
var transformSet = transformSetEntry.getValue();
for (var entry : transformSet) {
var renderType = entry.getKey();
var transformCalls = entry.getValue();
int vertices = 0;
for (var transformCall : transformCalls) {
transformCall.setup();
vertices += transformCall.getTotalVertexCount();
}
if (vertices == 0) {
continue;
}
DrawBuffer buffer = drawTracker.getBuffer(renderType, stage);
buffer.prepare(vertices);
int startVertex = 0;
for (var transformCall : transformCalls) {
plans.add(transformCall.getPlan(buffer, startVertex, matrices, level));
startVertex += transformCall.getTotalVertexCount();
}
}
}
return new NestedPlan(plans);
return transformManager.plan(stack.last(), context.level(), drawTracker);
}
@Override

View file

@ -1,28 +1,24 @@
package com.jozufozu.flywheel.backend.engine.batching;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.jetbrains.annotations.NotNull;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import com.jozufozu.flywheel.api.event.RenderStage;
import com.jozufozu.flywheel.api.instance.Instance;
import com.jozufozu.flywheel.api.instance.InstanceType;
import com.jozufozu.flywheel.api.instance.Instancer;
import com.jozufozu.flywheel.api.model.Mesh;
import com.jozufozu.flywheel.api.model.Model;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.backend.engine.InstancerKey;
import com.jozufozu.flywheel.lib.task.NestedPlan;
import com.mojang.blaze3d.vertex.PoseStack;
import com.mojang.blaze3d.vertex.VertexFormat;
import net.minecraft.client.multiplayer.ClientLevel;
import net.minecraft.client.renderer.RenderType;
public class BatchingTransformManager {
@ -30,15 +26,17 @@ public class BatchingTransformManager {
private final List<UninitializedInstancer> uninitializedInstancers = new ArrayList<>();
private final List<CPUInstancer<?>> initializedInstancers = new ArrayList<>();
private final Map<RenderStage, TransformSet> transformSets = new EnumMap<>(RenderStage.class);
private final Map<RenderStage, TransformSet> transformSetsView = Collections.unmodifiableMap(transformSets);
private final Map<VertexFormat, BatchedMeshPool> meshPools = new HashMap<>();
public TransformSet get(RenderStage stage) {
return transformSets.getOrDefault(stage, TransformSet.EMPTY);
public Plan plan(PoseStack.Pose matrices, ClientLevel level, BatchingDrawTracker tracker) {
flush();
var plans = new ArrayList<Plan>();
for (var transformSet : transformSets.values()) {
plans.add(transformSet.plan(matrices, level, tracker));
}
public Map<RenderStage, TransformSet> getTransformSetsView() {
return transformSetsView;
return new NestedPlan(plans);
}
@SuppressWarnings("unchecked")
@ -95,36 +93,6 @@ public class BatchingTransformManager {
.alloc(mesh);
}
public static class TransformSet implements Iterable<Map.Entry<RenderType, Collection<TransformCall<?>>>> {
public static final TransformSet EMPTY = new TransformSet(ImmutableListMultimap.of());
private final ListMultimap<RenderType, TransformCall<?>> transformCalls;
public TransformSet(RenderStage renderStage) {
transformCalls = ArrayListMultimap.create();
}
public TransformSet(ListMultimap<RenderType, TransformCall<?>> transformCalls) {
this.transformCalls = transformCalls;
}
public void put(RenderType shaderState, TransformCall<?> transformCall) {
transformCalls.put(shaderState, transformCall);
}
public boolean isEmpty() {
return transformCalls.isEmpty();
}
@NotNull
@Override
public Iterator<Map.Entry<RenderType, Collection<TransformCall<?>>>> iterator() {
return transformCalls.asMap()
.entrySet()
.iterator();
}
}
private record UninitializedInstancer(CPUInstancer<?> instancer, Model model, RenderStage stage) {
}
}

View file

@ -19,6 +19,10 @@ public class CPUInstancer<I extends Instance> extends AbstractInstancer<I> {
return instances;
}
public I get(int index) {
return instances.get(index);
}
public void update() {
removeDeletedInstances();
}

View file

@ -75,7 +75,7 @@ public class DrawBuffer {
}
ReusableVertexList vertexList = provider.createVertexList();
vertexList.ptr(memory.ptr() + startVertex * stride);
vertexList.ptr(memory.ptr() + (long) startVertex * stride);
vertexList.vertexCount(vertexCount);
return vertexList;
}

View file

@ -9,6 +9,7 @@ import com.jozufozu.flywheel.api.material.Material;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.vertex.MutableVertexList;
import com.jozufozu.flywheel.api.vertex.ReusableVertexList;
import com.jozufozu.flywheel.lib.math.MoreMath;
import com.jozufozu.flywheel.lib.task.SimplePlan;
import com.jozufozu.flywheel.lib.vertex.VertexTransformations;
import com.mojang.blaze3d.vertex.PoseStack;
@ -42,37 +43,32 @@ public class TransformCall<I extends Instance> {
instancer.update();
}
public Plan getPlan(DrawBuffer buffer, int startVertex, PoseStack.Pose matrices, ClientLevel level) {
int instances = instancer.getInstanceCount();
public Plan plan(DrawBuffer buffer, int startVertex, PoseStack.Pose matrices, ClientLevel level) {
final int totalCount = instancer.getInstanceCount();
final int chunkSize = MoreMath.ceilingDiv(totalCount, 6 * 32);
var out = new ArrayList<Runnable>();
while (instances > 0) {
int end = instances;
instances -= 512;
int start = Math.max(instances, 0);
final var out = new ArrayList<Runnable>();
int remaining = totalCount;
while (remaining > 0) {
int end = remaining;
remaining -= chunkSize;
int start = Math.max(remaining, 0);
int vertexCount = meshVertexCount * (end - start);
ReusableVertexList sub = buffer.slice(startVertex, vertexCount);
startVertex += vertexCount;
out.add(() -> transformRange(sub, start, end, matrices, level));
out.add(() -> transform(sub, matrices, level, instancer.getRange(start, end)));
}
return new SimplePlan(out);
}
public void transformRange(ReusableVertexList vertexList, int from, int to, PoseStack.Pose matrices, ClientLevel level) {
transformList(vertexList, instancer.getRange(from, to), matrices, level);
}
public void transformAll(ReusableVertexList vertexList, PoseStack.Pose matrices, ClientLevel level) {
transformList(vertexList, instancer.getAll(), matrices, level);
}
public void transformList(ReusableVertexList vertexList, List<I> instances, PoseStack.Pose matrices, ClientLevel level) {
long anchorPtr = vertexList.ptr();
int totalVertexCount = vertexList.vertexCount();
private void transform(ReusableVertexList vertexList, PoseStack.Pose matrices, ClientLevel level, List<I> instances) {
// save the total size of the slice for later.
final long anchorPtr = vertexList.ptr();
final int totalVertexCount = vertexList.vertexCount();
// while working on individual instances, the vertex list should expose just a single copy of the mesh.
vertexList.vertexCount(meshVertexCount);
InstanceVertexTransformer<I> instanceVertexTransformer = instancer.type.getVertexTransformer();
@ -85,6 +81,7 @@ public class TransformCall<I extends Instance> {
vertexList.ptr(vertexList.ptr() + meshByteSize);
}
// restore the original size of the slice to apply per-vertex transformations.
vertexList.ptr(anchorPtr);
vertexList.vertexCount(totalVertexCount);
material.getVertexTransformer().transform(vertexList, level);

View file

@ -0,0 +1,62 @@
package com.jozufozu.flywheel.backend.engine.batching;
import java.util.ArrayList;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ListMultimap;
import com.jozufozu.flywheel.api.event.RenderStage;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.lib.task.NestedPlan;
import com.mojang.blaze3d.vertex.PoseStack;
import net.minecraft.client.multiplayer.ClientLevel;
import net.minecraft.client.renderer.RenderType;
public class TransformSet {
private final RenderStage stage;
private final ListMultimap<RenderType, TransformCall<?>> transformCalls;
public TransformSet(RenderStage renderStage) {
stage = renderStage;
transformCalls = ArrayListMultimap.create();
}
public Plan plan(PoseStack.Pose matrices, ClientLevel level, BatchingDrawTracker tracker) {
var plans = new ArrayList<Plan>();
for (var entry : transformCalls.asMap()
.entrySet()) {
var renderType = entry.getKey();
var transformCalls = entry.getValue();
int vertices = 0;
for (var transformCall : transformCalls) {
transformCall.setup();
vertices += transformCall.getTotalVertexCount();
}
if (vertices == 0) {
continue;
}
DrawBuffer buffer = tracker.getBuffer(renderType, this.stage);
buffer.prepare(vertices);
int startVertex = 0;
for (var transformCall : transformCalls) {
plans.add(transformCall.plan(buffer, startVertex, matrices, level));
startVertex += transformCall.getTotalVertexCount();
}
}
return new NestedPlan(plans);
}
public void put(RenderType shaderState, TransformCall<?> transformCall) {
transformCalls.put(shaderState, transformCall);
}
public boolean isEmpty() {
return transformCalls.isEmpty();
}
}

View file

@ -163,4 +163,8 @@ public final class MoreMath {
MemoryUtil.memPutFloat(ptr + 88, nzW);
MemoryUtil.memPutFloat(ptr + 92, pzW);
}
public static int ceilingDiv(int numerator, int denominator) {
return (numerator + denominator - 1) / denominator;
}
}

View file

@ -6,6 +6,7 @@ import java.util.function.Supplier;
import com.jozufozu.flywheel.api.task.Plan;
import com.jozufozu.flywheel.api.task.TaskExecutor;
import com.jozufozu.flywheel.lib.math.MoreMath;
public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action) implements Plan {
public static <T> Plan of(Supplier<List<T>> iterable, Consumer<T> forEach) {
@ -32,7 +33,7 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
final int size = suppliedList.size();
final int chunkSize = getChunkSize(taskExecutor, size);
var synchronizer = new Synchronizer(ceilingDiv(size, chunkSize), onCompletion);
var synchronizer = new Synchronizer(MoreMath.ceilingDiv(size, chunkSize), onCompletion);
int remaining = size;
while (remaining > 0) {
@ -46,11 +47,7 @@ public record RunOnAllPlan<T>(Supplier<List<T>> listSupplier, Consumer<T> action
}
private static int getChunkSize(TaskExecutor taskExecutor, int totalSize) {
return ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32);
}
private static int ceilingDiv(int numerator, int denominator) {
return (numerator + denominator - 1) / denominator;
return MoreMath.ceilingDiv(totalSize, taskExecutor.getThreadCount() * 32);
}
private void processList(List<T> suppliedList, Runnable onCompletion) {

View file

@ -4,7 +4,6 @@ import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import com.jozufozu.flywheel.util.StringUtil;
import com.mojang.logging.LogUtils;
public class WaitGroup {
@ -42,7 +41,7 @@ public class WaitGroup {
long elapsed = end - start;
if (elapsed > 1000000) { // > 1ms
LOGGER.debug("Waited " + StringUtil.formatTime(elapsed) + ", looped " + count + " times");
// LOGGER.debug("Waited " + StringUtil.formatTime(elapsed) + ", looped " + count + " times");
}
}