From 1d318ecb02aabf5308ac8b8433afdaa988334e90 Mon Sep 17 00:00:00 2001 From: Jozufozu Date: Fri, 8 Dec 2023 17:57:05 -0800 Subject: [PATCH] Bit-ing the atom - Use an atomic bitset in AbstractInstancer. - Add method to iterate over each contiguous span of set bits. --- .../backend/engine/AbstractInstancer.java | 20 +- .../engine/indirect/IndirectCullingGroup.java | 2 +- .../engine/indirect/IndirectInstancer.java | 40 +- .../engine/indirect/IndirectModel.java | 4 +- .../engine/instancing/InstancedInstancer.java | 16 +- .../flywheel/lib/util/AtomicBitset.java | 465 ++++++++++++++++++ 6 files changed, 511 insertions(+), 36 deletions(-) create mode 100644 src/main/java/com/jozufozu/flywheel/lib/util/AtomicBitset.java diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/AbstractInstancer.java b/src/main/java/com/jozufozu/flywheel/backend/engine/AbstractInstancer.java index 444dabb05..795ece566 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/engine/AbstractInstancer.java +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/AbstractInstancer.java @@ -1,11 +1,11 @@ package com.jozufozu.flywheel.backend.engine; import java.util.ArrayList; -import java.util.BitSet; import com.jozufozu.flywheel.api.instance.Instance; import com.jozufozu.flywheel.api.instance.InstanceType; import com.jozufozu.flywheel.api.instance.Instancer; +import com.jozufozu.flywheel.lib.util.AtomicBitset; public abstract class AbstractInstancer implements Instancer { public final InstanceType type; @@ -15,8 +15,8 @@ public abstract class AbstractInstancer implements Instancer protected final ArrayList instances = new ArrayList<>(); protected final ArrayList handles = new ArrayList<>(); - protected final BitSet changed = new BitSet(); - protected final BitSet deleted = new BitSet(); + protected final AtomicBitset changed = new AtomicBitset(); + protected final AtomicBitset deleted = new AtomicBitset(); protected AbstractInstancer(InstanceType type) { this.type = type; @@ -40,23 +40,23 @@ public abstract class AbstractInstancer implements Instancer return instances.size(); } + protected boolean moreThanTwoThirdsChanged() { + return (changed.cardinality() * 3) > (instances.size() * 2); + } + + public void notifyDirty(int index) { if (index < 0 || index >= getInstanceCount()) { return; } - // TODO: Atomic bitset. Synchronizing here blocks the task executor and causes massive overhead. - synchronized (lock) { - changed.set(index); - } + changed.set(index); } public void notifyRemoval(int index) { if (index < 0 || index >= getInstanceCount()) { return; } - synchronized (lock) { - deleted.set(index); - } + deleted.set(index); } protected void removeDeletedInstances() { diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectCullingGroup.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectCullingGroup.java index 5512b96a1..f91f79cc3 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectCullingGroup.java +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectCullingGroup.java @@ -203,7 +203,7 @@ public class IndirectCullingGroup { long pos = 0; for (IndirectModel model : indirectModels) { var instanceCount = model.instancer.getInstanceCount(); - model.writeObjects(stagingBuffer, pos, buffers.object.handle()); + model.uploadObjects(stagingBuffer, pos, buffers.object.handle()); pos += instanceCount * objectStride; } diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectInstancer.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectInstancer.java index 806bccaea..548057092 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectInstancer.java +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectInstancer.java @@ -26,41 +26,45 @@ public class IndirectInstancer extends AbstractInstancer removeDeletedInstances(); } - public void write(StagingBuffer stagingBuffer, long startPos, int dstVbo) { - if (shouldWriteAll(startPos)) { - writeAll(stagingBuffer, startPos, dstVbo); + public void upload(StagingBuffer stagingBuffer, long startPos, int dstVbo) { + if (shouldUploadAll(startPos)) { + uploadAll(stagingBuffer, startPos, dstVbo); } else { - writeChanged(stagingBuffer, startPos, dstVbo); + uploadChanged(stagingBuffer, startPos, dstVbo); } changed.clear(); lastStartPos = startPos; } - private boolean shouldWriteAll(long startPos) { + private boolean shouldUploadAll(long startPos) { // If enough of the buffer has changed, write the whole thing to avoid the overhead of a bunch of small writes. + // TODO: The overhead comes from the driver performing many buffer copies. Using a compute shader to scatter + // the data should work much better. return startPos != lastStartPos || moreThanTwoThirdsChanged(); } - private boolean moreThanTwoThirdsChanged() { - return (changed.cardinality() * 3) > (instances.size() * 2); + private void uploadChanged(StagingBuffer stagingBuffer, long baseByte, int dstVbo) { + changed.forEachSetSpan((startInclusive, endInclusive) -> { + var totalSize = (endInclusive - startInclusive + 1) * objectStride; + + stagingBuffer.enqueueCopy(totalSize, dstVbo, baseByte + startInclusive * objectStride, ptr -> { + for (int i = startInclusive; i <= endInclusive; i++) { + var instance = instances.get(i); + writeOne(ptr, instance); + ptr += objectStride; + } + }); + }); } - private void writeChanged(StagingBuffer stagingBuffer, long start, int dstVbo) { - int count = instances.size(); - for (int i = changed.nextSetBit(0); i >= 0 && i < count; i = changed.nextSetBit(i + 1)) { - var instance = instances.get(i); - stagingBuffer.enqueueCopy(objectStride, dstVbo, start + i * objectStride, ptr -> writeOne(ptr, instance)); - } - } - - private void writeAll(StagingBuffer stagingBuffer, long start, int dstVbo) { + private void uploadAll(StagingBuffer stagingBuffer, long start, int dstVbo) { long totalSize = objectStride * instances.size(); - stagingBuffer.enqueueCopy(totalSize, dstVbo, start, this::writeAll); + stagingBuffer.enqueueCopy(totalSize, dstVbo, start, this::uploadAll); } - private void writeAll(long ptr) { + private void uploadAll(long ptr) { for (I instance : instances) { writeOne(ptr, instance); ptr += objectStride; diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectModel.java b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectModel.java index 41d61079b..c7c1932b0 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectModel.java +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/indirect/IndirectModel.java @@ -25,8 +25,8 @@ public class IndirectModel { this.baseInstance = baseInstance; } - public void writeObjects(StagingBuffer stagingBuffer, long start, int dstVbo) { - instancer.write(stagingBuffer, start, dstVbo); + public void uploadObjects(StagingBuffer stagingBuffer, long start, int dstVbo) { + instancer.upload(stagingBuffer, start, dstVbo); } public void write(long ptr) { diff --git a/src/main/java/com/jozufozu/flywheel/backend/engine/instancing/InstancedInstancer.java b/src/main/java/com/jozufozu/flywheel/backend/engine/instancing/InstancedInstancer.java index 33eb1632c..7dbb1773c 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/engine/instancing/InstancedInstancer.java +++ b/src/main/java/com/jozufozu/flywheel/backend/engine/instancing/InstancedInstancer.java @@ -21,6 +21,7 @@ public class InstancedInstancer extends AbstractInstancer private final int instanceStride; private final Set boundTo = new HashSet<>(); + private final InstanceWriter writer; private GlBuffer vbo; private final List drawCalls = new ArrayList<>(); @@ -29,6 +30,7 @@ public class InstancedInstancer extends AbstractInstancer super(type); instanceFormat = type.getLayout(); instanceStride = instanceFormat.getStride(); + writer = type.getWriter(); } public int getAttributeCount() { @@ -70,12 +72,8 @@ public class InstancedInstancer extends AbstractInstancer try (MappedBuffer buf = vbo.map()) { long ptr = buf.ptr(); - InstanceWriter writer = type.getWriter(); - int count = instances.size(); - for (int i = changed.nextSetBit(0); i >= 0 && i < count; i = changed.nextSetBit(i + 1)) { - writer.write(ptr + (long) instanceStride * i, instances.get(i)); - } + writeChanged(ptr); changed.clear(); } catch (Exception e) { @@ -83,6 +81,14 @@ public class InstancedInstancer extends AbstractInstancer } } + private void writeChanged(long ptr) { + changed.forEachSetSpan((startInclusive, endInclusive) -> { + for (int i = startInclusive; i <= endInclusive; i++) { + writer.write(ptr + (long) instanceStride * i, instances.get(i)); + } + }); + } + /** * Bind this instancer's vbo to the given vao if it hasn't already been bound. * @param vao The vao to bind to. diff --git a/src/main/java/com/jozufozu/flywheel/lib/util/AtomicBitset.java b/src/main/java/com/jozufozu/flywheel/lib/util/AtomicBitset.java new file mode 100644 index 000000000..472f9b5dc --- /dev/null +++ b/src/main/java/com/jozufozu/flywheel/lib/util/AtomicBitset.java @@ -0,0 +1,465 @@ +package com.jozufozu.flywheel.lib.util; + +import java.util.Arrays; +import java.util.BitSet; +import java.util.concurrent.atomic.AtomicLongArray; +import java.util.concurrent.atomic.AtomicReference; + +// https://github.com/Netflix/hollow/blob/master/hollow/src/main/java/com/netflix/hollow/core/memory/ThreadSafeBitSet.java +// Refactored to remove unused methods, deduplicate some code segments, and add extra functionality with #forEachSetSpan +public class AtomicBitset { + // 1024 bits, 128 bytes, 16 longs per segment + public static final int DEFAULT_LOG2_SEGMENT_SIZE_IN_BITS = 10; + + private final int numLongsPerSegment; + private final int log2SegmentSize; + private final int segmentMask; + private final AtomicReference segments; + + public AtomicBitset() { + this(DEFAULT_LOG2_SEGMENT_SIZE_IN_BITS); + } + + public AtomicBitset(int log2SegmentSizeInBits) { + this(log2SegmentSizeInBits, 0); + } + + public AtomicBitset(int log2SegmentSizeInBits, int numBitsToPreallocate) { + if (log2SegmentSizeInBits < 6) { + throw new IllegalArgumentException("Cannot specify fewer than 64 bits in each segment!"); + } + + this.log2SegmentSize = log2SegmentSizeInBits; + this.numLongsPerSegment = (1 << (log2SegmentSizeInBits - 6)); + this.segmentMask = numLongsPerSegment - 1; + + long numBitsPerSegment = numLongsPerSegment * 64L; + int numSegmentsToPreallocate = numBitsToPreallocate == 0 ? 1 : (int) (((numBitsToPreallocate - 1) / numBitsPerSegment) + 1); + + segments = new AtomicReference<>(new AtomicBitsetSegments(numSegmentsToPreallocate, numLongsPerSegment)); + } + + public void set(int position) { + int longPosition = longIndexInSegmentForPosition(position); + + AtomicLongArray segment = getSegmentForPosition(position); + + long mask = maskForPosition(position); + + // Thread safety: we need to loop until we win the race to set the long value. + while (true) { + // determine what the new long value will be after we set the appropriate bit. + long currentLongValue = segment.get(longPosition); + long newLongValue = currentLongValue | mask; + + // if no other thread has modified the value since we read it, we won the race and we are done. + if (segment.compareAndSet(longPosition, currentLongValue, newLongValue)) { + break; + } + } + } + + public void clear(int position) { + int longPosition = longIndexInSegmentForPosition(position); + + AtomicLongArray segment = getSegmentForPosition(position); + + long mask = ~maskForPosition(position); + + // Thread safety: we need to loop until we win the race to set the long value. + while (true) { + // determine what the new long value will be after we set the appropriate bit. + long currentLongValue = segment.get(longPosition); + long newLongValue = currentLongValue & mask; + + // if no other thread has modified the value since we read it, we won the race and we are done. + if (segment.compareAndSet(longPosition, currentLongValue, newLongValue)) { + break; + } + } + } + + public boolean get(int position) { + int segmentPosition = segmentIndexForPosition(position); + int longPosition = longIndexInSegmentForPosition(position); + + AtomicLongArray segment = segmentForPosition(segmentPosition); + + long mask = maskForPosition(position); + + return ((segment.get(longPosition) & mask) != 0); + } + + public long maxSetBit() { + AtomicBitsetSegments segments = this.segments.get(); + + int segmentIdx = segments.numSegments() - 1; + + for (; segmentIdx >= 0; segmentIdx--) { + AtomicLongArray segment = segments.getSegment(segmentIdx); + for (int longIdx = segment.length() - 1; longIdx >= 0; longIdx--) { + long l = segment.get(longIdx); + if (l != 0) { + return ((long) segmentIdx << log2SegmentSize) + (longIdx * 64L) + (63 - Long.numberOfLeadingZeros(l)); + } + } + } + + return -1; + } + + public int nextSetBit(int fromIndex) { + if (fromIndex < 0) { + throw new IndexOutOfBoundsException("fromIndex < 0: " + fromIndex); + } + + + AtomicBitsetSegments segments = this.segments.get(); + + int segmentPosition = segmentIndexForPosition(fromIndex); + if (segmentPosition >= segments.numSegments()) { + return -1; + } + + int longPosition = longIndexInSegmentForPosition(fromIndex); + AtomicLongArray segment = segments.getSegment(segmentPosition); + + long word = segment.get(longPosition) & (0xffffffffffffffffL << bitPosInLongForPosition(fromIndex)); + + while (true) { + if (word != 0) { + return (segmentPosition << (log2SegmentSize)) + (longPosition << 6) + Long.numberOfTrailingZeros(word); + } + if (++longPosition > segmentMask) { + segmentPosition++; + if (segmentPosition >= segments.numSegments()) { + return -1; + } + segment = segments.getSegment(segmentPosition); + longPosition = 0; + } + + word = segment.get(longPosition); + } + } + + public int nextClearBit(int fromIndex) { + if (fromIndex < 0) { + throw new IndexOutOfBoundsException("fromIndex < 0: " + fromIndex); + } + + int segmentPosition = segmentIndexForPosition(fromIndex); + + AtomicBitsetSegments segments = this.segments.get(); + + if (segmentPosition >= segments.numSegments()) { + return fromIndex; + } + + int longPosition = longIndexInSegmentForPosition(fromIndex); + AtomicLongArray segment = segments.getSegment(segmentPosition); + + long word = ~segment.get(longPosition) & (0xffffffffffffffffL << bitPosInLongForPosition(fromIndex)); + + while (true) { + if (word != 0) { + return (segmentPosition << (log2SegmentSize)) + (longPosition << 6) + Long.numberOfTrailingZeros(word); + } + if (++longPosition > segmentMask) { + segmentPosition++; + if (segmentPosition >= segments.numSegments()) { + return segments.numSegments() << log2SegmentSize; + } + segment = segments.getSegment(segmentPosition); + longPosition = 0; + } + + word = segment.get(longPosition); + } + } + + + /** + * @return the number of bits which are set in this bit set. + */ + public int cardinality() { + return this.segments.get() + .cardinality(); + } + + /** + * Iterate over each contiguous span of set bits. + * + * @param consumer The consumer to accept each span. + */ + public void forEachSetSpan(BitSpanConsumer consumer) { + AtomicBitsetSegments segments = this.segments.get(); + + if (segments.cardinality() == 0) { + return; + } + + int start = -1; + int end = -1; + + for (int segmentIndex = 0; segmentIndex < segments.numSegments(); segmentIndex++) { + AtomicLongArray segment = segments.getSegment(segmentIndex); + for (int longIndex = 0; longIndex < segment.length(); longIndex++) { + long l = segment.get(longIndex); + if (l != 0) { + // The JIT loves this loop. Trying to be clever by starting from Long.numberOfLeadingZeros(l) + // causes it to be much slower. + for (int bitIndex = 0; bitIndex < 64; bitIndex++) { + if ((l & (1L << bitIndex)) != 0) { + var position = (segmentIndex << log2SegmentSize) + (longIndex << 6) + bitIndex; + if (start == -1) { + start = position; + } + end = position; + } else { + if (start != -1) { + consumer.accept(start, end); + start = -1; + end = -1; + } + } + } + } else { + if (start != -1) { + consumer.accept(start, end); + start = -1; + end = -1; + } + } + } + } + + if (start != -1) { + consumer.accept(start, end); + } + } + + /** + * @return the number of bits which are currently specified by this bit set. This is the maximum value + * to which you might need to iterate, if you were to iterate over all bits in this set. + */ + public int currentCapacity() { + return segments.get() + .numSegments() * (1 << log2SegmentSize); + } + + public boolean isEmpty() { + return cardinality() == 0; + } + + /** + * Clear all bits to 0. + */ + public void clear() { + AtomicBitsetSegments segments = this.segments.get(); + + for (int i = 0; i < segments.numSegments(); i++) { + AtomicLongArray segment = segments.getSegment(i); + + for (int j = 0; j < segment.length(); j++) { + segment.set(j, 0L); + } + } + } + + /** + * Which bit in the long the given position resides in. + * + * @param position The absolute position in the bitset. + * @return The bit position in the long. + */ + private static int bitPosInLongForPosition(int position) { + // remainder of div by num bits in long (64) + return position & 0x3F; + } + + /** + * Which long in the segment the given position resides in. + * + * @param position The absolute position in the bitset + * @return The long position in the segment. + */ + private int longIndexInSegmentForPosition(int position) { + // remainder of div by num bits per segment + return (position >>> 6) & segmentMask; + } + + /** + * Which segment the given position resides in. + * + * @param position The absolute position in the bitset + * @return The segment index. + */ + private int segmentIndexForPosition(int position) { + // div by num bits per segment + return position >>> log2SegmentSize; + } + + private static long maskForPosition(int position) { + return 1L << bitPosInLongForPosition(position); + } + + private AtomicLongArray getSegmentForPosition(int position) { + return segmentForPosition(segmentIndexForPosition(position)); + } + + /** + * Get the segment at segmentIndex. If this segment does not yet exist, create it. + * + * @param segmentIndex the segment index + * @return the segment + */ + private AtomicLongArray segmentForPosition(int segmentIndex) { + AtomicBitsetSegments visibleSegments = segments.get(); + + while (visibleSegments.numSegments() <= segmentIndex) { + // Thread safety: newVisibleSegments contains all of the segments from the currently visible segments, plus extra. + // all of the segments in the currently visible segments are canonical and will not change. + AtomicBitsetSegments newVisibleSegments = new AtomicBitsetSegments(visibleSegments, segmentIndex + 1, numLongsPerSegment); + + // because we are using a compareAndSet, if this thread "wins the race" and successfully sets this variable, then the segments + // which are newly defined in newVisibleSegments become canonical. + if (segments.compareAndSet(visibleSegments, newVisibleSegments)) { + visibleSegments = newVisibleSegments; + } else { + // If we "lose the race" and are growing the AtomicBitset segments larger, + // then we will gather the new canonical sets from the update which we missed on the next iteration of this loop. + // Newly defined segments in newVisibleSegments will be discarded, they do not get to become canonical. + visibleSegments = segments.get(); + } + } + + return visibleSegments.getSegment(segmentIndex); + } + + private static class AtomicBitsetSegments { + private final AtomicLongArray[] segments; + + private AtomicBitsetSegments(int numSegments, int segmentLength) { + AtomicLongArray[] segments = new AtomicLongArray[numSegments]; + + for (int i = 0; i < numSegments; i++) { + segments[i] = new AtomicLongArray(segmentLength); + } + + // Thread safety: Because this.segments is final, the preceding operations in this constructor are guaranteed to be visible to any + // other thread which accesses this.segments. + this.segments = segments; + } + + private AtomicBitsetSegments(AtomicBitsetSegments copyFrom, int numSegments, int segmentLength) { + AtomicLongArray[] segments = new AtomicLongArray[numSegments]; + + for (int i = 0; i < numSegments; i++) { + segments[i] = i < copyFrom.numSegments() ? copyFrom.getSegment(i) : new AtomicLongArray(segmentLength); + } + + // see above re: thread-safety of this assignment + this.segments = segments; + } + + private int cardinality() { + int numSetBits = 0; + + for (int i = 0; i < numSegments(); i++) { + AtomicLongArray segment = getSegment(i); + for (int j = 0; j < segment.length(); j++) { + numSetBits += Long.bitCount(segment.get(j)); + } + } + return numSetBits; + } + + public int numSegments() { + return segments.length; + } + + public AtomicLongArray getSegment(int index) { + return segments[index]; + } + + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof AtomicBitset other)) { + return false; + } + + if (other.log2SegmentSize != log2SegmentSize) { + throw new IllegalArgumentException("Segment sizes must be the same"); + } + + AtomicBitsetSegments thisSegments = this.segments.get(); + AtomicBitsetSegments otherSegments = other.segments.get(); + + for (int i = 0; i < thisSegments.numSegments(); i++) { + AtomicLongArray thisArray = thisSegments.getSegment(i); + AtomicLongArray otherArray = (i < otherSegments.numSegments()) ? otherSegments.getSegment(i) : null; + + for (int j = 0; j < thisArray.length(); j++) { + long thisLong = thisArray.get(j); + long otherLong = (otherArray == null) ? 0 : otherArray.get(j); + + if (thisLong != otherLong) { + return false; + } + } + } + + for (int i = thisSegments.numSegments(); i < otherSegments.numSegments(); i++) { + AtomicLongArray otherArray = otherSegments.getSegment(i); + + for (int j = 0; j < otherArray.length(); j++) { + long l = otherArray.get(j); + + if (l != 0) { + return false; + } + } + } + + return true; + } + + @Override + public int hashCode() { + int result = log2SegmentSize; + result = 31 * result + Arrays.hashCode(segments.get().segments); + return result; + } + + /** + * @return a new BitSet with same bits set + */ + public BitSet toBitSet() { + BitSet resultSet = new BitSet(); + int ordinal = this.nextSetBit(0); + while (ordinal != -1) { + resultSet.set(ordinal); + ordinal = this.nextSetBit(ordinal + 1); + } + return resultSet; + } + + @Override + public String toString() { + return toBitSet().toString(); + } + + @FunctionalInterface + public interface BitSpanConsumer { + /** + * Consume a span of bits. + * + * @param startInclusive The first (inclusive) bit in the span. + * @param endInclusive The last (inclusive) bit in the span. + */ + void accept(int startInclusive, int endInclusive); + } +}