diff --git a/src/main/java/com/jozufozu/flywheel/backend/instancing/AbstractInstancer.java b/src/main/java/com/jozufozu/flywheel/backend/instancing/AbstractInstancer.java index b6e85dbdc..06e6071fd 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/instancing/AbstractInstancer.java +++ b/src/main/java/com/jozufozu/flywheel/backend/instancing/AbstractInstancer.java @@ -66,6 +66,10 @@ public abstract class AbstractInstancer implements Insta return data.size(); } + public int getTotalVertexCount() { + return getModelVertexCount() * numInstances(); + } + protected BitSet getDirtyBitSet() { final int size = data.size(); final BitSet dirtySet = new BitSet(size); diff --git a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchExecutor.java b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchExecutor.java new file mode 100644 index 000000000..60e41f2b4 --- /dev/null +++ b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchExecutor.java @@ -0,0 +1,34 @@ +package com.jozufozu.flywheel.backend.instancing.batching; + +import java.util.concurrent.Executor; + +import org.jetbrains.annotations.NotNull; + +public class BatchExecutor implements Executor { + private final Executor internal; + private final WaitGroup wg; + + public BatchExecutor(Executor internal) { + this.internal = internal; + + wg = new WaitGroup(); + } + + @Override + public void execute(@NotNull Runnable command) { + wg.add(1); + internal.execute(() -> { + // wrapper function to decrement the wait group + try { + command.run(); + } catch (Exception ignored) { + } finally { + wg.done(); + } + }); + } + + public void await() throws InterruptedException { + wg.await(); + } +} diff --git a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchedMaterial.java b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchedMaterial.java index ad14cba40..18183fa6a 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchedMaterial.java +++ b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchedMaterial.java @@ -30,7 +30,8 @@ public class BatchedMaterial implements Material { public void render(PoseStack stack, VertexConsumer buffer, FormatContext context) { for (CPUInstancer instancer : models.values()) { - instancer.drawAll(stack, buffer, context); + instancer.setup(context); + instancer.drawAll(stack, buffer); } } diff --git a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchedMaterialGroup.java b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchedMaterialGroup.java index ea7b65e7c..34f4b71d6 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchedMaterialGroup.java +++ b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchedMaterialGroup.java @@ -2,6 +2,7 @@ package com.jozufozu.flywheel.backend.instancing.batching; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.Executor; import com.jozufozu.flywheel.api.InstanceData; import com.jozufozu.flywheel.api.MaterialGroup; @@ -30,15 +31,20 @@ public class BatchedMaterialGroup implements MaterialGroup { return (BatchedMaterial) materials.computeIfAbsent(spec, BatchedMaterial::new); } - public void render(PoseStack stack, MultiBufferSource source) { + public void render(PoseStack stack, MultiBufferSource source, Executor pool) { VertexConsumer buffer = source.getBuffer(state); if (buffer instanceof DirectBufferBuilder direct) { DirectVertexConsumer consumer = direct.intoDirectConsumer(calculateNeededVertices()); + FormatContext context = new FormatContext(consumer.hasOverlay()); - renderInto(stack, consumer, new FormatContext(consumer.hasOverlay())); + for (BatchedMaterial material : materials.values()) { + for (CPUInstancer instancer : material.models.values()) { + instancer.setup(context); - direct.updateAfterWriting(consumer); + instancer.submitTasks(stack, pool, consumer); + } + } } else { renderInto(stack, buffer, FormatContext.defaultContext()); } @@ -48,7 +54,7 @@ public class BatchedMaterialGroup implements MaterialGroup { int total = 0; for (BatchedMaterial material : materials.values()) { for (CPUInstancer instancer : material.models.values()) { - total += instancer.getModelVertexCount() * instancer.numInstances(); + total += instancer.getTotalVertexCount(); } } diff --git a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchingEngine.java b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchingEngine.java index ba915499e..cc5621418 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchingEngine.java +++ b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/BatchingEngine.java @@ -3,6 +3,8 @@ package com.jozufozu.flywheel.backend.instancing.batching; import java.util.EnumMap; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ForkJoinPool; import com.jozufozu.flywheel.api.MaterialGroup; import com.jozufozu.flywheel.backend.RenderLayer; @@ -22,11 +24,15 @@ public class BatchingEngine implements Engine { protected final Map> layers; + private final BatchExecutor pool; + public BatchingEngine() { this.layers = new EnumMap<>(RenderLayer.class); for (RenderLayer value : RenderLayer.values()) { layers.put(value, new HashMap<>()); } + + pool = new BatchExecutor(Executors.newWorkStealingPool(ForkJoinPool.getCommonPoolParallelism())); } @Override @@ -50,7 +56,12 @@ public class BatchingEngine implements Engine { for (Map.Entry entry : layers.get(event.getLayer()).entrySet()) { BatchedMaterialGroup group = entry.getValue(); - group.render(stack, buffers); + group.render(stack, buffers, pool); + } + + try { + pool.await(); + } catch (InterruptedException ignored) { } stack.popPose(); diff --git a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/CPUInstancer.java b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/CPUInstancer.java index 307f3315b..4231c2759 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/CPUInstancer.java +++ b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/CPUInstancer.java @@ -1,27 +1,51 @@ package com.jozufozu.flywheel.backend.instancing.batching; +import java.util.concurrent.Executor; + import com.jozufozu.flywheel.api.InstanceData; import com.jozufozu.flywheel.api.struct.BatchingTransformer; import com.jozufozu.flywheel.api.struct.StructType; import com.jozufozu.flywheel.backend.instancing.AbstractInstancer; +import com.jozufozu.flywheel.backend.model.DirectVertexConsumer; import com.jozufozu.flywheel.core.model.Model; import com.jozufozu.flywheel.core.model.SuperByteBuffer; import com.mojang.blaze3d.vertex.PoseStack; import com.mojang.blaze3d.vertex.VertexConsumer; -import com.mojang.blaze3d.vertex.VertexFormat; public class CPUInstancer extends AbstractInstancer { private final BatchingTransformer transform; private final SuperByteBuffer sbb; + private final SuperByteBuffer.Params defaultParams; public CPUInstancer(StructType type, Model modelData) { super(type, modelData); sbb = new SuperByteBuffer(modelData); + defaultParams = SuperByteBuffer.Params.defaultParams(); transform = type.asBatched() .getTransformer(); + + if (transform == null) { + throw new NullPointerException("Cannot batch " + type.toString()); + } + } + + void submitTasks(PoseStack stack, Executor pool, DirectVertexConsumer consumer) { + int instances = numInstances(); + + while (instances > 0) { + int end = instances; + instances -= 100; + int start = Math.max(instances, 0); + + int verts = getModelVertexCount() * (end - start); + + DirectVertexConsumer sub = consumer.split(verts); + + pool.execute(() -> drawRange(stack, sub, start, end)); + } } @Override @@ -29,23 +53,34 @@ public class CPUInstancer extends AbstractInstancer { // noop } - public void drawAll(PoseStack stack, VertexConsumer buffer, FormatContext context) { - if (transform == null) { - return; - } + private void drawRange(PoseStack stack, VertexConsumer buffer, int from, int to) { + SuperByteBuffer.Params params = defaultParams.copy(); + for (D d : data.subList(from, to)) { + transform.transform(d, params); + + sbb.renderInto(params, stack, buffer); + + params.load(defaultParams); + } + } + + void drawAll(PoseStack stack, VertexConsumer buffer) { + SuperByteBuffer.Params params = defaultParams.copy(); + for (D d : data) { + transform.transform(d, params); + + sbb.renderInto(params, stack, buffer); + + params.load(defaultParams); + } + } + + void setup(FormatContext context) { renderSetup(); if (context.usesOverlay()) { - sbb.getDefaultParams().entityMode(); - } - - sbb.reset(); - - for (D d : data) { - transform.transform(d, sbb.getParams()); - - sbb.renderInto(stack, buffer); + defaultParams.entityMode(); } } diff --git a/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/WaitGroup.java b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/WaitGroup.java new file mode 100644 index 000000000..df7ffb7d9 --- /dev/null +++ b/src/main/java/com/jozufozu/flywheel/backend/instancing/batching/WaitGroup.java @@ -0,0 +1,24 @@ +package com.jozufozu.flywheel.backend.instancing.batching; + +// https://stackoverflow.com/questions/29655531 +public class WaitGroup { + + private int jobs = 0; + + public synchronized void add(int i) { + jobs += i; + } + + public synchronized void done() { + if (--jobs == 0) { + notifyAll(); + } + } + + public synchronized void await() throws InterruptedException { + while (jobs > 0) { + wait(); + } + } + +} diff --git a/src/main/java/com/jozufozu/flywheel/backend/model/DirectBufferBuilder.java b/src/main/java/com/jozufozu/flywheel/backend/model/DirectBufferBuilder.java index b3f08681a..98449155a 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/model/DirectBufferBuilder.java +++ b/src/main/java/com/jozufozu/flywheel/backend/model/DirectBufferBuilder.java @@ -2,7 +2,5 @@ package com.jozufozu.flywheel.backend.model; public interface DirectBufferBuilder { - DirectVertexConsumer intoDirectConsumer(int neededVerts); - - void updateAfterWriting(DirectVertexConsumer complete); + DirectVertexConsumer intoDirectConsumer(int vertexCount); } diff --git a/src/main/java/com/jozufozu/flywheel/backend/model/DirectVertexConsumer.java b/src/main/java/com/jozufozu/flywheel/backend/model/DirectVertexConsumer.java index 0bea53cb4..e0516e69c 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/model/DirectVertexConsumer.java +++ b/src/main/java/com/jozufozu/flywheel/backend/model/DirectVertexConsumer.java @@ -10,7 +10,6 @@ import com.mojang.blaze3d.vertex.VertexFormat; import com.mojang.blaze3d.vertex.VertexFormatElement; public class DirectVertexConsumer implements VertexConsumer { - public final VertexFormat format; private final int stride; public final int startPos; @@ -23,7 +22,6 @@ public class DirectVertexConsumer implements VertexConsumer { private int uv2 = -1; private long vertexBase; - private int vertexCount; public DirectVertexConsumer(ByteBuffer buffer, VertexFormat format) { this.format = format; @@ -49,13 +47,42 @@ public class DirectVertexConsumer implements VertexConsumer { offset += element.getByteSize(); } - this.vertexBase = MemoryUtil.memAddress(buffer); + this.vertexBase = MemoryUtil.memAddress(buffer, startPos); + } + + private DirectVertexConsumer(DirectVertexConsumer parent) { + this.format = parent.format; + this.stride = parent.stride; + this.startPos = parent.startPos; + this.position = parent.position; + this.normal = parent.normal; + this.color = parent.color; + this.uv = parent.uv; + this.uv1 = parent.uv1; + this.uv2 = parent.uv2; + + this.vertexBase = parent.vertexBase; } public boolean hasOverlay() { return uv1 >= 0; } + /** + * Split off the head of this consumer into a new object and advance our write-pointer. + * @param vertexCount The number of vertices that must be written to the head. + * @return The head of this consumer. + */ + public DirectVertexConsumer split(int vertexCount) { + int bytes = vertexCount * stride; + + DirectVertexConsumer head = new DirectVertexConsumer(this); + + this.vertexBase += bytes; + + return head; + } + @Override public VertexConsumer vertex(double x, double y, double z) { if (position < 0) return this; @@ -117,7 +144,6 @@ public class DirectVertexConsumer implements VertexConsumer { @Override public void endVertex() { vertexBase += stride; - vertexCount++; } @Override @@ -129,8 +155,4 @@ public class DirectVertexConsumer implements VertexConsumer { public void unsetDefaultColor() { } - - public int getVertexCount() { - return vertexCount; - } } diff --git a/src/main/java/com/jozufozu/flywheel/core/model/SuperByteBuffer.java b/src/main/java/com/jozufozu/flywheel/core/model/SuperByteBuffer.java index 5135885b5..daf2f3440 100644 --- a/src/main/java/com/jozufozu/flywheel/core/model/SuperByteBuffer.java +++ b/src/main/java/com/jozufozu/flywheel/core/model/SuperByteBuffer.java @@ -22,32 +22,22 @@ public class SuperByteBuffer { private final Model model; private final ModelReader template; - private final Params defaultParams = Params.defaultParams(); - private final Params params = defaultParams.copy(); - - public Params getDefaultParams() { - return defaultParams; - } - - public Params getParams() { - return params; - } - // Temporary private static final Long2IntMap WORLD_LIGHT_CACHE = new Long2IntOpenHashMap(); - private final Vector4f pos = new Vector4f(); - private final Vector3f normal = new Vector3f(); - private final Vector4f lightPos = new Vector4f(); public SuperByteBuffer(Model model) { this.model = model; template = model.getReader(); } - public void renderInto(PoseStack input, VertexConsumer builder) { + public void renderInto(Params params, PoseStack input, VertexConsumer builder) { if (isEmpty()) return; + Vector4f pos = new Vector4f(); + Vector3f normal = new Vector3f(); + Vector4f lightPos = new Vector4f(); + Matrix4f modelMat = input.last() .pose() .copy(); @@ -168,14 +158,6 @@ public class SuperByteBuffer { builder.endVertex(); } - - reset(); - } - - public SuperByteBuffer reset() { - params.load(defaultParams); - - return this; } public boolean isEmpty() { diff --git a/src/main/java/com/jozufozu/flywheel/mixin/BufferBuilderMixin.java b/src/main/java/com/jozufozu/flywheel/mixin/BufferBuilderMixin.java index a22223d2f..95d2874e3 100644 --- a/src/main/java/com/jozufozu/flywheel/mixin/BufferBuilderMixin.java +++ b/src/main/java/com/jozufozu/flywheel/mixin/BufferBuilderMixin.java @@ -38,21 +38,19 @@ public abstract class BufferBuilderMixin implements DirectBufferBuilder { private int nextElementByte; @Override - public DirectVertexConsumer intoDirectConsumer(int neededVerts) { - ensureCapacity(neededVerts * this.format.getVertexSize()); - return new DirectVertexConsumer(this.buffer, this.format); - } + public DirectVertexConsumer intoDirectConsumer(int vertexCount) { + int bytes = vertexCount * format.getVertexSize(); + ensureCapacity(bytes); - @Override - public void updateAfterWriting(DirectVertexConsumer complete) { - int vertexCount = complete.getVertexCount(); - int totalWrittenBytes = vertexCount * format.getVertexSize(); + DirectVertexConsumer consumer = new DirectVertexConsumer(this.buffer, this.format); this.vertices += vertexCount; this.currentElement = format.getElements() .get(0); this.elementIndex = 0; - this.nextElementByte += totalWrittenBytes; - this.buffer.position(complete.startPos + totalWrittenBytes); + this.nextElementByte += bytes; + this.buffer.position(consumer.startPos + bytes); + + return consumer; } }