From f9209f7dd60173f9eebf201bb44aacb1b2de43c0 Mon Sep 17 00:00:00 2001 From: Jozufozu Date: Sat, 9 Dec 2023 15:32:21 -0800 Subject: [PATCH] Throw more compute at it - Use a shader to perform copies. - Needs optimization, but it works surprisingly well and avoids the driver hitch. --- .../backend/compile/IndirectPrograms.java | 33 ++++-- .../engine/indirect/IndirectCullingGroup.java | 1 + .../engine/indirect/StagingBuffer.java | 111 ++++++++++++++++-- .../jozufozu/flywheel/gl/buffer/GlBuffer.java | 14 ++- .../jozufozu/flywheel/lib/math/MoreMath.java | 4 + .../flywheel/internal/indirect/scatter.glsl | 35 ++++++ 6 files changed, 174 insertions(+), 24 deletions(-) create mode 100644 src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl diff --git a/src/main/java/com/jozufozu/flywheel/backend/compile/IndirectPrograms.java b/src/main/java/com/jozufozu/flywheel/backend/compile/IndirectPrograms.java index 39453ea5a..144273c69 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/compile/IndirectPrograms.java +++ b/src/main/java/com/jozufozu/flywheel/backend/compile/IndirectPrograms.java @@ -24,18 +24,21 @@ import net.minecraft.resources.ResourceLocation; public class IndirectPrograms { private static final ResourceLocation CULL_SHADER_MAIN = Flywheel.rl("internal/indirect/cull.glsl"); private static final ResourceLocation APPLY_SHADER_MAIN = Flywheel.rl("internal/indirect/apply.glsl"); + private static final ResourceLocation SCATTER_SHADER_MAIN = Flywheel.rl("internal/indirect/scatter.glsl"); public static IndirectPrograms instance; private static final Compile> CULL = new Compile<>(); - private static final Compile APPLY = new Compile<>(); + private static final Compile UNIT = new Compile<>(); private final Map pipeline; private final Map, GlProgram> culling; private final GlProgram apply; + private final GlProgram scatter; - public IndirectPrograms(Map pipeline, Map, GlProgram> culling, GlProgram apply) { + public IndirectPrograms(Map pipeline, Map, GlProgram> culling, GlProgram apply, GlProgram scatter) { this.pipeline = pipeline; this.culling = culling; this.apply = apply; + this.scatter = scatter; } static void reload(ShaderSources sources, ImmutableList pipelineKeys, UniformComponent uniformComponent, List vertexComponents, List fragmentComponents) { @@ -43,14 +46,16 @@ public class IndirectPrograms { var pipelineCompiler = PipelineCompiler.create(sources, Pipelines.INDIRECT, pipelineKeys, uniformComponent, vertexComponents, fragmentComponents); var cullingCompiler = createCullingCompiler(uniformComponent, sources); var applyCompiler = createApplyCompiler(sources); + var scatterCompiler = createScatterCompiler(sources); try { var pipelineResult = pipelineCompiler.compileAndReportErrors(); var cullingResult = cullingCompiler.compileAndReportErrors(); var applyResult = applyCompiler.compileAndReportErrors(); + var scatterResult = scatterCompiler.compileAndReportErrors(); - if (pipelineResult != null && cullingResult != null && applyResult != null) { - instance = new IndirectPrograms(pipelineResult, cullingResult, applyResult.get(Unit.INSTANCE)); + if (pipelineResult != null && cullingResult != null && applyResult != null && scatterResult != null) { + instance = new IndirectPrograms(pipelineResult, cullingResult, applyResult.get(Unit.INSTANCE), scatterResult.get(Unit.INSTANCE)); } } catch (Throwable e) { Flywheel.LOGGER.error("Failed to compile indirect programs", e); @@ -99,15 +104,25 @@ public class IndirectPrograms { } private static CompilationHarness createApplyCompiler(ShaderSources sources) { - return APPLY.harness(sources) + return UNIT.harness(sources) .keys(ImmutableList.of(Unit.INSTANCE)) - .compiler(APPLY.program() - .link(APPLY.shader(GlslVersion.V460, ShaderType.COMPUTE) + .compiler(UNIT.program() + .link(UNIT.shader(GlslVersion.V460, ShaderType.COMPUTE) .define("_FLW_SUBGROUP_SIZE", GlCompat.SUBGROUP_SIZE) .withResource(APPLY_SHADER_MAIN))) .build(); } + private static CompilationHarness createScatterCompiler(ShaderSources sources) { + return UNIT.harness(sources) + .keys(ImmutableList.of(Unit.INSTANCE)) + .compiler(UNIT.program() + .link(UNIT.shader(GlslVersion.V460, ShaderType.COMPUTE) + .define("_FLW_SUBGROUP_SIZE", GlCompat.SUBGROUP_SIZE) + .withResource(SCATTER_SHADER_MAIN))) + .build(); + } + public GlProgram getIndirectProgram(InstanceType instanceType, Context contextShader) { return pipeline.get(new PipelineProgramKey(instanceType, contextShader)); } @@ -120,6 +135,10 @@ public class IndirectPrograms { return apply; } + public GlProgram getScatterProgram() { + return scatter; + } + public void delete() { pipeline.values() .forEach(GlProgram::delete); diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectCullingGroup.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectCullingGroup.java index f91f79cc3..0e8ed99f4 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectCullingGroup.java +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectCullingGroup.java @@ -99,6 +99,7 @@ public class IndirectCullingGroup { UniformBuffer.get().sync(); cullProgram.bind(); buffers.bindForCompute(); + glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT); glDispatchCompute(GlCompat.getComputeGroupCount(instanceCountThisFrame), 1, 1); } diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java index 5d79e1e57..184c9d6f4 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java @@ -4,14 +4,22 @@ import java.util.ArrayList; import java.util.List; import java.util.function.LongConsumer; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.lwjgl.opengl.GL45; import org.lwjgl.opengl.GL45C; import org.lwjgl.system.MemoryUtil; +import com.jozufozu.flywheel.backend.compile.IndirectPrograms; +import com.jozufozu.flywheel.gl.GlCompat; import com.jozufozu.flywheel.gl.GlFence; +import com.jozufozu.flywheel.gl.buffer.GlBuffer; import com.jozufozu.flywheel.lib.memory.FlwMemoryTracker; import com.jozufozu.flywheel.lib.memory.MemoryBlock; import it.unimi.dsi.fastutil.PriorityQueue; +import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.objects.ObjectArrayFIFOQueue; // Used https://github.com/CaffeineMC/sodium-fabric/blob/dev/src/main/java/me/jellysquid/mods/sodium/client/gl/arena/staging/MappedStagingBuffer.java @@ -30,8 +38,10 @@ public class StagingBuffer { private long totalAvailable; + @Nullable private MemoryBlock scratch; + private final GlBuffer copyBuffer = new GlBuffer(); private final OverflowStagingBuffer overflow = new OverflowStagingBuffer(); private final PriorityQueue transfers = new ObjectArrayFIFOQueue<>(); private final PriorityQueue fencedRegions = new ObjectArrayFIFOQueue<>(); @@ -78,6 +88,7 @@ public class StagingBuffer { enqueueCopy(block.ptr(), size, dstVbo, dstOffset); } + @NotNull private MemoryBlock getScratch(long size) { if (scratch == null) { scratch = MemoryBlock.malloc(size); @@ -107,8 +118,10 @@ public class StagingBuffer { * @param dstOffset The offset in the destination VBO. */ public void enqueueCopy(long ptr, long size, int dstVbo, long dstOffset) { + assertMultipleOf4(size); + if (size > totalAvailable) { - overflow.enqueueCopy(ptr, size, dstVbo, dstOffset); + overflow.upload(ptr, size, dstVbo, dstOffset); return; } @@ -151,6 +164,7 @@ public class StagingBuffer { * @return A pointer to the reserved space, or {@code null} if there is not enough contiguous space. */ public long reserveForTransferTo(long size, int dstVbo, long dstOffset) { + assertMultipleOf4(size); // Don't need to check totalAvailable here because that's a looser constraint than the bytes remaining. long remaining = capacity - pos; if (size > remaining) { @@ -173,14 +187,69 @@ public class StagingBuffer { return; } - if (pos < start) { - // we rolled around, need to flush 2 ranges. - GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start); - GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos); - } else { - GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start); + flushUsedRegion(); + + var usedCapacity = dispatchComputeCopies(); + + fencedRegions.enqueue(new FencedRegion(new GlFence(), usedCapacity)); + + start = pos; + } + + private long dispatchComputeCopies() { + long usedCapacity = 0; + Int2ObjectMap> copiesPerVbo = new Int2ObjectArrayMap<>(); + + long bytesPerCopy = 64; + + for (var transfer : consolidateCopies(transfers)) { + usedCapacity += transfer.size; + + var forVbo = copiesPerVbo.computeIfAbsent(transfer.dstVbo, k -> new ArrayList<>()); + + long offset = 0; + long remaining = transfer.size; + + while (offset < transfer.size) { + long copySize = Math.min(remaining, bytesPerCopy); + forVbo.add(new Transfer(transfer.srcOffset + offset, transfer.dstVbo, transfer.dstOffset + offset, copySize)); + offset += copySize; + remaining -= copySize; + } } + IndirectPrograms.get() + .getScatterProgram() + .bind(); + + for (var entry : copiesPerVbo.int2ObjectEntrySet()) { + var dstVbo = entry.getIntKey(); + var transfers = entry.getValue(); + var copyCount = transfers.size(); + + var size = copyCount * Integer.BYTES * 3L; + var scratch = getScratch(size); + + putTransfers(scratch.ptr(), transfers); + + copyBuffer.upload(scratch.ptr(), size); + + GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 0, copyBuffer.handle()); + GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 1, vbo); + GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 2, dstVbo); + + GL45.glDispatchCompute(GlCompat.getComputeGroupCount(copyCount), 1, 1); + } + return usedCapacity; + } + + private void assertMultipleOf4(long size) { + if (size % 4 != 0) { + throw new IllegalArgumentException("Size must be a multiple of 4"); + } + } + + private long sendCopyCommands() { long usedCapacity = 0; for (Transfer transfer : consolidateCopies(transfers)) { @@ -188,10 +257,17 @@ public class StagingBuffer { GL45C.glCopyNamedBufferSubData(vbo, transfer.dstVbo, transfer.srcOffset, transfer.dstOffset, transfer.size); } + return usedCapacity; + } - fencedRegions.enqueue(new FencedRegion(new GlFence(), usedCapacity)); - - start = pos; + private void flushUsedRegion() { + if (pos < start) { + // we rolled around, need to flush 2 ranges. + GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start); + GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos); + } else { + GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start); + } } private static List consolidateCopies(PriorityQueue queue) { @@ -214,6 +290,15 @@ public class StagingBuffer { return merged; } + private static void putTransfers(long ptr, List transfers) { + for (Transfer transfer : transfers) { + MemoryUtil.memPutInt(ptr, (int) transfer.srcOffset); + MemoryUtil.memPutInt(ptr + Integer.BYTES, (int) transfer.dstOffset); + MemoryUtil.memPutInt(ptr + Integer.BYTES * 2, (int) transfer.size); + ptr += Integer.BYTES * 3; + } + } + private static boolean areContiguous(Transfer last, Transfer transfer) { return last.dstVbo == transfer.dstVbo && last.dstOffset + last.size == transfer.dstOffset && last.srcOffset + last.size == transfer.srcOffset; } @@ -238,7 +323,9 @@ public class StagingBuffer { GL45C.glDeleteBuffers(vbo); overflow.delete(); - scratch.free(); + if (scratch != null) { + scratch.free(); + } FlwMemoryTracker._freeCPUMemory(capacity); } @@ -271,7 +358,7 @@ public class StagingBuffer { vbo = GL45C.glCreateBuffers(); } - public void enqueueCopy(long ptr, long size, int dstVbo, long dstOffset) { + public void upload(long ptr, long size, int dstVbo, long dstOffset) { GL45C.nglNamedBufferData(vbo, size, ptr, GL45C.GL_STREAM_COPY); GL45C.glCopyNamedBufferSubData(vbo, dstVbo, 0, dstOffset, size); } diff --git a/src/main/java/com/jozufozu/flywheel/gl/buffer/GlBuffer.java b/src/main/java/com/jozufozu/flywheel/gl/buffer/GlBuffer.java index 001eacf81..0085e0b2b 100644 --- a/src/main/java/com/jozufozu/flywheel/gl/buffer/GlBuffer.java +++ b/src/main/java/com/jozufozu/flywheel/gl/buffer/GlBuffer.java @@ -82,11 +82,15 @@ public class GlBuffer extends GlObject { size = growthFunction.apply(capacity); } - public void upload(MemoryBlock directBuffer) { - FlwMemoryTracker._freeGPUMemory(size); - Buffer.IMPL.data(handle(), directBuffer.size(), directBuffer.ptr(), usage.glEnum); - size = directBuffer.size(); - FlwMemoryTracker._allocGPUMemory(size); + public void upload(MemoryBlock memoryBlock) { + upload(memoryBlock.ptr(), memoryBlock.size()); + } + + public void upload(long ptr, long size) { + FlwMemoryTracker._freeGPUMemory(this.size); + Buffer.IMPL.data(handle(), size, ptr, usage.glEnum); + this.size = size; + FlwMemoryTracker._allocGPUMemory(this.size); } public MappedBuffer map() { diff --git a/src/main/java/com/jozufozu/flywheel/lib/math/MoreMath.java b/src/main/java/com/jozufozu/flywheel/lib/math/MoreMath.java index 169a449a1..d0ed97451 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/math/MoreMath.java +++ b/src/main/java/com/jozufozu/flywheel/lib/math/MoreMath.java @@ -17,6 +17,10 @@ public final class MoreMath { return (numerator + denominator - 1) / denominator; } + public static long ceilingDiv(long numerator, long denominator) { + return (numerator + denominator - 1) / denominator; + } + public static int numDigits(int number) { // cursed but allegedly the fastest algorithm, taken from https://www.baeldung.com/java-number-of-digits-in-int if (number < 100000) { diff --git a/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl b/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl new file mode 100644 index 000000000..4513d7e9e --- /dev/null +++ b/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl @@ -0,0 +1,35 @@ +layout(local_size_x = _FLW_SUBGROUP_SIZE) in; + +struct Copy { + uint srcOffset; + uint dstOffset; + uint byteSize; +}; + +layout(std430, binding = 0) restrict readonly buffer Copies { + Copy copies[]; +}; + +layout(std430, binding = 1) restrict readonly buffer Src { + uint src[]; +}; + +layout(std430, binding = 2) restrict writeonly buffer Dst { + uint dst[]; +}; + +void main() { + uint copy = gl_GlobalInvocationID.x; + + if (copy >= copies.length()) { + return; + } + + uint srcOffset = copies[copy].srcOffset >> 2; + uint dstOffset = copies[copy].dstOffset >> 2; + uint size = copies[copy].byteSize >> 2; + + for (uint i = 0; i < size; i++) { + dst[dstOffset + i] = src[srcOffset + i]; + } +}