diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/ScatterList.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/ScatterList.java
new file mode 100644
index 000000000..4955f4cc7
--- /dev/null
+++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/ScatterList.java
@@ -0,0 +1,105 @@
+package com.jozufozu.flywheel.backend.engine.indirect;
+
+import org.lwjgl.system.MemoryUtil;
+
+import com.jozufozu.flywheel.lib.memory.MemoryBlock;
+
+public class ScatterList {
+ public static final long STRIDE = Integer.BYTES * 2;
+ public final long maxBytesPerScatter;
+ private MemoryBlock block;
+ private int length;
+ private long usedBytes;
+
+ public ScatterList() {
+ this(64);
+ }
+
+ public ScatterList(long maxBytesPerScatter) {
+ if ((maxBytesPerScatter & 0b1111111100L) != maxBytesPerScatter) {
+ throw new IllegalArgumentException("Max bytes per scatter must be a multiple of 4 and less than 1024");
+ }
+
+ this.maxBytesPerScatter = maxBytesPerScatter;
+ }
+
+ /**
+ * Breaks a transfer into many smaller scatter commands if it is too large, and appends them to this list.
+ *
+ * @param transfers The list of transfers to push.
+ * @param transferIndex The index of the transfer to push.
+ */
+ public void pushTransfer(TransferList transfers, int transferIndex) {
+ long size = transfers.size(transferIndex);
+ long srcOffset = transfers.srcOffset(transferIndex);
+ long dstOffset = transfers.dstOffset(transferIndex);
+
+ long offset = 0;
+ long remaining = size;
+
+ while (offset < size) {
+ long copySize = Math.min(remaining, maxBytesPerScatter);
+ push(copySize, srcOffset + offset, dstOffset + offset);
+ offset += copySize;
+ remaining -= copySize;
+ }
+ }
+
+ public void push(long sizeBytes, long srcOffsetBytes, long dstOffsetBytes) {
+ reallocIfNeeded(length);
+
+ long ptr = block.ptr() + length * STRIDE;
+ MemoryUtil.memPutInt(ptr, packSizeAndSrcOffset(sizeBytes, srcOffsetBytes));
+ MemoryUtil.memPutInt(ptr + Integer.BYTES, (int) (dstOffsetBytes >> 2));
+
+ length++;
+ usedBytes += STRIDE;
+ }
+
+ public int copyCount() {
+ return length;
+ }
+
+ public long usedBytes() {
+ return usedBytes;
+ }
+
+ public boolean isEmpty() {
+ return length == 0;
+ }
+
+ public void reset() {
+ length = 0;
+ usedBytes = 0;
+ }
+
+ public long ptr() {
+ return block.ptr();
+ }
+
+ public void delete() {
+ if (block != null) {
+ block.free();
+ }
+ }
+
+ private void reallocIfNeeded(int index) {
+ if (block == null) {
+ block = MemoryBlock.malloc(neededCapacityForIndex(index + 8));
+ } else if (block.size() < neededCapacityForIndex(index)) {
+ block = block.realloc(neededCapacityForIndex(index + 8));
+ }
+ }
+
+ private static long neededCapacityForIndex(int index) {
+ return (index + 1) * STRIDE;
+ }
+
+ private static int packSizeAndSrcOffset(long sizeBytes, long srcOffsetBytes) {
+ // Divide by 4 and put the offset in the lower 3 bytes.
+ int out = (int) (srcOffsetBytes >>> 2) & 0xFFFFFF;
+ // Place the size divided by 4 in the upper byte.
+ out |= (int) (sizeBytes << 22) & 0xFF000000;
+ return out;
+ }
+}
diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java
index 184c9d6f4..db787957a 100644
--- a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java
+++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/StagingBuffer.java
@@ -1,7 +1,5 @@
package com.jozufozu.flywheel.backend.engine.indirect;
-import java.util.ArrayList;
-import java.util.List;
import java.util.function.LongConsumer;
import org.jetbrains.annotations.NotNull;
@@ -18,8 +16,6 @@ import com.jozufozu.flywheel.lib.memory.FlwMemoryTracker;
import com.jozufozu.flywheel.lib.memory.MemoryBlock;
import it.unimi.dsi.fastutil.PriorityQueue;
-import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
-import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.objects.ObjectArrayFIFOQueue;
// Used https://github.com/CaffeineMC/sodium-fabric/blob/dev/src/main/java/me/jellysquid/mods/sodium/client/gl/arena/staging/MappedStagingBuffer.java
@@ -33,18 +29,40 @@ public class StagingBuffer {
private final long map;
private final long capacity;
+ /**
+ * The position in the buffer at the time of the last flush.
+ */
private long start = 0;
+ /**
+ * The current position in the buffer,
+ * incremented as transfers are enqueued.
+ */
private long pos = 0;
+ /**
+ * The number of bytes used in the buffer since the last flush,
+ * decremented as transfers are enqueued.
+ */
+ private long usedCapacity = 0;
+ /**
+ * The number of bytes available in the buffer.
+ *
+ * This decreases as transfers are enqueued and increases as fenced regions are reclaimed.
+ */
private long totalAvailable;
+ /**
+ * A scratch buffer for when there is not enough contiguous space
+ * in the staging buffer for the write the user wants to make.
+ */
@Nullable
private MemoryBlock scratch;
- private final GlBuffer copyBuffer = new GlBuffer();
private final OverflowStagingBuffer overflow = new OverflowStagingBuffer();
- private final PriorityQueue transfers = new ObjectArrayFIFOQueue<>();
+ private final TransferList transfers = new TransferList();
private final PriorityQueue fencedRegions = new ObjectArrayFIFOQueue<>();
+ private final GlBuffer scatterBuffer = new GlBuffer();
+ private final ScatterList scatterList = new ScatterList();
public StagingBuffer() {
this(DEFAULT_CAPACITY);
@@ -75,7 +93,7 @@ public class StagingBuffer {
*/
public void enqueueCopy(long size, int dstVbo, long dstOffset, LongConsumer write) {
// Try to write directly into the staging buffer if there is enough contiguous space.
- var direct = reserveForTransferTo(size, dstVbo, dstOffset);
+ var direct = reserveForCopy(size, dstVbo, dstOffset);
if (direct != MemoryUtil.NULL) {
write.accept(direct);
@@ -88,27 +106,6 @@ public class StagingBuffer {
enqueueCopy(block.ptr(), size, dstVbo, dstOffset);
}
- @NotNull
- private MemoryBlock getScratch(long size) {
- if (scratch == null) {
- scratch = MemoryBlock.malloc(size);
- } else if (scratch.size() < size) {
- scratch = scratch.realloc(size);
- }
- return scratch;
- }
-
- /**
- * Enqueue a copy from the given pointer to the given VBO.
- *
- * @param block The block to copy from.
- * @param dstVbo The VBO to copy to.
- * @param dstOffset The offset in the destination VBO.
- */
- public void enqueueCopy(MemoryBlock block, int dstVbo, long dstOffset) {
- enqueueCopy(block.ptr(), block.size(), dstVbo, dstOffset);
- }
-
/**
* Enqueue a copy from the given pointer to the given VBO.
*
@@ -132,21 +129,19 @@ public class StagingBuffer {
// Put the first span at the tail of the buffer...
MemoryUtil.memCopy(ptr, map + pos, remaining);
- transfers.enqueue(new Transfer(pos, dstVbo, dstOffset, remaining));
+ pushTransfer(dstVbo, pos, dstOffset, remaining);
// ... and the rest at the head.
MemoryUtil.memCopy(ptr + remaining, map, split);
- transfers.enqueue(new Transfer(0, dstVbo, dstOffset + remaining, split));
+ pushTransfer(dstVbo, 0, dstOffset + remaining, split);
pos = split;
} else {
MemoryUtil.memCopy(ptr, map + pos, size);
- transfers.enqueue(new Transfer(pos, dstVbo, dstOffset, size));
+ pushTransfer(dstVbo, pos, dstOffset, size);
pos += size;
}
-
- totalAvailable -= size;
}
/**
@@ -163,7 +158,7 @@ public class StagingBuffer {
* @param dstOffset The offset in the destination VBO.
* @return A pointer to the reserved space, or {@code null} if there is not enough contiguous space.
*/
- public long reserveForTransferTo(long size, int dstVbo, long dstOffset) {
+ public long reserveForCopy(long size, int dstVbo, long dstOffset) {
assertMultipleOf4(size);
// Don't need to check totalAvailable here because that's a looser constraint than the bytes remaining.
long remaining = capacity - pos;
@@ -173,12 +168,10 @@ public class StagingBuffer {
long out = map + pos;
- transfers.enqueue(new Transfer(pos, dstVbo, dstOffset, size));
+ pushTransfer(dstVbo, pos, dstOffset, size);
pos += size;
- totalAvailable -= size;
-
return out;
}
@@ -189,120 +182,15 @@ public class StagingBuffer {
flushUsedRegion();
- var usedCapacity = dispatchComputeCopies();
+ dispatchComputeCopies();
+ transfers.reset();
fencedRegions.enqueue(new FencedRegion(new GlFence(), usedCapacity));
+ usedCapacity = 0;
start = pos;
}
- private long dispatchComputeCopies() {
- long usedCapacity = 0;
- Int2ObjectMap> copiesPerVbo = new Int2ObjectArrayMap<>();
-
- long bytesPerCopy = 64;
-
- for (var transfer : consolidateCopies(transfers)) {
- usedCapacity += transfer.size;
-
- var forVbo = copiesPerVbo.computeIfAbsent(transfer.dstVbo, k -> new ArrayList<>());
-
- long offset = 0;
- long remaining = transfer.size;
-
- while (offset < transfer.size) {
- long copySize = Math.min(remaining, bytesPerCopy);
- forVbo.add(new Transfer(transfer.srcOffset + offset, transfer.dstVbo, transfer.dstOffset + offset, copySize));
- offset += copySize;
- remaining -= copySize;
- }
- }
-
- IndirectPrograms.get()
- .getScatterProgram()
- .bind();
-
- for (var entry : copiesPerVbo.int2ObjectEntrySet()) {
- var dstVbo = entry.getIntKey();
- var transfers = entry.getValue();
- var copyCount = transfers.size();
-
- var size = copyCount * Integer.BYTES * 3L;
- var scratch = getScratch(size);
-
- putTransfers(scratch.ptr(), transfers);
-
- copyBuffer.upload(scratch.ptr(), size);
-
- GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 0, copyBuffer.handle());
- GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 1, vbo);
- GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 2, dstVbo);
-
- GL45.glDispatchCompute(GlCompat.getComputeGroupCount(copyCount), 1, 1);
- }
- return usedCapacity;
- }
-
- private void assertMultipleOf4(long size) {
- if (size % 4 != 0) {
- throw new IllegalArgumentException("Size must be a multiple of 4");
- }
- }
-
- private long sendCopyCommands() {
- long usedCapacity = 0;
-
- for (Transfer transfer : consolidateCopies(transfers)) {
- usedCapacity += transfer.size;
-
- GL45C.glCopyNamedBufferSubData(vbo, transfer.dstVbo, transfer.srcOffset, transfer.dstOffset, transfer.size);
- }
- return usedCapacity;
- }
-
- private void flushUsedRegion() {
- if (pos < start) {
- // we rolled around, need to flush 2 ranges.
- GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start);
- GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos);
- } else {
- GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start);
- }
- }
-
- private static List consolidateCopies(PriorityQueue queue) {
- List merged = new ArrayList<>();
- Transfer last = null;
-
- while (!queue.isEmpty()) {
- Transfer transfer = queue.dequeue();
-
- if (last != null) {
- if (areContiguous(last, transfer)) {
- last.size += transfer.size;
- continue;
- }
- }
-
- merged.add(last = new Transfer(transfer));
- }
-
- return merged;
- }
-
- private static void putTransfers(long ptr, List transfers) {
- for (Transfer transfer : transfers) {
- MemoryUtil.memPutInt(ptr, (int) transfer.srcOffset);
- MemoryUtil.memPutInt(ptr + Integer.BYTES, (int) transfer.dstOffset);
- MemoryUtil.memPutInt(ptr + Integer.BYTES * 2, (int) transfer.size);
- ptr += Integer.BYTES * 3;
- }
- }
-
- private static boolean areContiguous(Transfer last, Transfer transfer) {
- return last.dstVbo == transfer.dstVbo && last.dstOffset + last.size == transfer.dstOffset && last.srcOffset + last.size == transfer.srcOffset;
- }
-
public void reclaim() {
while (!fencedRegions.isEmpty()) {
var region = fencedRegions.first();
@@ -322,29 +210,106 @@ public class StagingBuffer {
GL45C.glUnmapNamedBuffer(vbo);
GL45C.glDeleteBuffers(vbo);
overflow.delete();
+ scatterBuffer.delete();
if (scratch != null) {
scratch.free();
}
+ transfers.delete();
+ scatterList.delete();
+
FlwMemoryTracker._freeCPUMemory(capacity);
}
- private static final class Transfer {
- private final long srcOffset;
- private final int dstVbo;
- private final long dstOffset;
- private long size;
+ @NotNull
+ private MemoryBlock getScratch(long size) {
+ if (scratch == null) {
+ scratch = MemoryBlock.malloc(size);
+ } else if (scratch.size() < size) {
+ scratch = scratch.realloc(size);
+ }
+ return scratch;
+ }
- private Transfer(long srcOffset, int dstVbo, long dstOffset, long size) {
- this.srcOffset = srcOffset;
- this.dstVbo = dstVbo;
- this.dstOffset = dstOffset;
- this.size = size;
+ private void pushTransfer(int dstVbo, long srcOffset, long dstOffset, long size) {
+ transfers.push(dstVbo, srcOffset, dstOffset, size);
+ usedCapacity += size;
+ totalAvailable -= size;
+ }
+
+ /**
+ * We could just use {@link #sendCopyCommands}, but that has significant
+ * overhead for many small transfers, such as when the object buffer is sparsely updated.
+ *
+ * Instead, we use a compute shader to scatter the data from the staging buffer to the destination VBOs.
+ * This approach is recommended by nvidia in
+ * this presentation
+ */
+ private void dispatchComputeCopies() {
+ IndirectPrograms.get()
+ .getScatterProgram()
+ .bind();
+
+ // These bindings don't change between dstVbos.
+ GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 0, scatterBuffer.handle());
+ GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 1, vbo);
+
+ int dstVbo;
+ var transferCount = transfers.length();
+ for (int i = 0; i < transferCount; i++) {
+ dstVbo = transfers.vbo(i);
+
+ scatterList.pushTransfer(transfers, i);
+
+ int nextVbo = i == transferCount - 1 ? -1 : transfers.vbo(i + 1);
+
+ // If we're switching VBOs, dispatch the copies for the previous VBO.
+ // Generally VBOs don't appear in multiple spans of the list,
+ // so submitting duplicates is rare.
+ if (dstVbo != nextVbo) {
+ dispatchScatter(dstVbo);
+ }
+ }
+ }
+
+ private void dispatchScatter(int dstVbo) {
+ scatterBuffer.upload(scatterList.ptr(), scatterList.usedBytes());
+
+ GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 2, dstVbo);
+
+ GL45.glDispatchCompute(GlCompat.getComputeGroupCount(scatterList.copyCount()), 1, 1);
+
+ scatterList.reset();
+ }
+
+ private void assertMultipleOf4(long size) {
+ if (size % 4 != 0) {
+ throw new IllegalArgumentException("Size must be a multiple of 4");
+ }
+ }
+
+ private long sendCopyCommands() {
+ long usedCapacity = 0;
+
+ for (int i = 0; i < transfers.length(); i++) {
+ var size = transfers.size(i);
+
+ usedCapacity += size;
+
+ GL45C.glCopyNamedBufferSubData(vbo, transfers.vbo(i), transfers.srcOffset(i), transfers.dstOffset(i), size);
}
- public Transfer(Transfer other) {
- this(other.srcOffset, other.dstVbo, other.dstOffset, other.size);
+ return usedCapacity;
+ }
+
+ private void flushUsedRegion() {
+ if (pos < start) {
+ // we rolled around, need to flush 2 ranges.
+ GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start);
+ GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos);
+ } else {
+ GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start);
}
}
diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/TransferList.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/TransferList.java
new file mode 100644
index 000000000..bcba94761
--- /dev/null
+++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/TransferList.java
@@ -0,0 +1,124 @@
+package com.jozufozu.flywheel.backend.engine.indirect;
+
+import org.lwjgl.system.MemoryUtil;
+
+import com.jozufozu.flywheel.lib.memory.MemoryBlock;
+
+public class TransferList {
+ private static final long STRIDE = Long.BYTES * 4;
+ private MemoryBlock block;
+ private int length;
+
+ /**
+ * Append a transfer to the end of the list, combining with the last transfer if possible.
+ *
+ * @param vbo The VBO to transfer to.
+ * @param srcOffset The offset in the staging buffer.
+ * @param dstOffset The offset in the VBO.
+ * @param size The size of the transfer.
+ */
+ public void push(int vbo, long srcOffset, long dstOffset, long size) {
+ if (continuesLast(vbo, srcOffset, dstOffset)) {
+ int lastIndex = length - 1;
+ size(lastIndex, size(lastIndex) + size);
+ return;
+ }
+
+ reallocIfNeeded(length);
+
+ vbo(length, vbo);
+ srcOffset(length, srcOffset);
+ dstOffset(length, dstOffset);
+ size(length, size);
+
+ length++;
+ }
+
+ /**
+ * @return The number of transfers in the list.
+ */
+ public int length() {
+ return length;
+ }
+
+ /**
+ * @return {@code true} if there are no transfers in the list, {@code false} otherwise.
+ */
+ public boolean isEmpty() {
+ return length == 0;
+ }
+
+ /**
+ * Reset the list to be empty.
+ */
+ public void reset() {
+ length = 0;
+ }
+
+ public int vbo(int index) {
+ return MemoryUtil.memGetInt(ptrForIndex(index));
+ }
+
+ public long srcOffset(int index) {
+ return MemoryUtil.memGetLong(ptrForIndex(index) + Long.BYTES);
+ }
+
+ public long dstOffset(int index) {
+ return MemoryUtil.memGetLong(ptrForIndex(index) + Long.BYTES * 2);
+ }
+
+ public long size(int index) {
+ return MemoryUtil.memGetLong(ptrForIndex(index) + Long.BYTES * 3);
+ }
+
+ public void delete() {
+ if (block != null) {
+ block.free();
+ }
+ }
+
+ private boolean continuesLast(int vbo, long srcOffset, long dstOffset) {
+ if (length == 0) {
+ return false;
+ }
+ int lastIndex = length - 1;
+ var lastSize = size(lastIndex);
+ return vbo(lastIndex) == vbo && dstOffset(lastIndex) + lastSize == dstOffset && srcOffset(lastIndex) + lastSize == srcOffset;
+ }
+
+ private void vbo(int index, int vbo) {
+ MemoryUtil.memPutInt(ptrForIndex(index), vbo);
+ }
+
+ private void srcOffset(int index, long srcOffset) {
+ MemoryUtil.memPutLong(ptrForIndex(index) + Long.BYTES, srcOffset);
+ }
+
+ private void dstOffset(int index, long dstOffset) {
+ MemoryUtil.memPutLong(ptrForIndex(index) + Long.BYTES * 2, dstOffset);
+ }
+
+ private void size(int index, long size) {
+ MemoryUtil.memPutLong(ptrForIndex(index) + Long.BYTES * 3, size);
+ }
+
+ private void reallocIfNeeded(int index) {
+ if (block == null) {
+ block = MemoryBlock.malloc(neededCapacityForIndex(index + 8));
+ } else if (block.size() < neededCapacityForIndex(index)) {
+ block = block.realloc(neededCapacityForIndex(index + 8));
+ }
+ }
+
+ private long ptrForIndex(int index) {
+ return block.ptr() + bytePosForIndex(index);
+ }
+
+ private static long bytePosForIndex(int index) {
+ return index * STRIDE;
+ }
+
+ private static long neededCapacityForIndex(int index) {
+ return (index + 1) * STRIDE;
+ }
+}
diff --git a/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl b/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl
index 4513d7e9e..489cace85 100644
--- a/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl
+++ b/src/main/resources/assets/flywheel/flywheel/internal/indirect/scatter.glsl
@@ -1,9 +1,14 @@
layout(local_size_x = _FLW_SUBGROUP_SIZE) in;
+const uint SRC_OFFSET_MASK = 0xFFFFFF;
+
+// Since StagingBuffer is 16MB, a source offset *into an array of uints* can be represented with 22 bits.
+// We use 24 here for some wiggle room.
+// The lower 24 bits are the offset into the Src buffer.
+// The upper 8 bits are the size of the copy.
struct Copy {
- uint srcOffset;
+ uint sizeAndSrcOffset;
uint dstOffset;
- uint byteSize;
};
layout(std430, binding = 0) restrict readonly buffer Copies {
@@ -25,9 +30,11 @@ void main() {
return;
}
- uint srcOffset = copies[copy].srcOffset >> 2;
- uint dstOffset = copies[copy].dstOffset >> 2;
- uint size = copies[copy].byteSize >> 2;
+ uint sizeAndSrcOffset = copies[copy].sizeAndSrcOffset;
+ uint srcOffset = sizeAndSrcOffset & SRC_OFFSET_MASK;
+ uint size = sizeAndSrcOffset >> 24;
+
+ uint dstOffset = copies[copy].dstOffset;
for (uint i = 0; i < size; i++) {
dst[dstOffset + i] = src[srcOffset + i];