Thought I'd need this

- Consolidate common logic from transfer list and scatter list into
  MemoryBuffer
This commit is contained in:
Jozufozu 2024-03-02 21:30:08 -08:00
parent 64aa8d4242
commit 5d142d2f13
3 changed files with 67 additions and 53 deletions

View file

@ -2,12 +2,12 @@ package com.jozufozu.flywheel.backend.engine.indirect;
import org.lwjgl.system.MemoryUtil; import org.lwjgl.system.MemoryUtil;
import com.jozufozu.flywheel.lib.memory.MemoryBlock; import com.jozufozu.flywheel.backend.util.MemoryBuffer;
public class ScatterList { public class ScatterList {
public static final long STRIDE = Integer.BYTES * 2; public static final long STRIDE = Integer.BYTES * 2;
public final long maxBytesPerScatter; public final long maxBytesPerScatter;
private MemoryBlock block; private final MemoryBuffer block = new MemoryBuffer(STRIDE);
private int length; private int length;
private long usedBytes; private long usedBytes;
@ -46,9 +46,9 @@ public class ScatterList {
} }
public void push(long sizeBytes, long srcOffsetBytes, long dstOffsetBytes) { public void push(long sizeBytes, long srcOffsetBytes, long dstOffsetBytes) {
reallocIfNeeded(length); block.reallocIfNeeded(length);
long ptr = block.ptr() + length * STRIDE; long ptr = block.ptrForIndex(length);
MemoryUtil.memPutInt(ptr, packSizeAndSrcOffset(sizeBytes, srcOffsetBytes)); MemoryUtil.memPutInt(ptr, packSizeAndSrcOffset(sizeBytes, srcOffsetBytes));
MemoryUtil.memPutInt(ptr + Integer.BYTES, (int) (dstOffsetBytes >> 2)); MemoryUtil.memPutInt(ptr + Integer.BYTES, (int) (dstOffsetBytes >> 2));
@ -78,21 +78,7 @@ public class ScatterList {
} }
public void delete() { public void delete() {
if (block != null) { block.delete();
block.free();
}
}
private void reallocIfNeeded(int index) {
if (block == null) {
block = MemoryBlock.malloc(neededCapacityForIndex(index + 8));
} else if (block.size() < neededCapacityForIndex(index)) {
block = block.realloc(neededCapacityForIndex(index + 8));
}
}
private static long neededCapacityForIndex(int index) {
return (index + 1) * STRIDE;
} }
private static int packSizeAndSrcOffset(long sizeBytes, long srcOffsetBytes) { private static int packSizeAndSrcOffset(long sizeBytes, long srcOffsetBytes) {

View file

@ -2,11 +2,11 @@ package com.jozufozu.flywheel.backend.engine.indirect;
import org.lwjgl.system.MemoryUtil; import org.lwjgl.system.MemoryUtil;
import com.jozufozu.flywheel.lib.memory.MemoryBlock; import com.jozufozu.flywheel.backend.util.MemoryBuffer;
public class TransferList { public class TransferList {
private static final long STRIDE = Long.BYTES * 4; private static final long STRIDE = Long.BYTES * 4;
private MemoryBlock block; private final MemoryBuffer block = new MemoryBuffer(STRIDE);
private int length; private int length;
/** /**
@ -24,7 +24,7 @@ public class TransferList {
return; return;
} }
reallocIfNeeded(length); block.reallocIfNeeded(length);
vbo(length, vbo); vbo(length, vbo);
srcOffset(length, srcOffset); srcOffset(length, srcOffset);
@ -56,25 +56,23 @@ public class TransferList {
} }
public int vbo(int index) { public int vbo(int index) {
return MemoryUtil.memGetInt(ptrForIndex(index)); return MemoryUtil.memGetInt(block.ptrForIndex(index));
} }
public long srcOffset(int index) { public long srcOffset(int index) {
return MemoryUtil.memGetLong(ptrForIndex(index) + Long.BYTES); return MemoryUtil.memGetLong(block.ptrForIndex(index) + Long.BYTES);
} }
public long dstOffset(int index) { public long dstOffset(int index) {
return MemoryUtil.memGetLong(ptrForIndex(index) + Long.BYTES * 2); return MemoryUtil.memGetLong(block.ptrForIndex(index) + Long.BYTES * 2);
} }
public long size(int index) { public long size(int index) {
return MemoryUtil.memGetLong(ptrForIndex(index) + Long.BYTES * 3); return MemoryUtil.memGetLong(block.ptrForIndex(index) + Long.BYTES * 3);
} }
public void delete() { public void delete() {
if (block != null) { block.delete();
block.free();
}
} }
private boolean continuesLast(int vbo, long srcOffset, long dstOffset) { private boolean continuesLast(int vbo, long srcOffset, long dstOffset) {
@ -87,38 +85,18 @@ public class TransferList {
} }
private void vbo(int index, int vbo) { private void vbo(int index, int vbo) {
MemoryUtil.memPutInt(ptrForIndex(index), vbo); MemoryUtil.memPutInt(block.ptrForIndex(index), vbo);
} }
private void srcOffset(int index, long srcOffset) { private void srcOffset(int index, long srcOffset) {
MemoryUtil.memPutLong(ptrForIndex(index) + Long.BYTES, srcOffset); MemoryUtil.memPutLong(block.ptrForIndex(index) + Long.BYTES, srcOffset);
} }
private void dstOffset(int index, long dstOffset) { private void dstOffset(int index, long dstOffset) {
MemoryUtil.memPutLong(ptrForIndex(index) + Long.BYTES * 2, dstOffset); MemoryUtil.memPutLong(block.ptrForIndex(index) + Long.BYTES * 2, dstOffset);
} }
private void size(int index, long size) { private void size(int index, long size) {
MemoryUtil.memPutLong(ptrForIndex(index) + Long.BYTES * 3, size); MemoryUtil.memPutLong(block.ptrForIndex(index) + Long.BYTES * 3, size);
}
private void reallocIfNeeded(int index) {
if (block == null) {
block = MemoryBlock.malloc(neededCapacityForIndex(index + 8));
} else if (block.size() < neededCapacityForIndex(index)) {
block = block.realloc(neededCapacityForIndex(index + 8));
}
}
private long ptrForIndex(int index) {
return block.ptr() + bytePosForIndex(index);
}
private static long bytePosForIndex(int index) {
return index * STRIDE;
}
private static long neededCapacityForIndex(int index) {
return (index + 1) * STRIDE;
} }
} }

View file

@ -0,0 +1,50 @@
package com.jozufozu.flywheel.backend.util;
import org.jetbrains.annotations.Nullable;
import com.jozufozu.flywheel.lib.memory.MemoryBlock;
public class MemoryBuffer {
private final long stride;
@Nullable
private MemoryBlock block;
public MemoryBuffer(long stride) {
this.stride = stride;
}
public boolean reallocIfNeeded(int index) {
if (block == null) {
block = MemoryBlock.malloc(neededCapacityForIndex(index + 8));
return true;
} else if (block.size() < neededCapacityForIndex(index)) {
block = block.realloc(neededCapacityForIndex(index + 8));
return true;
}
return false;
}
public long ptr() {
return block.ptr();
}
public long ptrForIndex(int index) {
return block.ptr() + bytePosForIndex(index);
}
public long bytePosForIndex(int index) {
return index * stride;
}
public long neededCapacityForIndex(int index) {
return (index + 1) * stride;
}
public void delete() {
if (block != null) {
block.free();
}
}
}