From 4e782b8dcd4acc34613cd46cc1d275da00bca9ec Mon Sep 17 00:00:00 2001 From: Jozufozu Date: Sun, 10 Dec 2023 13:18:19 -0800 Subject: [PATCH] Scattered to the winds - A scatter command is 2 uints: - The first contains the size and source offset in the upper byte and lower 3 bytes respectively. - The destination offset. - All offsets and sizes are in uints, not bytes. - Use ScatterList write scatter commands. - Use TransferList to collect transfers. - Rather than consolidating transfers in a separate pass, do so as they are collected. - Reorganize StagingBuffer. --- .../backend/engine/indirect/ScatterList.java | 105 +++++++ .../engine/indirect/StagingBuffer.java | 279 ++++++++---------- .../backend/engine/indirect/TransferList.java | 124 ++++++++ .../flywheel/internal/indirect/scatter.glsl | 17 +- 4 files changed, 363 insertions(+), 162 deletions(-) create mode 100644 src/main/java/com/jozufozu/flywheel/backend/engine/indirect/ScatterList.java create mode 100644 src/main/java/com/jozufozu/flywheel/backend/engine/indirect/TransferList.java diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/ScatterList.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/ScatterList.java new file mode 100644 index 000000000..4955f4cc7 --- /dev/null +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/ScatterList.java @@ -0,0 +1,105 @@ +package com.jozufozu.flywheel.backend.engine.indirect; + +import org.lwjgl.system.MemoryUtil; + +import com.jozufozu.flywheel.lib.memory.MemoryBlock; + +public class ScatterList { + public static final long STRIDE = Integer.BYTES * 2; + public final long maxBytesPerScatter; + private MemoryBlock block; + private int length; + private long usedBytes; + + public ScatterList() { + this(64); + } + + public ScatterList(long maxBytesPerScatter) { + if ((maxBytesPerScatter & 0b1111111100L) != maxBytesPerScatter) { + throw new IllegalArgumentException("Max bytes per scatter must be a multiple of 4 and less than 1024"); + } + + this.maxBytesPerScatter = maxBytesPerScatter; + } + + /** + * Breaks a transfer into many smaller scatter commands if it is too large, and appends them to this list. + * + * @param transfers The list of transfers to push. + * @param transferIndex The index of the transfer to push. + */ + public void pushTransfer(TransferList transfers, int transferIndex) { + long size = transfers.size(transferIndex); + long srcOffset = transfers.srcOffset(transferIndex); + long dstOffset = transfers.dstOffset(transferIndex); + + long offset = 0; + long remaining = size; + + while (offset < size) { + long copySize = Math.min(remaining, maxBytesPerScatter); + push(copySize, srcOffset + offset, dstOffset + offset); + offset += copySize; + remaining -= copySize; + } + } + + public void push(long sizeBytes, long srcOffsetBytes, long dstOffsetBytes) { + reallocIfNeeded(length); + + long ptr = block.ptr() + length * STRIDE; + MemoryUtil.memPutInt(ptr, packSizeAndSrcOffset(sizeBytes, srcOffsetBytes)); + MemoryUtil.memPutInt(ptr + Integer.BYTES, (int) (dstOffsetBytes >> 2)); + + length++; + usedBytes += STRIDE; + } + + public int copyCount() { + return length; + } + + public long usedBytes() { + return usedBytes; + } + + public boolean isEmpty() { + return length == 0; + } + + public void reset() { + length = 0; + usedBytes = 0; + } + + public long ptr() { + return block.ptr(); + } + + public void delete() { + if (block != null) { + block.free(); + } + } + + private void reallocIfNeeded(int index) { + if (block == null) { + block = MemoryBlock.malloc(neededCapacityForIndex(index + 8)); + } else if (block.size() < neededCapacityForIndex(index)) { + block = block.realloc(neededCapacityForIndex(index + 8)); + } + } + + private static long neededCapacityForIndex(int index) { + return (index + 1) * STRIDE; + } + + private static int packSizeAndSrcOffset(long sizeBytes, long srcOffsetBytes) { + // Divide by 4 and put the offset in the lower 3 bytes. + int out = (int) (srcOffsetBytes >>> 2) & 0xFFFFFF; + // Place the size divided by 4 in the upper byte. + out |= (int) (sizeBytes << 22) & 0xFF000000; + return out; + } +} diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java index 184c9d6f4..db787957a 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java @@ -1,7 +1,5 @@ package com.jozufozu.flywheel.backend.engine.indirect; -import java.util.ArrayList; -import java.util.List; import java.util.function.LongConsumer; import org.jetbrains.annotations.NotNull; @@ -18,8 +16,6 @@ import com.jozufozu.flywheel.lib.memory.FlwMemoryTracker; import com.jozufozu.flywheel.lib.memory.MemoryBlock; import it.unimi.dsi.fastutil.PriorityQueue; -import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap; -import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.objects.ObjectArrayFIFOQueue; // Used https://github.com/CaffeineMC/sodium-fabric/blob/dev/src/main/java/me/jellysquid/mods/sodium/client/gl/arena/staging/MappedStagingBuffer.java @@ -33,18 +29,40 @@ public class StagingBuffer { private final long map; private final long capacity; + /** + * The position in the buffer at the time of the last flush. + */ private long start = 0; + /** + * The current position in the buffer, + * incremented as transfers are enqueued. + */ private long pos = 0; + /** + * The number of bytes used in the buffer since the last flush, + * decremented as transfers are enqueued. + */ + private long usedCapacity = 0; + /** + * The number of bytes available in the buffer. + *
+ * This decreases as transfers are enqueued and increases as fenced regions are reclaimed. + */ private long totalAvailable; + /** + * A scratch buffer for when there is not enough contiguous space + * in the staging buffer for the write the user wants to make. + */ @Nullable private MemoryBlock scratch; - private final GlBuffer copyBuffer = new GlBuffer(); private final OverflowStagingBuffer overflow = new OverflowStagingBuffer(); - private final PriorityQueue transfers = new ObjectArrayFIFOQueue<>(); + private final TransferList transfers = new TransferList(); private final PriorityQueue fencedRegions = new ObjectArrayFIFOQueue<>(); + private final GlBuffer scatterBuffer = new GlBuffer(); + private final ScatterList scatterList = new ScatterList(); public StagingBuffer() { this(DEFAULT_CAPACITY); @@ -75,7 +93,7 @@ public class StagingBuffer { */ public void enqueueCopy(long size, int dstVbo, long dstOffset, LongConsumer write) { // Try to write directly into the staging buffer if there is enough contiguous space. - var direct = reserveForTransferTo(size, dstVbo, dstOffset); + var direct = reserveForCopy(size, dstVbo, dstOffset); if (direct != MemoryUtil.NULL) { write.accept(direct); @@ -88,27 +106,6 @@ public class StagingBuffer { enqueueCopy(block.ptr(), size, dstVbo, dstOffset); } - @NotNull - private MemoryBlock getScratch(long size) { - if (scratch == null) { - scratch = MemoryBlock.malloc(size); - } else if (scratch.size() < size) { - scratch = scratch.realloc(size); - } - return scratch; - } - - /** - * Enqueue a copy from the given pointer to the given VBO. - * - * @param block The block to copy from. - * @param dstVbo The VBO to copy to. - * @param dstOffset The offset in the destination VBO. - */ - public void enqueueCopy(MemoryBlock block, int dstVbo, long dstOffset) { - enqueueCopy(block.ptr(), block.size(), dstVbo, dstOffset); - } - /** * Enqueue a copy from the given pointer to the given VBO. * @@ -132,21 +129,19 @@ public class StagingBuffer { // Put the first span at the tail of the buffer... MemoryUtil.memCopy(ptr, map + pos, remaining); - transfers.enqueue(new Transfer(pos, dstVbo, dstOffset, remaining)); + pushTransfer(dstVbo, pos, dstOffset, remaining); // ... and the rest at the head. MemoryUtil.memCopy(ptr + remaining, map, split); - transfers.enqueue(new Transfer(0, dstVbo, dstOffset + remaining, split)); + pushTransfer(dstVbo, 0, dstOffset + remaining, split); pos = split; } else { MemoryUtil.memCopy(ptr, map + pos, size); - transfers.enqueue(new Transfer(pos, dstVbo, dstOffset, size)); + pushTransfer(dstVbo, pos, dstOffset, size); pos += size; } - - totalAvailable -= size; } /** @@ -163,7 +158,7 @@ public class StagingBuffer { * @param dstOffset The offset in the destination VBO. * @return A pointer to the reserved space, or {@code null} if there is not enough contiguous space. */ - public long reserveForTransferTo(long size, int dstVbo, long dstOffset) { + public long reserveForCopy(long size, int dstVbo, long dstOffset) { assertMultipleOf4(size); // Don't need to check totalAvailable here because that's a looser constraint than the bytes remaining. long remaining = capacity - pos; @@ -173,12 +168,10 @@ public class StagingBuffer { long out = map + pos; - transfers.enqueue(new Transfer(pos, dstVbo, dstOffset, size)); + pushTransfer(dstVbo, pos, dstOffset, size); pos += size; - totalAvailable -= size; - return out; } @@ -189,120 +182,15 @@ public class StagingBuffer { flushUsedRegion(); - var usedCapacity = dispatchComputeCopies(); + dispatchComputeCopies(); + transfers.reset(); fencedRegions.enqueue(new FencedRegion(new GlFence(), usedCapacity)); + usedCapacity = 0; start = pos; } - private long dispatchComputeCopies() { - long usedCapacity = 0; - Int2ObjectMap> copiesPerVbo = new Int2ObjectArrayMap<>(); - - long bytesPerCopy = 64; - - for (var transfer : consolidateCopies(transfers)) { - usedCapacity += transfer.size; - - var forVbo = copiesPerVbo.computeIfAbsent(transfer.dstVbo, k -> new ArrayList<>()); - - long offset = 0; - long remaining = transfer.size; - - while (offset < transfer.size) { - long copySize = Math.min(remaining, bytesPerCopy); - forVbo.add(new Transfer(transfer.srcOffset + offset, transfer.dstVbo, transfer.dstOffset + offset, copySize)); - offset += copySize; - remaining -= copySize; - } - } - - IndirectPrograms.get() - .getScatterProgram() - .bind(); - - for (var entry : copiesPerVbo.int2ObjectEntrySet()) { - var dstVbo = entry.getIntKey(); - var transfers = entry.getValue(); - var copyCount = transfers.size(); - - var size = copyCount * Integer.BYTES * 3L; - var scratch = getScratch(size); - - putTransfers(scratch.ptr(), transfers); - - copyBuffer.upload(scratch.ptr(), size); - - GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 0, copyBuffer.handle()); - GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 1, vbo); - GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 2, dstVbo); - - GL45.glDispatchCompute(GlCompat.getComputeGroupCount(copyCount), 1, 1); - } - return usedCapacity; - } - - private void assertMultipleOf4(long size) { - if (size % 4 != 0) { - throw new IllegalArgumentException("Size must be a multiple of 4"); - } - } - - private long sendCopyCommands() { - long usedCapacity = 0; - - for (Transfer transfer : consolidateCopies(transfers)) { - usedCapacity += transfer.size; - - GL45C.glCopyNamedBufferSubData(vbo, transfer.dstVbo, transfer.srcOffset, transfer.dstOffset, transfer.size); - } - return usedCapacity; - } - - private void flushUsedRegion() { - if (pos < start) { - // we rolled around, need to flush 2 ranges. - GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start); - GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos); - } else { - GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start); - } - } - - private static List consolidateCopies(PriorityQueue queue) { - List merged = new ArrayList<>(); - Transfer last = null; - - while (!queue.isEmpty()) { - Transfer transfer = queue.dequeue(); - - if (last != null) { - if (areContiguous(last, transfer)) { - last.size += transfer.size; - continue; - } - } - - merged.add(last = new Transfer(transfer)); - } - - return merged; - } - - private static void putTransfers(long ptr, List transfers) { - for (Transfer transfer : transfers) { - MemoryUtil.memPutInt(ptr, (int) transfer.srcOffset); - MemoryUtil.memPutInt(ptr + Integer.BYTES, (int) transfer.dstOffset); - MemoryUtil.memPutInt(ptr + Integer.BYTES * 2, (int) transfer.size); - ptr += Integer.BYTES * 3; - } - } - - private static boolean areContiguous(Transfer last, Transfer transfer) { - return last.dstVbo == transfer.dstVbo && last.dstOffset + last.size == transfer.dstOffset && last.srcOffset + last.size == transfer.srcOffset; - } - public void reclaim() { while (!fencedRegions.isEmpty()) { var region = fencedRegions.first(); @@ -322,29 +210,106 @@ public class StagingBuffer { GL45C.glUnmapNamedBuffer(vbo); GL45C.glDeleteBuffers(vbo); overflow.delete(); + scatterBuffer.delete(); if (scratch != null) { scratch.free(); } + transfers.delete(); + scatterList.delete(); + FlwMemoryTracker._freeCPUMemory(capacity); } - private static final class Transfer { - private final long srcOffset; - private final int dstVbo; - private final long dstOffset; - private long size; + @NotNull + private MemoryBlock getScratch(long size) { + if (scratch == null) { + scratch = MemoryBlock.malloc(size); + } else if (scratch.size() < size) { + scratch = scratch.realloc(size); + } + return scratch; + } - private Transfer(long srcOffset, int dstVbo, long dstOffset, long size) { - this.srcOffset = srcOffset; - this.dstVbo = dstVbo; - this.dstOffset = dstOffset; - this.size = size; + private void pushTransfer(int dstVbo, long srcOffset, long dstOffset, long size) { + transfers.push(dstVbo, srcOffset, dstOffset, size); + usedCapacity += size; + totalAvailable -= size; + } + + /** + * We could just use {@link #sendCopyCommands}, but that has significant + * overhead for many small transfers, such as when the object buffer is sparsely updated. + *
+ * Instead, we use a compute shader to scatter the data from the staging buffer to the destination VBOs. + * This approach is recommended by nvidia in + * this presentation + */ + private void dispatchComputeCopies() { + IndirectPrograms.get() + .getScatterProgram() + .bind(); + + // These bindings don't change between dstVbos. + GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 0, scatterBuffer.handle()); + GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 1, vbo); + + int dstVbo; + var transferCount = transfers.length(); + for (int i = 0; i < transferCount; i++) { + dstVbo = transfers.vbo(i); + + scatterList.pushTransfer(transfers, i); + + int nextVbo = i == transferCount - 1 ? -1 : transfers.vbo(i + 1); + + // If we're switching VBOs, dispatch the copies for the previous VBO. + // Generally VBOs don't appear in multiple spans of the list, + // so submitting duplicates is rare. + if (dstVbo != nextVbo) { + dispatchScatter(dstVbo); + } + } + } + + private void dispatchScatter(int dstVbo) { + scatterBuffer.upload(scatterList.ptr(), scatterList.usedBytes()); + + GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 2, dstVbo); + + GL45.glDispatchCompute(GlCompat.getComputeGroupCount(scatterList.copyCount()), 1, 1); + + scatterList.reset(); + } + + private void assertMultipleOf4(long size) { + if (size % 4 != 0) { + throw new IllegalArgumentException("Size must be a multiple of 4"); + } + } + + private long sendCopyCommands() { + long usedCapacity = 0; + + for (int i = 0; i < transfers.length(); i++) { + var size = transfers.size(i); + + usedCapacity += size; + + GL45C.glCopyNamedBufferSubData(vbo, transfers.vbo(i), transfers.srcOffset(i), transfers.dstOffset(i), size); } - public Transfer(Transfer other) { - this(other.srcOffset, other.dstVbo, other.dstOffset, other.size); + return usedCapacity; + } + + private void flushUsedRegion() { + if (pos < start) { + // we rolled around, need to flush 2 ranges. + GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start); + GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos); + } else { + GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start); } } diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/TransferList.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/TransferList.java new file mode 100644 index 000000000..bcba94761 --- /dev/null +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/TransferList.java @@ -0,0 +1,124 @@ +package com.jozufozu.flywheel.backend.engine.indirect; + +import org.lwjgl.system.MemoryUtil; + +import com.jozufozu.flywheel.lib.memory.MemoryBlock; + +public class TransferList { + private static final long STRIDE = Long.BYTES * 4; + private MemoryBlock block; + private int length; + + /** + * Append a transfer to the end of the list, combining with the last transfer if possible. + * + * @param vbo The VBO to transfer to. + * @param srcOffset The offset in the staging buffer. + * @param dstOffset The offset in the VBO. + * @param size The size of the transfer. + */ + public void push(int vbo, long srcOffset, long dstOffset, long size) { + if (continuesLast(vbo, srcOffset, dstOffset)) { + int lastIndex = length - 1; + size(lastIndex, size(lastIndex) + size); + return; + } + + reallocIfNeeded(length); + + vbo(length, vbo); + srcOffset(length, srcOffset); + dstOffset(length, dstOffset); + size(length, size); + + length++; + } + + /** + * @return The number of transfers in the list. + */ + public int length() { + return length; + } + + /** + * @return {@code true} if there are no transfers in the list, {@code false} otherwise. + */ + public boolean isEmpty() { + return length == 0; + } + + /** + * Reset the list to be empty. + */ + public void reset() { + length = 0; + } + + public int vbo(int index) { + return MemoryUtil.memGetInt(ptrForIndex(index)); + } + + public long srcOffset(int index) { + return MemoryUtil.memGetLong(ptrForIndex(index) + Long.BYTES); + } + + public long dstOffset(int index) { + return MemoryUtil.memGetLong(ptrForIndex(index) + Long.BYTES * 2); + } + + public long size(int index) { + return MemoryUtil.memGetLong(ptrForIndex(index) + Long.BYTES * 3); + } + + public void delete() { + if (block != null) { + block.free(); + } + } + + private boolean continuesLast(int vbo, long srcOffset, long dstOffset) { + if (length == 0) { + return false; + } + int lastIndex = length - 1; + var lastSize = size(lastIndex); + return vbo(lastIndex) == vbo && dstOffset(lastIndex) + lastSize == dstOffset && srcOffset(lastIndex) + lastSize == srcOffset; + } + + private void vbo(int index, int vbo) { + MemoryUtil.memPutInt(ptrForIndex(index), vbo); + } + + private void srcOffset(int index, long srcOffset) { + MemoryUtil.memPutLong(ptrForIndex(index) + Long.BYTES, srcOffset); + } + + private void dstOffset(int index, long dstOffset) { + MemoryUtil.memPutLong(ptrForIndex(index) + Long.BYTES * 2, dstOffset); + } + + private void size(int index, long size) { + MemoryUtil.memPutLong(ptrForIndex(index) + Long.BYTES * 3, size); + } + + private void reallocIfNeeded(int index) { + if (block == null) { + block = MemoryBlock.malloc(neededCapacityForIndex(index + 8)); + } else if (block.size() < neededCapacityForIndex(index)) { + block = block.realloc(neededCapacityForIndex(index + 8)); + } + } + + private long ptrForIndex(int index) { + return block.ptr() + bytePosForIndex(index); + } + + private static long bytePosForIndex(int index) { + return index * STRIDE; + } + + private static long neededCapacityForIndex(int index) { + return (index + 1) * STRIDE; + } +} diff --git a/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl b/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl index 4513d7e9e..489cace85 100644 --- a/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl +++ b/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl @@ -1,9 +1,14 @@ layout(local_size_x = _FLW_SUBGROUP_SIZE) in; +const uint SRC_OFFSET_MASK = 0xFFFFFF; + +// Since StagingBuffer is 16MB, a source offset *into an array of uints* can be represented with 22 bits. +// We use 24 here for some wiggle room. +// The lower 24 bits are the offset into the Src buffer. +// The upper 8 bits are the size of the copy. struct Copy { - uint srcOffset; + uint sizeAndSrcOffset; uint dstOffset; - uint byteSize; }; layout(std430, binding = 0) restrict readonly buffer Copies { @@ -25,9 +30,11 @@ void main() { return; } - uint srcOffset = copies[copy].srcOffset >> 2; - uint dstOffset = copies[copy].dstOffset >> 2; - uint size = copies[copy].byteSize >> 2; + uint sizeAndSrcOffset = copies[copy].sizeAndSrcOffset; + uint srcOffset = sizeAndSrcOffset & SRC_OFFSET_MASK; + uint size = sizeAndSrcOffset >> 24; + + uint dstOffset = copies[copy].dstOffset; for (uint i = 0; i < size; i++) { dst[dstOffset + i] = src[srcOffset + i];