Throw more compute at it

- Use a shader to perform copies.
- Needs optimization, but it works surprisingly well and avoids the
  driver hitch.
This commit is contained in:
Jozufozu 2023-12-09 15:32:21 -08:00
parent e8c77f1e89
commit f9209f7dd6
6 changed files with 174 additions and 24 deletions

View file

@ -24,18 +24,21 @@ import net.minecraft.resources.ResourceLocation;
public class IndirectPrograms {
private static final ResourceLocation CULL_SHADER_MAIN = Flywheel.rl("internal/indirect/cull.glsl");
private static final ResourceLocation APPLY_SHADER_MAIN = Flywheel.rl("internal/indirect/apply.glsl");
private static final ResourceLocation SCATTER_SHADER_MAIN = Flywheel.rl("internal/indirect/scatter.glsl");
public static IndirectPrograms instance;
private static final Compile<InstanceType<?>> CULL = new Compile<>();
private static final Compile<Unit> APPLY = new Compile<>();
private static final Compile<Unit> UNIT = new Compile<>();
private final Map<PipelineProgramKey, GlProgram> pipeline;
private final Map<InstanceType<?>, GlProgram> culling;
private final GlProgram apply;
private final GlProgram scatter;
public IndirectPrograms(Map<PipelineProgramKey, GlProgram> pipeline, Map<InstanceType<?>, GlProgram> culling, GlProgram apply) {
public IndirectPrograms(Map<PipelineProgramKey, GlProgram> pipeline, Map<InstanceType<?>, GlProgram> culling, GlProgram apply, GlProgram scatter) {
this.pipeline = pipeline;
this.culling = culling;
this.apply = apply;
this.scatter = scatter;
}
static void reload(ShaderSources sources, ImmutableList<PipelineProgramKey> pipelineKeys, UniformComponent uniformComponent, List<SourceComponent> vertexComponents, List<SourceComponent> fragmentComponents) {
@ -43,14 +46,16 @@ public class IndirectPrograms {
var pipelineCompiler = PipelineCompiler.create(sources, Pipelines.INDIRECT, pipelineKeys, uniformComponent, vertexComponents, fragmentComponents);
var cullingCompiler = createCullingCompiler(uniformComponent, sources);
var applyCompiler = createApplyCompiler(sources);
var scatterCompiler = createScatterCompiler(sources);
try {
var pipelineResult = pipelineCompiler.compileAndReportErrors();
var cullingResult = cullingCompiler.compileAndReportErrors();
var applyResult = applyCompiler.compileAndReportErrors();
var scatterResult = scatterCompiler.compileAndReportErrors();
if (pipelineResult != null && cullingResult != null && applyResult != null) {
instance = new IndirectPrograms(pipelineResult, cullingResult, applyResult.get(Unit.INSTANCE));
if (pipelineResult != null && cullingResult != null && applyResult != null && scatterResult != null) {
instance = new IndirectPrograms(pipelineResult, cullingResult, applyResult.get(Unit.INSTANCE), scatterResult.get(Unit.INSTANCE));
}
} catch (Throwable e) {
Flywheel.LOGGER.error("Failed to compile indirect programs", e);
@ -99,15 +104,25 @@ public class IndirectPrograms {
}
private static CompilationHarness<Unit> createApplyCompiler(ShaderSources sources) {
return APPLY.harness(sources)
return UNIT.harness(sources)
.keys(ImmutableList.of(Unit.INSTANCE))
.compiler(APPLY.program()
.link(APPLY.shader(GlslVersion.V460, ShaderType.COMPUTE)
.compiler(UNIT.program()
.link(UNIT.shader(GlslVersion.V460, ShaderType.COMPUTE)
.define("_FLW_SUBGROUP_SIZE", GlCompat.SUBGROUP_SIZE)
.withResource(APPLY_SHADER_MAIN)))
.build();
}
private static CompilationHarness<Unit> createScatterCompiler(ShaderSources sources) {
return UNIT.harness(sources)
.keys(ImmutableList.of(Unit.INSTANCE))
.compiler(UNIT.program()
.link(UNIT.shader(GlslVersion.V460, ShaderType.COMPUTE)
.define("_FLW_SUBGROUP_SIZE", GlCompat.SUBGROUP_SIZE)
.withResource(SCATTER_SHADER_MAIN)))
.build();
}
public GlProgram getIndirectProgram(InstanceType<?> instanceType, Context contextShader) {
return pipeline.get(new PipelineProgramKey(instanceType, contextShader));
}
@ -120,6 +135,10 @@ public class IndirectPrograms {
return apply;
}
public GlProgram getScatterProgram() {
return scatter;
}
public void delete() {
pipeline.values()
.forEach(GlProgram::delete);

View file

@ -99,6 +99,7 @@ public class IndirectCullingGroup<I extends Instance> {
UniformBuffer.get().sync();
cullProgram.bind();
buffers.bindForCompute();
glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
glDispatchCompute(GlCompat.getComputeGroupCount(instanceCountThisFrame), 1, 1);
}

View file

@ -4,14 +4,22 @@ import java.util.ArrayList;
import java.util.List;
import java.util.function.LongConsumer;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.lwjgl.opengl.GL45;
import org.lwjgl.opengl.GL45C;
import org.lwjgl.system.MemoryUtil;
import com.jozufozu.flywheel.backend.compile.IndirectPrograms;
import com.jozufozu.flywheel.gl.GlCompat;
import com.jozufozu.flywheel.gl.GlFence;
import com.jozufozu.flywheel.gl.buffer.GlBuffer;
import com.jozufozu.flywheel.lib.memory.FlwMemoryTracker;
import com.jozufozu.flywheel.lib.memory.MemoryBlock;
import it.unimi.dsi.fastutil.PriorityQueue;
import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.objects.ObjectArrayFIFOQueue;
// Used https://github.com/CaffeineMC/sodium-fabric/blob/dev/src/main/java/me/jellysquid/mods/sodium/client/gl/arena/staging/MappedStagingBuffer.java
@ -30,8 +38,10 @@ public class StagingBuffer {
private long totalAvailable;
@Nullable
private MemoryBlock scratch;
private final GlBuffer copyBuffer = new GlBuffer();
private final OverflowStagingBuffer overflow = new OverflowStagingBuffer();
private final PriorityQueue<Transfer> transfers = new ObjectArrayFIFOQueue<>();
private final PriorityQueue<FencedRegion> fencedRegions = new ObjectArrayFIFOQueue<>();
@ -78,6 +88,7 @@ public class StagingBuffer {
enqueueCopy(block.ptr(), size, dstVbo, dstOffset);
}
@NotNull
private MemoryBlock getScratch(long size) {
if (scratch == null) {
scratch = MemoryBlock.malloc(size);
@ -107,8 +118,10 @@ public class StagingBuffer {
* @param dstOffset The offset in the destination VBO.
*/
public void enqueueCopy(long ptr, long size, int dstVbo, long dstOffset) {
assertMultipleOf4(size);
if (size > totalAvailable) {
overflow.enqueueCopy(ptr, size, dstVbo, dstOffset);
overflow.upload(ptr, size, dstVbo, dstOffset);
return;
}
@ -151,6 +164,7 @@ public class StagingBuffer {
* @return A pointer to the reserved space, or {@code null} if there is not enough contiguous space.
*/
public long reserveForTransferTo(long size, int dstVbo, long dstOffset) {
assertMultipleOf4(size);
// Don't need to check totalAvailable here because that's a looser constraint than the bytes remaining.
long remaining = capacity - pos;
if (size > remaining) {
@ -173,14 +187,69 @@ public class StagingBuffer {
return;
}
if (pos < start) {
// we rolled around, need to flush 2 ranges.
GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start);
GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos);
} else {
GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start);
flushUsedRegion();
var usedCapacity = dispatchComputeCopies();
fencedRegions.enqueue(new FencedRegion(new GlFence(), usedCapacity));
start = pos;
}
private long dispatchComputeCopies() {
long usedCapacity = 0;
Int2ObjectMap<List<Transfer>> copiesPerVbo = new Int2ObjectArrayMap<>();
long bytesPerCopy = 64;
for (var transfer : consolidateCopies(transfers)) {
usedCapacity += transfer.size;
var forVbo = copiesPerVbo.computeIfAbsent(transfer.dstVbo, k -> new ArrayList<>());
long offset = 0;
long remaining = transfer.size;
while (offset < transfer.size) {
long copySize = Math.min(remaining, bytesPerCopy);
forVbo.add(new Transfer(transfer.srcOffset + offset, transfer.dstVbo, transfer.dstOffset + offset, copySize));
offset += copySize;
remaining -= copySize;
}
}
IndirectPrograms.get()
.getScatterProgram()
.bind();
for (var entry : copiesPerVbo.int2ObjectEntrySet()) {
var dstVbo = entry.getIntKey();
var transfers = entry.getValue();
var copyCount = transfers.size();
var size = copyCount * Integer.BYTES * 3L;
var scratch = getScratch(size);
putTransfers(scratch.ptr(), transfers);
copyBuffer.upload(scratch.ptr(), size);
GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 0, copyBuffer.handle());
GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 1, vbo);
GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 2, dstVbo);
GL45.glDispatchCompute(GlCompat.getComputeGroupCount(copyCount), 1, 1);
}
return usedCapacity;
}
private void assertMultipleOf4(long size) {
if (size % 4 != 0) {
throw new IllegalArgumentException("Size must be a multiple of 4");
}
}
private long sendCopyCommands() {
long usedCapacity = 0;
for (Transfer transfer : consolidateCopies(transfers)) {
@ -188,10 +257,17 @@ public class StagingBuffer {
GL45C.glCopyNamedBufferSubData(vbo, transfer.dstVbo, transfer.srcOffset, transfer.dstOffset, transfer.size);
}
return usedCapacity;
}
fencedRegions.enqueue(new FencedRegion(new GlFence(), usedCapacity));
start = pos;
private void flushUsedRegion() {
if (pos < start) {
// we rolled around, need to flush 2 ranges.
GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start);
GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos);
} else {
GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start);
}
}
private static List<Transfer> consolidateCopies(PriorityQueue<Transfer> queue) {
@ -214,6 +290,15 @@ public class StagingBuffer {
return merged;
}
private static void putTransfers(long ptr, List<Transfer> transfers) {
for (Transfer transfer : transfers) {
MemoryUtil.memPutInt(ptr, (int) transfer.srcOffset);
MemoryUtil.memPutInt(ptr + Integer.BYTES, (int) transfer.dstOffset);
MemoryUtil.memPutInt(ptr + Integer.BYTES * 2, (int) transfer.size);
ptr += Integer.BYTES * 3;
}
}
private static boolean areContiguous(Transfer last, Transfer transfer) {
return last.dstVbo == transfer.dstVbo && last.dstOffset + last.size == transfer.dstOffset && last.srcOffset + last.size == transfer.srcOffset;
}
@ -238,7 +323,9 @@ public class StagingBuffer {
GL45C.glDeleteBuffers(vbo);
overflow.delete();
if (scratch != null) {
scratch.free();
}
FlwMemoryTracker._freeCPUMemory(capacity);
}
@ -271,7 +358,7 @@ public class StagingBuffer {
vbo = GL45C.glCreateBuffers();
}
public void enqueueCopy(long ptr, long size, int dstVbo, long dstOffset) {
public void upload(long ptr, long size, int dstVbo, long dstOffset) {
GL45C.nglNamedBufferData(vbo, size, ptr, GL45C.GL_STREAM_COPY);
GL45C.glCopyNamedBufferSubData(vbo, dstVbo, 0, dstOffset, size);
}

View file

@ -82,11 +82,15 @@ public class GlBuffer extends GlObject {
size = growthFunction.apply(capacity);
}
public void upload(MemoryBlock directBuffer) {
FlwMemoryTracker._freeGPUMemory(size);
Buffer.IMPL.data(handle(), directBuffer.size(), directBuffer.ptr(), usage.glEnum);
size = directBuffer.size();
FlwMemoryTracker._allocGPUMemory(size);
public void upload(MemoryBlock memoryBlock) {
upload(memoryBlock.ptr(), memoryBlock.size());
}
public void upload(long ptr, long size) {
FlwMemoryTracker._freeGPUMemory(this.size);
Buffer.IMPL.data(handle(), size, ptr, usage.glEnum);
this.size = size;
FlwMemoryTracker._allocGPUMemory(this.size);
}
public MappedBuffer map() {

View file

@ -17,6 +17,10 @@ public final class MoreMath {
return (numerator + denominator - 1) / denominator;
}
public static long ceilingDiv(long numerator, long denominator) {
return (numerator + denominator - 1) / denominator;
}
public static int numDigits(int number) {
// cursed but allegedly the fastest algorithm, taken from https://www.baeldung.com/java-number-of-digits-in-int
if (number < 100000) {

View file

@ -0,0 +1,35 @@
layout(local_size_x = _FLW_SUBGROUP_SIZE) in;
struct Copy {
uint srcOffset;
uint dstOffset;
uint byteSize;
};
layout(std430, binding = 0) restrict readonly buffer Copies {
Copy copies[];
};
layout(std430, binding = 1) restrict readonly buffer Src {
uint src[];
};
layout(std430, binding = 2) restrict writeonly buffer Dst {
uint dst[];
};
void main() {
uint copy = gl_GlobalInvocationID.x;
if (copy >= copies.length()) {
return;
}
uint srcOffset = copies[copy].srcOffset >> 2;
uint dstOffset = copies[copy].dstOffset >> 2;
uint size = copies[copy].byteSize >> 2;
for (uint i = 0; i < size; i++) {
dst[dstOffset + i] = src[srcOffset + i];
}
}