Throw more compute at it

- Use a shader to perform copies.
- Needs optimization, but it works surprisingly well and avoids the
  driver hitch.
This commit is contained in:
Jozufozu 2023-12-09 15:32:21 -08:00
parent ec6dbfbf49
commit 1bfb6db6d1
6 changed files with 174 additions and 24 deletions

View file

@ -24,18 +24,21 @@ import net.minecraft.resources.ResourceLocation;
public class IndirectPrograms { public class IndirectPrograms {
private static final ResourceLocation CULL_SHADER_MAIN = Flywheel.rl("internal/indirect/cull.glsl"); private static final ResourceLocation CULL_SHADER_MAIN = Flywheel.rl("internal/indirect/cull.glsl");
private static final ResourceLocation APPLY_SHADER_MAIN = Flywheel.rl("internal/indirect/apply.glsl"); private static final ResourceLocation APPLY_SHADER_MAIN = Flywheel.rl("internal/indirect/apply.glsl");
private static final ResourceLocation SCATTER_SHADER_MAIN = Flywheel.rl("internal/indirect/scatter.glsl");
public static IndirectPrograms instance; public static IndirectPrograms instance;
private static final Compile<InstanceType<?>> CULL = new Compile<>(); private static final Compile<InstanceType<?>> CULL = new Compile<>();
private static final Compile<Unit> APPLY = new Compile<>(); private static final Compile<Unit> UNIT = new Compile<>();
private final Map<PipelineProgramKey, GlProgram> pipeline; private final Map<PipelineProgramKey, GlProgram> pipeline;
private final Map<InstanceType<?>, GlProgram> culling; private final Map<InstanceType<?>, GlProgram> culling;
private final GlProgram apply; private final GlProgram apply;
private final GlProgram scatter;
public IndirectPrograms(Map<PipelineProgramKey, GlProgram> pipeline, Map<InstanceType<?>, GlProgram> culling, GlProgram apply) { public IndirectPrograms(Map<PipelineProgramKey, GlProgram> pipeline, Map<InstanceType<?>, GlProgram> culling, GlProgram apply, GlProgram scatter) {
this.pipeline = pipeline; this.pipeline = pipeline;
this.culling = culling; this.culling = culling;
this.apply = apply; this.apply = apply;
this.scatter = scatter;
} }
static void reload(ShaderSources sources, ImmutableList<PipelineProgramKey> pipelineKeys, UniformComponent uniformComponent, List<SourceComponent> vertexComponents, List<SourceComponent> fragmentComponents) { static void reload(ShaderSources sources, ImmutableList<PipelineProgramKey> pipelineKeys, UniformComponent uniformComponent, List<SourceComponent> vertexComponents, List<SourceComponent> fragmentComponents) {
@ -43,14 +46,16 @@ public class IndirectPrograms {
var pipelineCompiler = PipelineCompiler.create(sources, Pipelines.INDIRECT, pipelineKeys, uniformComponent, vertexComponents, fragmentComponents); var pipelineCompiler = PipelineCompiler.create(sources, Pipelines.INDIRECT, pipelineKeys, uniformComponent, vertexComponents, fragmentComponents);
var cullingCompiler = createCullingCompiler(uniformComponent, sources); var cullingCompiler = createCullingCompiler(uniformComponent, sources);
var applyCompiler = createApplyCompiler(sources); var applyCompiler = createApplyCompiler(sources);
var scatterCompiler = createScatterCompiler(sources);
try { try {
var pipelineResult = pipelineCompiler.compileAndReportErrors(); var pipelineResult = pipelineCompiler.compileAndReportErrors();
var cullingResult = cullingCompiler.compileAndReportErrors(); var cullingResult = cullingCompiler.compileAndReportErrors();
var applyResult = applyCompiler.compileAndReportErrors(); var applyResult = applyCompiler.compileAndReportErrors();
var scatterResult = scatterCompiler.compileAndReportErrors();
if (pipelineResult != null && cullingResult != null && applyResult != null) { if (pipelineResult != null && cullingResult != null && applyResult != null && scatterResult != null) {
instance = new IndirectPrograms(pipelineResult, cullingResult, applyResult.get(Unit.INSTANCE)); instance = new IndirectPrograms(pipelineResult, cullingResult, applyResult.get(Unit.INSTANCE), scatterResult.get(Unit.INSTANCE));
} }
} catch (Throwable e) { } catch (Throwable e) {
Flywheel.LOGGER.error("Failed to compile indirect programs", e); Flywheel.LOGGER.error("Failed to compile indirect programs", e);
@ -99,15 +104,25 @@ public class IndirectPrograms {
} }
private static CompilationHarness<Unit> createApplyCompiler(ShaderSources sources) { private static CompilationHarness<Unit> createApplyCompiler(ShaderSources sources) {
return APPLY.harness(sources) return UNIT.harness(sources)
.keys(ImmutableList.of(Unit.INSTANCE)) .keys(ImmutableList.of(Unit.INSTANCE))
.compiler(APPLY.program() .compiler(UNIT.program()
.link(APPLY.shader(GlslVersion.V460, ShaderType.COMPUTE) .link(UNIT.shader(GlslVersion.V460, ShaderType.COMPUTE)
.define("_FLW_SUBGROUP_SIZE", GlCompat.SUBGROUP_SIZE) .define("_FLW_SUBGROUP_SIZE", GlCompat.SUBGROUP_SIZE)
.withResource(APPLY_SHADER_MAIN))) .withResource(APPLY_SHADER_MAIN)))
.build(); .build();
} }
private static CompilationHarness<Unit> createScatterCompiler(ShaderSources sources) {
return UNIT.harness(sources)
.keys(ImmutableList.of(Unit.INSTANCE))
.compiler(UNIT.program()
.link(UNIT.shader(GlslVersion.V460, ShaderType.COMPUTE)
.define("_FLW_SUBGROUP_SIZE", GlCompat.SUBGROUP_SIZE)
.withResource(SCATTER_SHADER_MAIN)))
.build();
}
public GlProgram getIndirectProgram(InstanceType<?> instanceType, Context contextShader) { public GlProgram getIndirectProgram(InstanceType<?> instanceType, Context contextShader) {
return pipeline.get(new PipelineProgramKey(instanceType, contextShader)); return pipeline.get(new PipelineProgramKey(instanceType, contextShader));
} }
@ -120,6 +135,10 @@ public class IndirectPrograms {
return apply; return apply;
} }
public GlProgram getScatterProgram() {
return scatter;
}
public void delete() { public void delete() {
pipeline.values() pipeline.values()
.forEach(GlProgram::delete); .forEach(GlProgram::delete);

View file

@ -99,6 +99,7 @@ public class IndirectCullingGroup<I extends Instance> {
UniformBuffer.get().sync(); UniformBuffer.get().sync();
cullProgram.bind(); cullProgram.bind();
buffers.bindForCompute(); buffers.bindForCompute();
glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
glDispatchCompute(GlCompat.getComputeGroupCount(instanceCountThisFrame), 1, 1); glDispatchCompute(GlCompat.getComputeGroupCount(instanceCountThisFrame), 1, 1);
} }

View file

@ -4,14 +4,22 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.function.LongConsumer; import java.util.function.LongConsumer;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.lwjgl.opengl.GL45;
import org.lwjgl.opengl.GL45C; import org.lwjgl.opengl.GL45C;
import org.lwjgl.system.MemoryUtil; import org.lwjgl.system.MemoryUtil;
import com.jozufozu.flywheel.backend.compile.IndirectPrograms;
import com.jozufozu.flywheel.gl.GlCompat;
import com.jozufozu.flywheel.gl.GlFence; import com.jozufozu.flywheel.gl.GlFence;
import com.jozufozu.flywheel.gl.buffer.GlBuffer;
import com.jozufozu.flywheel.lib.memory.FlwMemoryTracker; import com.jozufozu.flywheel.lib.memory.FlwMemoryTracker;
import com.jozufozu.flywheel.lib.memory.MemoryBlock; import com.jozufozu.flywheel.lib.memory.MemoryBlock;
import it.unimi.dsi.fastutil.PriorityQueue; import it.unimi.dsi.fastutil.PriorityQueue;
import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.objects.ObjectArrayFIFOQueue; import it.unimi.dsi.fastutil.objects.ObjectArrayFIFOQueue;
// Used https://github.com/CaffeineMC/sodium-fabric/blob/dev/src/main/java/me/jellysquid/mods/sodium/client/gl/arena/staging/MappedStagingBuffer.java // Used https://github.com/CaffeineMC/sodium-fabric/blob/dev/src/main/java/me/jellysquid/mods/sodium/client/gl/arena/staging/MappedStagingBuffer.java
@ -30,8 +38,10 @@ public class StagingBuffer {
private long totalAvailable; private long totalAvailable;
@Nullable
private MemoryBlock scratch; private MemoryBlock scratch;
private final GlBuffer copyBuffer = new GlBuffer();
private final OverflowStagingBuffer overflow = new OverflowStagingBuffer(); private final OverflowStagingBuffer overflow = new OverflowStagingBuffer();
private final PriorityQueue<Transfer> transfers = new ObjectArrayFIFOQueue<>(); private final PriorityQueue<Transfer> transfers = new ObjectArrayFIFOQueue<>();
private final PriorityQueue<FencedRegion> fencedRegions = new ObjectArrayFIFOQueue<>(); private final PriorityQueue<FencedRegion> fencedRegions = new ObjectArrayFIFOQueue<>();
@ -78,6 +88,7 @@ public class StagingBuffer {
enqueueCopy(block.ptr(), size, dstVbo, dstOffset); enqueueCopy(block.ptr(), size, dstVbo, dstOffset);
} }
@NotNull
private MemoryBlock getScratch(long size) { private MemoryBlock getScratch(long size) {
if (scratch == null) { if (scratch == null) {
scratch = MemoryBlock.malloc(size); scratch = MemoryBlock.malloc(size);
@ -107,8 +118,10 @@ public class StagingBuffer {
* @param dstOffset The offset in the destination VBO. * @param dstOffset The offset in the destination VBO.
*/ */
public void enqueueCopy(long ptr, long size, int dstVbo, long dstOffset) { public void enqueueCopy(long ptr, long size, int dstVbo, long dstOffset) {
assertMultipleOf4(size);
if (size > totalAvailable) { if (size > totalAvailable) {
overflow.enqueueCopy(ptr, size, dstVbo, dstOffset); overflow.upload(ptr, size, dstVbo, dstOffset);
return; return;
} }
@ -151,6 +164,7 @@ public class StagingBuffer {
* @return A pointer to the reserved space, or {@code null} if there is not enough contiguous space. * @return A pointer to the reserved space, or {@code null} if there is not enough contiguous space.
*/ */
public long reserveForTransferTo(long size, int dstVbo, long dstOffset) { public long reserveForTransferTo(long size, int dstVbo, long dstOffset) {
assertMultipleOf4(size);
// Don't need to check totalAvailable here because that's a looser constraint than the bytes remaining. // Don't need to check totalAvailable here because that's a looser constraint than the bytes remaining.
long remaining = capacity - pos; long remaining = capacity - pos;
if (size > remaining) { if (size > remaining) {
@ -173,14 +187,69 @@ public class StagingBuffer {
return; return;
} }
if (pos < start) { flushUsedRegion();
// we rolled around, need to flush 2 ranges.
GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start); var usedCapacity = dispatchComputeCopies();
GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos);
} else { fencedRegions.enqueue(new FencedRegion(new GlFence(), usedCapacity));
GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start);
start = pos;
}
private long dispatchComputeCopies() {
long usedCapacity = 0;
Int2ObjectMap<List<Transfer>> copiesPerVbo = new Int2ObjectArrayMap<>();
long bytesPerCopy = 64;
for (var transfer : consolidateCopies(transfers)) {
usedCapacity += transfer.size;
var forVbo = copiesPerVbo.computeIfAbsent(transfer.dstVbo, k -> new ArrayList<>());
long offset = 0;
long remaining = transfer.size;
while (offset < transfer.size) {
long copySize = Math.min(remaining, bytesPerCopy);
forVbo.add(new Transfer(transfer.srcOffset + offset, transfer.dstVbo, transfer.dstOffset + offset, copySize));
offset += copySize;
remaining -= copySize;
}
} }
IndirectPrograms.get()
.getScatterProgram()
.bind();
for (var entry : copiesPerVbo.int2ObjectEntrySet()) {
var dstVbo = entry.getIntKey();
var transfers = entry.getValue();
var copyCount = transfers.size();
var size = copyCount * Integer.BYTES * 3L;
var scratch = getScratch(size);
putTransfers(scratch.ptr(), transfers);
copyBuffer.upload(scratch.ptr(), size);
GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 0, copyBuffer.handle());
GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 1, vbo);
GL45.glBindBufferBase(GL45C.GL_SHADER_STORAGE_BUFFER, 2, dstVbo);
GL45.glDispatchCompute(GlCompat.getComputeGroupCount(copyCount), 1, 1);
}
return usedCapacity;
}
private void assertMultipleOf4(long size) {
if (size % 4 != 0) {
throw new IllegalArgumentException("Size must be a multiple of 4");
}
}
private long sendCopyCommands() {
long usedCapacity = 0; long usedCapacity = 0;
for (Transfer transfer : consolidateCopies(transfers)) { for (Transfer transfer : consolidateCopies(transfers)) {
@ -188,10 +257,17 @@ public class StagingBuffer {
GL45C.glCopyNamedBufferSubData(vbo, transfer.dstVbo, transfer.srcOffset, transfer.dstOffset, transfer.size); GL45C.glCopyNamedBufferSubData(vbo, transfer.dstVbo, transfer.srcOffset, transfer.dstOffset, transfer.size);
} }
return usedCapacity;
}
fencedRegions.enqueue(new FencedRegion(new GlFence(), usedCapacity)); private void flushUsedRegion() {
if (pos < start) {
start = pos; // we rolled around, need to flush 2 ranges.
GL45C.glFlushMappedNamedBufferRange(vbo, start, capacity - start);
GL45C.glFlushMappedNamedBufferRange(vbo, 0, pos);
} else {
GL45C.glFlushMappedNamedBufferRange(vbo, start, pos - start);
}
} }
private static List<Transfer> consolidateCopies(PriorityQueue<Transfer> queue) { private static List<Transfer> consolidateCopies(PriorityQueue<Transfer> queue) {
@ -214,6 +290,15 @@ public class StagingBuffer {
return merged; return merged;
} }
private static void putTransfers(long ptr, List<Transfer> transfers) {
for (Transfer transfer : transfers) {
MemoryUtil.memPutInt(ptr, (int) transfer.srcOffset);
MemoryUtil.memPutInt(ptr + Integer.BYTES, (int) transfer.dstOffset);
MemoryUtil.memPutInt(ptr + Integer.BYTES * 2, (int) transfer.size);
ptr += Integer.BYTES * 3;
}
}
private static boolean areContiguous(Transfer last, Transfer transfer) { private static boolean areContiguous(Transfer last, Transfer transfer) {
return last.dstVbo == transfer.dstVbo && last.dstOffset + last.size == transfer.dstOffset && last.srcOffset + last.size == transfer.srcOffset; return last.dstVbo == transfer.dstVbo && last.dstOffset + last.size == transfer.dstOffset && last.srcOffset + last.size == transfer.srcOffset;
} }
@ -238,7 +323,9 @@ public class StagingBuffer {
GL45C.glDeleteBuffers(vbo); GL45C.glDeleteBuffers(vbo);
overflow.delete(); overflow.delete();
scratch.free(); if (scratch != null) {
scratch.free();
}
FlwMemoryTracker._freeCPUMemory(capacity); FlwMemoryTracker._freeCPUMemory(capacity);
} }
@ -271,7 +358,7 @@ public class StagingBuffer {
vbo = GL45C.glCreateBuffers(); vbo = GL45C.glCreateBuffers();
} }
public void enqueueCopy(long ptr, long size, int dstVbo, long dstOffset) { public void upload(long ptr, long size, int dstVbo, long dstOffset) {
GL45C.nglNamedBufferData(vbo, size, ptr, GL45C.GL_STREAM_COPY); GL45C.nglNamedBufferData(vbo, size, ptr, GL45C.GL_STREAM_COPY);
GL45C.glCopyNamedBufferSubData(vbo, dstVbo, 0, dstOffset, size); GL45C.glCopyNamedBufferSubData(vbo, dstVbo, 0, dstOffset, size);
} }

View file

@ -82,11 +82,15 @@ public class GlBuffer extends GlObject {
size = growthFunction.apply(capacity); size = growthFunction.apply(capacity);
} }
public void upload(MemoryBlock directBuffer) { public void upload(MemoryBlock memoryBlock) {
FlwMemoryTracker._freeGPUMemory(size); upload(memoryBlock.ptr(), memoryBlock.size());
Buffer.IMPL.data(handle(), directBuffer.size(), directBuffer.ptr(), usage.glEnum); }
size = directBuffer.size();
FlwMemoryTracker._allocGPUMemory(size); public void upload(long ptr, long size) {
FlwMemoryTracker._freeGPUMemory(this.size);
Buffer.IMPL.data(handle(), size, ptr, usage.glEnum);
this.size = size;
FlwMemoryTracker._allocGPUMemory(this.size);
} }
public MappedBuffer map() { public MappedBuffer map() {

View file

@ -17,6 +17,10 @@ public final class MoreMath {
return (numerator + denominator - 1) / denominator; return (numerator + denominator - 1) / denominator;
} }
public static long ceilingDiv(long numerator, long denominator) {
return (numerator + denominator - 1) / denominator;
}
public static int numDigits(int number) { public static int numDigits(int number) {
// cursed but allegedly the fastest algorithm, taken from https://www.baeldung.com/java-number-of-digits-in-int // cursed but allegedly the fastest algorithm, taken from https://www.baeldung.com/java-number-of-digits-in-int
if (number < 100000) { if (number < 100000) {

View file

@ -0,0 +1,35 @@
layout(local_size_x = _FLW_SUBGROUP_SIZE) in;
struct Copy {
uint srcOffset;
uint dstOffset;
uint byteSize;
};
layout(std430, binding = 0) restrict readonly buffer Copies {
Copy copies[];
};
layout(std430, binding = 1) restrict readonly buffer Src {
uint src[];
};
layout(std430, binding = 2) restrict writeonly buffer Dst {
uint dst[];
};
void main() {
uint copy = gl_GlobalInvocationID.x;
if (copy >= copies.length()) {
return;
}
uint srcOffset = copies[copy].srcOffset >> 2;
uint dstOffset = copies[copy].dstOffset >> 2;
uint size = copies[copy].byteSize >> 2;
for (uint i = 0; i < size; i++) {
dst[dstOffset + i] = src[srcOffset + i];
}
}