Fix tracked MemoryBlocks

This commit is contained in:
PepperCode1 2022-08-11 15:48:08 -07:00
parent f0823af00f
commit 293d6ee59c
5 changed files with 60 additions and 42 deletions

View file

@ -3,11 +3,11 @@ package com.jozufozu.flywheel.api.struct;
import com.jozufozu.flywheel.api.instancer.InstancedPart; import com.jozufozu.flywheel.api.instancer.InstancedPart;
/** /**
* StructWriters can quickly consume many instances of S and write them to some backing buffer. * StructWriters can quickly consume many instances of S and write them to some memory address.
*/ */
public interface StructWriter<S extends InstancedPart> { public interface StructWriter<S extends InstancedPart> {
/** /**
* Write the given struct to given memory address. * Write the given struct to the given memory address.
*/ */
void write(long ptr, S struct); void write(long ptr, S struct);
} }

View file

@ -21,50 +21,40 @@ public class FlwMemoryTracker {
public static MemoryBlock mallocBlock(long size) { public static MemoryBlock mallocBlock(long size) {
MemoryBlock block = new MemoryBlockImpl(malloc(size), size); MemoryBlock block = new MemoryBlockImpl(malloc(size), size);
cpuMemory += block.size(); _allocCPUMemory(block.size());
return block; return block;
} }
public static MemoryBlock mallocBlockTracked(long size) { public static MemoryBlock mallocBlockTracked(long size) {
TrackedMemoryBlockImpl block = new TrackedMemoryBlockImpl(malloc(size), size); MemoryBlock block = new TrackedMemoryBlockImpl(malloc(size), size, CLEANER);
block.cleanable = CLEANER.register(block, () -> { _allocCPUMemory(block.size());
if (!block.isFreed()) {
block.free();
}
});
cpuMemory += block.size();
return block; return block;
} }
@Deprecated @Deprecated
public static ByteBuffer mallocBuffer(int size) { public static ByteBuffer mallocBuffer(int size) {
ByteBuffer buffer = MemoryUtil.memByteBuffer(malloc(size), size); ByteBuffer buffer = MemoryUtil.memByteBuffer(malloc(size), size);
cpuMemory += buffer.capacity(); _allocCPUMemory(buffer.capacity());
return buffer; return buffer;
} }
public static long calloc(long num, long size) { public static long calloc(long num, long size) {
long ptr = MemoryUtil.nmemCalloc(num, size); long ptr = MemoryUtil.nmemCalloc(num, size);
if (ptr == MemoryUtil.NULL) { if (ptr == MemoryUtil.NULL) {
throw new OutOfMemoryError("Failed to allocate " + num + " blocks of size " + size + " bytes"); throw new OutOfMemoryError("Failed to allocate " + num + " elements of size " + size + " bytes");
} }
return ptr; return ptr;
} }
public static MemoryBlock callocBlock(long num, long size) { public static MemoryBlock callocBlock(long num, long size) {
MemoryBlock block = new MemoryBlockImpl(calloc(num, size), num * size); MemoryBlock block = new MemoryBlockImpl(calloc(num, size), num * size);
cpuMemory += block.size(); _allocCPUMemory(block.size());
return block; return block;
} }
public static MemoryBlock callocBlockTracked(long num, long size) { public static MemoryBlock callocBlockTracked(long num, long size) {
TrackedMemoryBlockImpl block = new TrackedMemoryBlockImpl(calloc(num, size), num * size); MemoryBlock block = new TrackedMemoryBlockImpl(calloc(num, size), num * size, CLEANER);
block.cleanable = CLEANER.register(block, () -> { _allocCPUMemory(block.size());
if (!block.isFreed()) {
block.free();
}
});
cpuMemory += block.size();
return block; return block;
} }
@ -78,25 +68,23 @@ public class FlwMemoryTracker {
public static MemoryBlock reallocBlock(MemoryBlock block, long size) { public static MemoryBlock reallocBlock(MemoryBlock block, long size) {
MemoryBlock newBlock = new MemoryBlockImpl(realloc(block.ptr(), size), size); MemoryBlock newBlock = new MemoryBlockImpl(realloc(block.ptr(), size), size);
cpuMemory += -block.size() + newBlock.size(); _freeCPUMemory(block.size());
_allocCPUMemory(newBlock.size());
return newBlock; return newBlock;
} }
public static MemoryBlock reallocBlockTracked(MemoryBlock block, long size) { public static MemoryBlock reallocBlockTracked(MemoryBlock block, long size) {
TrackedMemoryBlockImpl newBlock = new TrackedMemoryBlockImpl(realloc(block.ptr(), size), size); MemoryBlock newBlock = new TrackedMemoryBlockImpl(realloc(block.ptr(), size), size, CLEANER);
newBlock.cleanable = CLEANER.register(newBlock, () -> { _freeCPUMemory(block.size());
if (!newBlock.isFreed()) { _allocCPUMemory(newBlock.size());
newBlock.free();
}
});
cpuMemory += -block.size() + newBlock.size();
return newBlock; return newBlock;
} }
@Deprecated @Deprecated
public static ByteBuffer reallocBuffer(ByteBuffer buffer, int size) { public static ByteBuffer reallocBuffer(ByteBuffer buffer, int size) {
ByteBuffer newBuffer = MemoryUtil.memByteBuffer(realloc(MemoryUtil.memAddress(buffer), size), size); ByteBuffer newBuffer = MemoryUtil.memByteBuffer(realloc(MemoryUtil.memAddress(buffer), size), size);
cpuMemory += -buffer.capacity() + newBuffer.capacity(); _freeCPUMemory(buffer.capacity());
_allocCPUMemory(newBuffer.capacity());
return newBuffer; return newBuffer;
} }
@ -106,13 +94,13 @@ public class FlwMemoryTracker {
public static void freeBlock(MemoryBlock block) { public static void freeBlock(MemoryBlock block) {
free(block.ptr()); free(block.ptr());
cpuMemory -= block.size(); _freeCPUMemory(block.size());
} }
@Deprecated @Deprecated
public static void freeBuffer(ByteBuffer buffer) { public static void freeBuffer(ByteBuffer buffer) {
free(MemoryUtil.memAddress(buffer)); free(MemoryUtil.memAddress(buffer));
cpuMemory -= buffer.capacity(); _freeCPUMemory(buffer.capacity());
} }
public static void _allocCPUMemory(long size) { public static void _allocCPUMemory(long size) {

View file

@ -13,9 +13,7 @@ public sealed interface MemoryBlock permits MemoryBlockImpl {
void copyTo(long ptr, long bytes); void copyTo(long ptr, long bytes);
default void copyTo(long ptr) { void copyTo(long ptr);
copyTo(ptr, size());
}
void clear(); void clear();

View file

@ -40,6 +40,11 @@ sealed class MemoryBlockImpl implements MemoryBlock permits TrackedMemoryBlockIm
MemoryUtil.memCopy(this.ptr, ptr, bytes); MemoryUtil.memCopy(this.ptr, ptr, bytes);
} }
@Override
public void copyTo(long ptr) {
copyTo(ptr, size);
}
@Override @Override
public void clear() { public void clear() {
MemoryUtil.memSet(ptr, 0, size); MemoryUtil.memSet(ptr, 0, size);

View file

@ -3,10 +3,13 @@ package com.jozufozu.flywheel.backend.memory;
import java.lang.ref.Cleaner; import java.lang.ref.Cleaner;
final class TrackedMemoryBlockImpl extends MemoryBlockImpl { final class TrackedMemoryBlockImpl extends MemoryBlockImpl {
Cleaner.Cleanable cleanable; final CleaningAction cleaningAction;
final Cleaner.Cleanable cleanable;
TrackedMemoryBlockImpl(long ptr, long size) { TrackedMemoryBlockImpl(long ptr, long size, Cleaner cleaner) {
super(ptr, size); super(ptr, size);
cleaningAction = new CleaningAction(ptr, size);
cleanable = cleaner.register(this, cleaningAction);
} }
@Override @Override
@ -14,25 +17,49 @@ final class TrackedMemoryBlockImpl extends MemoryBlockImpl {
return true; return true;
} }
void freeInner() {
freed = true;
cleaningAction.freed = true;
cleanable.clean();
}
@Override @Override
public MemoryBlock realloc(long size) { public MemoryBlock realloc(long size) {
MemoryBlock block = FlwMemoryTracker.reallocBlock(this, size); MemoryBlock block = FlwMemoryTracker.reallocBlock(this, size);
freed = true; freeInner();
cleanable.clean();
return block; return block;
} }
@Override @Override
public MemoryBlock reallocTracked(long size) { public MemoryBlock reallocTracked(long size) {
MemoryBlock block = FlwMemoryTracker.reallocBlockTracked(this, size); MemoryBlock block = FlwMemoryTracker.reallocBlockTracked(this, size);
freed = true; freeInner();
cleanable.clean();
return block; return block;
} }
@Override @Override
public void free() { public void free() {
cleanable.clean(); FlwMemoryTracker.freeBlock(this);
freed = true; freeInner();
}
static class CleaningAction implements Runnable {
final long ptr;
final long size;
boolean freed;
CleaningAction(long ptr, long size) {
this.ptr = ptr;
this.size = size;
}
@Override
public void run() {
if (!freed) {
FlwMemoryTracker.free(ptr);
FlwMemoryTracker._freeCPUMemory(size);
}
}
} }
} }