diff --git a/src/main/java/com/jozufozu/flywheel/api/layout/Layout.java b/src/main/java/com/jozufozu/flywheel/api/layout/Layout.java index 3c4d1e83e..10d6b9a10 100644 --- a/src/main/java/com/jozufozu/flywheel/api/layout/Layout.java +++ b/src/main/java/com/jozufozu/flywheel/api/layout/Layout.java @@ -13,8 +13,7 @@ public interface Layout { @Unmodifiable List elements(); - @Unmodifiable - Map asMap(); + @Unmodifiable Map asMap(); int byteSize(); @@ -23,5 +22,7 @@ public interface Layout { String name(); ElementType type(); + + int offset(); } } diff --git a/src/main/java/com/jozufozu/flywheel/backend/compile/component/IndirectComponent.java b/src/main/java/com/jozufozu/flywheel/backend/compile/component/IndirectComponent.java index af02aa7d7..bc62f1d8e 100644 --- a/src/main/java/com/jozufozu/flywheel/backend/compile/component/IndirectComponent.java +++ b/src/main/java/com/jozufozu/flywheel/backend/compile/component/IndirectComponent.java @@ -4,10 +4,19 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.function.Function; + +import org.jetbrains.annotations.NotNull; import com.jozufozu.flywheel.Flywheel; import com.jozufozu.flywheel.api.instance.InstanceType; +import com.jozufozu.flywheel.api.layout.FloatRepr; +import com.jozufozu.flywheel.api.layout.IntegerRepr; import com.jozufozu.flywheel.api.layout.Layout; +import com.jozufozu.flywheel.api.layout.MatrixElementType; +import com.jozufozu.flywheel.api.layout.ScalarElementType; +import com.jozufozu.flywheel.api.layout.UnsignedIntegerRepr; +import com.jozufozu.flywheel.api.layout.VectorElementType; import com.jozufozu.flywheel.backend.compile.LayoutInterpreter; import com.jozufozu.flywheel.backend.compile.Pipeline; import com.jozufozu.flywheel.glsl.SourceComponent; @@ -15,7 +24,7 @@ import com.jozufozu.flywheel.glsl.generate.FnSignature; import com.jozufozu.flywheel.glsl.generate.GlslBlock; import com.jozufozu.flywheel.glsl.generate.GlslBuilder; import com.jozufozu.flywheel.glsl.generate.GlslExpr; -import com.jozufozu.flywheel.lib.layout.LayoutItem; +import com.jozufozu.flywheel.glsl.generate.GlslStruct; import net.minecraft.resources.ResourceLocation; @@ -27,10 +36,8 @@ public class IndirectComponent implements SourceComponent { private static final String UNPACK_FN_NAME = "_flw_unpackInstance"; private final Layout layout; - private final List layoutItems; public IndirectComponent(InstanceType type) { - this.layoutItems = type.oldLayout().layoutItems; this.layout = type.layout(); } @@ -60,8 +67,6 @@ public class IndirectComponent implements SourceComponent { public String generateIndirect() { var builder = new GlslBuilder(); - generateHeader(builder); - generateInstanceStruct(builder); builder.blankLine(); @@ -73,13 +78,6 @@ public class IndirectComponent implements SourceComponent { return builder.build(); } - private void generateHeader(GlslBuilder builder) { - layoutItems.stream() - .map(LayoutItem::type) - .distinct() - .forEach(type -> type.declare(builder)); - } - private void generateInstanceStruct(GlslBuilder builder) { var instance = builder.struct(); instance.setName(STRUCT_NAME); @@ -93,13 +91,9 @@ public class IndirectComponent implements SourceComponent { packed.setName(PACKED_STRUCT_NAME); var unpackArgs = new ArrayList(); - for (LayoutItem field : layoutItems) { - GlslExpr unpack = UNPACKING_VARIABLE.access(field.name()) - .transform(field.type()::unpack); - unpackArgs.add(unpack); - packed.addField(field.type() - .packedTypeName(), field.name()); + for (Layout.Element element : layout.elements()) { + unpackArgs.add(unpackElement(element, packed)); } var block = new GlslBlock(); @@ -114,4 +108,224 @@ public class IndirectComponent implements SourceComponent { .build()) .body(block); } + + public static GlslExpr unpackElement(Layout.Element element, GlslStruct packed) { + // FIXME: I don't think we're unpacking signed byte/short values correctly + // FIXME: we definitely don't consider endianness. this all assumes little endian which works on my machine. + var type = element.type(); + + if (type instanceof ScalarElementType scalar) { + return unpackScalar(element.name(), packed, scalar); + } else if (type instanceof VectorElementType vector) { + return unpackVector(element.name(), packed, vector); + } else if (type instanceof MatrixElementType matrix) { + return unpackMatrix(element.name(), packed, matrix); + } + + throw new IllegalArgumentException("Unknown type " + type); + } + + private static GlslExpr unpackMatrix(String name, GlslStruct packed, MatrixElementType matrix) { + var repr = matrix.repr(); + + int rows = matrix.rows(); + int columns = matrix.columns(); + + List args = new ArrayList<>(); + + for (int i = 0; i < columns; i++) { + args.add(unpackFloatVector(name + "_" + i, repr, packed, rows)); + } + + return GlslExpr.call("mat" + columns + "x" + rows, args); + } + + private static GlslExpr unpackScalar(String fieldName, GlslStruct packed, ScalarElementType scalar) { + var repr = scalar.repr(); + + if (repr instanceof FloatRepr floatRepr) { + return unpackFloatScalar(fieldName, floatRepr, packed); + } else if (repr instanceof IntegerRepr intRepr) { + return unpackIntScalar(fieldName, intRepr, packed); + } else if (repr instanceof UnsignedIntegerRepr unsignedIntegerRepr) { + return unpackUnsignedScalar(fieldName, unsignedIntegerRepr, packed); + } + + throw new IllegalArgumentException("Unknown repr " + repr); + } + + private static GlslExpr unpackVector(String fieldName, GlslStruct packed, VectorElementType vector) { + var repr = vector.repr(); + + int size = vector.size(); + + if (repr instanceof FloatRepr floatRepr) { + return unpackFloatVector(fieldName, floatRepr, packed, size); + } else if (repr instanceof IntegerRepr intRepr) { + return unpackIntVector(fieldName, intRepr, packed, size); + } else if (repr instanceof UnsignedIntegerRepr unsignedIntegerRepr) { + return unpackUnsignedVector(fieldName, unsignedIntegerRepr, packed, size); + } + + throw new IllegalArgumentException("Unknown repr " + repr); + } + + private static GlslExpr unpackFloatVector(String fieldName, FloatRepr floatRepr, GlslStruct packed, int size) { + return switch (floatRepr) { + case NORMALIZED_BYTE -> unpackBuiltin(fieldName, packed, size, "unpackSnorm4x8"); + case NORMALIZED_UNSIGNED_BYTE -> unpackBuiltin(fieldName, packed, size, "unpackUnorm4x8"); + case NORMALIZED_SHORT -> unpackBuiltin(fieldName, packed, size, "unpackSnorm2x16"); + case NORMALIZED_UNSIGNED_SHORT -> unpackBuiltin(fieldName, packed, size, "unpackUnorm2x16"); + case NORMALIZED_INT -> unpack(fieldName, packed, size, "int", "vec" + size, e -> e.div(2147483647f) + .clamp(-1, 1)); + case NORMALIZED_UNSIGNED_INT -> + unpack(fieldName, packed, size, "uint", "vec" + size, e -> e.div(4294967295f)); + case BYTE -> unpackByteBacked(fieldName, packed, size, "vec" + size, e -> e.cast("int") + .cast("float")); + case UNSIGNED_BYTE -> unpackByteBacked(fieldName, packed, size, "vec" + size, e -> e.cast("float")); + case SHORT -> unpackShortBacked(fieldName, packed, size, "vec" + size, e -> e.cast("int") + .cast("float")); + case UNSIGNED_SHORT -> unpackShortBacked(fieldName, packed, size, "vec" + size, e -> e.cast("float")); + case INT -> unpack(fieldName, packed, size, "int", "vec" + size, e -> e.cast("float")); + case UNSIGNED_INT -> unpack(fieldName, packed, size, "float", "vec" + size, e -> e.cast("float")); + case FLOAT -> unpack(fieldName, packed, size, "float", "vec" + size); + }; + } + + private static GlslExpr unpackUnsignedVector(String fieldName, UnsignedIntegerRepr unsignedIntegerRepr, GlslStruct packed, int size) { + return switch (unsignedIntegerRepr) { + case UNSIGNED_BYTE -> unpackByteBacked(fieldName, packed, size, "uvec" + size, e -> e.cast("uint")); + case UNSIGNED_SHORT -> unpackShortBacked(fieldName, packed, size, "uvec" + size, e -> e.cast("uint")); + case UNSIGNED_INT -> unpack(fieldName, packed, size, "uint", "uvec" + size); + }; + } + + private static GlslExpr unpackIntVector(String fieldName, IntegerRepr repr, GlslStruct packed, int size) { + return switch (repr) { + case BYTE -> unpackByteBacked(fieldName, packed, size, "ivec" + size, e -> e.cast("int")); + case SHORT -> unpackShortBacked(fieldName, packed, size, "ivec" + size, e -> e.cast("int")); + case INT -> unpack(fieldName, packed, size, "int", "ivec" + size); + }; + } + + @NotNull + private static GlslExpr unpack(String fieldName, GlslStruct packed, int size, String backingType, String outType) { + return unpack(fieldName, packed, size, backingType, outType, Function.identity()); + } + + @NotNull + private static GlslExpr unpack(String fieldName, GlslStruct packed, int size, String backingType, String outType, Function perElement) { + List args = new ArrayList<>(); + for (int i = 0; i < size; i++) { + var name = "_" + fieldName + "_" + i; + packed.addField(backingType, name); + args.add(UNPACKING_VARIABLE.access(name) + .transform(perElement)); + } + return GlslExpr.call(outType, args); + } + + @NotNull + private static GlslExpr unpackBuiltin(String fieldName, GlslStruct packed, int size, String func) { + packed.addField("uint", fieldName); + GlslExpr expr = UNPACKING_VARIABLE.access(fieldName) + .callFunction(func); + return switch (size) { + case 2 -> expr.swizzle("xy"); + case 3 -> expr.swizzle("xyz"); + case 4 -> expr; + default -> throw new IllegalArgumentException("Invalid vector size " + size); + }; + } + + @NotNull + private static GlslExpr unpackByteBacked(String fieldName, GlslStruct packed, int size, String outType, Function perElement) { + packed.addField("uint", fieldName); + List args = new ArrayList<>(); + for (int i = 0; i < size; i++) { + int bitPos = i * 8; + var element = UNPACKING_VARIABLE.access(fieldName) + .and(0xFF << bitPos) + .rsh(bitPos); + args.add(perElement.apply(element)); + } + return GlslExpr.call(outType + size, args); + } + + @NotNull + private static GlslExpr unpackShortBacked(String fieldName, GlslStruct packed, int size, String outType, Function perElement) { + List args = new ArrayList<>(); + for (int i = 0; i < size; i++) { + int unpackField = i / 2; + int bitPos = (i % 2) * 16; + var name = "_" + fieldName + "_" + unpackField; + if (bitPos == 0) { + // First time we're seeing this field, add it to the struct. + packed.addField("uint", name); + } + var element = UNPACKING_VARIABLE.access(name) + .and(0xFFFF << bitPos) + .rsh(bitPos); + args.add(perElement.apply(element)); + } + return GlslExpr.call(outType, args); + } + + private static GlslExpr unpackFloatScalar(String fieldName, FloatRepr repr, GlslStruct packed) { + return switch (repr) { + case BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF) + .cast("int") + .cast("float")); + case NORMALIZED_BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackSnorm4x8") + .swizzle("x")); + case UNSIGNED_BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF) + .cast("float")); + case NORMALIZED_UNSIGNED_BYTE -> + unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackUnorm4x8") + .swizzle("x")); + case SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF) + .cast("int") + .cast("float")); + case NORMALIZED_SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackSnorm2x16") + .swizzle("x")); + case UNSIGNED_SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF) + .cast("float")); + case NORMALIZED_UNSIGNED_SHORT -> + unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackUnorm2x16") + .swizzle("x")); + case INT -> unpackScalar(fieldName, packed, "int", e -> e.cast("float")); + case NORMALIZED_INT -> unpackScalar(fieldName, packed, "int", e -> e.div(2147483647f) + .clamp(-1, 1)); + case UNSIGNED_INT -> unpackScalar(fieldName, packed, "uint", e -> e.cast("float")); + case NORMALIZED_UNSIGNED_INT -> unpackScalar(fieldName, packed, "uint", e -> e.div(4294967295f)); + case FLOAT -> unpackScalar(fieldName, packed, "float"); + }; + } + + private static GlslExpr unpackUnsignedScalar(String fieldName, UnsignedIntegerRepr repr, GlslStruct packed) { + return switch (repr) { + case UNSIGNED_BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF)); + case UNSIGNED_SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF)); + case UNSIGNED_INT -> unpackScalar(fieldName, packed, "uint"); + }; + } + + private static GlslExpr unpackIntScalar(String fieldName, IntegerRepr intRepr, GlslStruct packed) { + return switch (intRepr) { + case BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF) + .cast("int")); + case SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF) + .cast("int")); + case INT -> unpackScalar(fieldName, packed, "int"); + }; + } + + private static GlslExpr unpackScalar(String fieldName, GlslStruct packed, String packedType) { + return unpackScalar(fieldName, packed, packedType, Function.identity()); + } + + private static GlslExpr unpackScalar(String fieldName, GlslStruct packed, String packedType, Function perElement) { + packed.addField(packedType, fieldName); + return perElement.apply(UNPACKING_VARIABLE.access(fieldName)); + } } diff --git a/src/main/java/com/jozufozu/flywheel/glsl/generate/BinOp.java b/src/main/java/com/jozufozu/flywheel/glsl/generate/BinOp.java new file mode 100644 index 000000000..27bea127e --- /dev/null +++ b/src/main/java/com/jozufozu/flywheel/glsl/generate/BinOp.java @@ -0,0 +1,15 @@ +package com.jozufozu.flywheel.glsl.generate; + +public enum BinOp { + BITWISE_AND("&"), + RIGHT_SHIFT(">>"), + DIVIDE("/"), + // TODO: add more as we need them + ; + + public final String op; + + BinOp(String op) { + this.op = op; + } +} diff --git a/src/main/java/com/jozufozu/flywheel/glsl/generate/GlslExpr.java b/src/main/java/com/jozufozu/flywheel/glsl/generate/GlslExpr.java index 5103a6d86..801c4f79a 100644 --- a/src/main/java/com/jozufozu/flywheel/glsl/generate/GlslExpr.java +++ b/src/main/java/com/jozufozu/flywheel/glsl/generate/GlslExpr.java @@ -32,15 +32,23 @@ public interface GlslExpr { } static GlslExpr intLiteral(int expr) { - return new IntLiteral(expr); + return new RawLiteral(Integer.toString(expr)); } static GlslExpr uintLiteral(int expr) { - return new UIntLiteral(expr); + return new RawLiteral(Integer.toUnsignedString(expr) + 'u'); + } + + static GlslExpr uintHexLiteral(int expr) { + return new RawLiteral("0x" + Integer.toHexString(expr) + 'u'); } static GlslExpr boolLiteral(boolean expr) { - return new BoolLiteral(expr); + return new RawLiteral(Boolean.toString(expr)); + } + + static GlslExpr floatLiteral(float expr) { + return new RawLiteral(Float.toString(expr)); } /** @@ -53,6 +61,10 @@ public interface GlslExpr { return new FunctionCall(name, this); } + default FunctionCall cast(String name) { + return new FunctionCall(name, this); + } + /** * Swizzle the components of this expression. * @@ -83,6 +95,25 @@ public interface GlslExpr { return f.apply(this); } + default GlslExpr and(int mask) { + return new Binary(this, uintHexLiteral(mask), BinOp.BITWISE_AND); + } + + default GlslExpr rsh(int by) { + if (by == 0) { + return this; + } + return new Binary(this, uintLiteral(by), BinOp.RIGHT_SHIFT); + } + + default GlslExpr div(float v) { + return new Binary(this, floatLiteral(v), BinOp.DIVIDE); + } + + default GlslExpr clamp(float from, float to) { + return new Clamp(this, floatLiteral(from), floatLiteral(to)); + } + String prettyPrint(); record Variable(String name) implements GlslExpr { @@ -132,30 +163,24 @@ public interface GlslExpr { } - record IntLiteral(int value) implements GlslExpr { + record Clamp(GlslExpr value, GlslExpr from, GlslExpr to) implements GlslExpr { @Override public String prettyPrint() { - return Integer.toString(value); + return "clamp(" + value.prettyPrint() + ", " + from.prettyPrint() + ", " + to.prettyPrint() + ")"; } } - record UIntLiteral(int value) implements GlslExpr { - public UIntLiteral { - if (value < 0) { - throw new IllegalArgumentException("UIntLiteral must be positive"); - } - } - + record Binary(GlslExpr lhs, GlslExpr rhs, BinOp op) implements GlslExpr { @Override public String prettyPrint() { - return Integer.toString(value) + 'u'; + return "(" + lhs.prettyPrint() + " " + op.op + " " + rhs.prettyPrint() + ")"; } } - record BoolLiteral(boolean value) implements GlslExpr { + record RawLiteral(String value) implements GlslExpr { @Override public String prettyPrint() { - return Boolean.toString(value); + return value; } } } diff --git a/src/main/java/com/jozufozu/flywheel/impl/layout/LayoutBuilderImpl.java b/src/main/java/com/jozufozu/flywheel/impl/layout/LayoutBuilderImpl.java index 502fca4da..6ca6d034c 100644 --- a/src/main/java/com/jozufozu/flywheel/impl/layout/LayoutBuilderImpl.java +++ b/src/main/java/com/jozufozu/flywheel/impl/layout/LayoutBuilderImpl.java @@ -7,12 +7,14 @@ import java.util.Set; import org.jetbrains.annotations.Range; +import com.jozufozu.flywheel.api.layout.ElementType; import com.jozufozu.flywheel.api.layout.FloatRepr; import com.jozufozu.flywheel.api.layout.Layout; import com.jozufozu.flywheel.api.layout.Layout.Element; import com.jozufozu.flywheel.api.layout.LayoutBuilder; import com.jozufozu.flywheel.api.layout.ValueRepr; import com.jozufozu.flywheel.impl.layout.LayoutImpl.ElementImpl; +import com.jozufozu.flywheel.lib.math.MoreMath; import it.unimi.dsi.fastutil.objects.Object2IntMap; import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; @@ -101,23 +103,21 @@ public class LayoutBuilderImpl implements LayoutBuilder { ); private final List elements = new ArrayList<>(); + private int offset = 0; @Override public LayoutBuilder scalar(String name, ValueRepr repr) { - elements.add(new ElementImpl(name, new ScalarElementTypeImpl(repr))); - return this; + return element(name, ScalarElementTypeImpl.create(repr)); } @Override public LayoutBuilder vector(String name, ValueRepr repr, @Range(from = 2, to = 4) int size) { - elements.add(new ElementImpl(name, new VectorElementTypeImpl(repr, size))); - return this; + return element(name, VectorElementTypeImpl.create(repr, size)); } @Override public LayoutBuilder matrix(String name, FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns) { - elements.add(new ElementImpl(name, new MatrixElementTypeImpl(repr, rows, columns))); - return this; + return element(name, MatrixElementTypeImpl.create(repr, rows, columns)); } @Override @@ -125,6 +125,13 @@ public class LayoutBuilderImpl implements LayoutBuilder { return matrix(name, repr, size, size); } + private LayoutBuilder element(String name, ElementType type) { + elements.add(new ElementImpl(name, type, offset)); + offset += type.byteSize(); + offset = MoreMath.align4(offset); + return this; + } + @Override public Layout build() { Object2IntMap name2IndexMap = new Object2IntOpenHashMap<>(); diff --git a/src/main/java/com/jozufozu/flywheel/impl/layout/LayoutImpl.java b/src/main/java/com/jozufozu/flywheel/impl/layout/LayoutImpl.java index b2feaac2b..cac5567ea 100644 --- a/src/main/java/com/jozufozu/flywheel/impl/layout/LayoutImpl.java +++ b/src/main/java/com/jozufozu/flywheel/impl/layout/LayoutImpl.java @@ -15,16 +15,16 @@ final class LayoutImpl implements Layout { @Unmodifiable private final List elements; @Unmodifiable - private final Map map; + private final Map map; private final int byteSize; LayoutImpl(@Unmodifiable List elements) { this.elements = elements; - Object2ObjectOpenHashMap map = new Object2ObjectOpenHashMap<>(); + Object2ObjectOpenHashMap map = new Object2ObjectOpenHashMap<>(); int byteSize = 0; for (Element element : this.elements) { - map.put(element.name(), element.type()); + map.put(element.name(), element); byteSize += element.type().byteSize(); } map.trim(); @@ -41,7 +41,7 @@ final class LayoutImpl implements Layout { @Override @Unmodifiable - public Map asMap() { + public Map asMap() { return map; } @@ -73,29 +73,6 @@ final class LayoutImpl implements Layout { return elements.equals(other.elements); } - record ElementImpl(String name, ElementType type) implements Element { - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + name.hashCode(); - result = prime * result + type.hashCode(); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - ElementImpl other = (ElementImpl) obj; - return name.equals(other.name) && type.equals(other.type); - } + record ElementImpl(String name, ElementType type, int offset) implements Element { } } diff --git a/src/main/java/com/jozufozu/flywheel/impl/layout/MatrixElementTypeImpl.java b/src/main/java/com/jozufozu/flywheel/impl/layout/MatrixElementTypeImpl.java index 9c5c45b05..d65491478 100644 --- a/src/main/java/com/jozufozu/flywheel/impl/layout/MatrixElementTypeImpl.java +++ b/src/main/java/com/jozufozu/flywheel/impl/layout/MatrixElementTypeImpl.java @@ -5,72 +5,16 @@ import org.jetbrains.annotations.Range; import com.jozufozu.flywheel.api.layout.FloatRepr; import com.jozufozu.flywheel.api.layout.MatrixElementType; -final class MatrixElementTypeImpl implements MatrixElementType { - private final FloatRepr repr; - @Range(from = 2, to = 4) - private final int rows; - @Range(from = 2, to = 4) - private final int columns; - private final int byteSize; - - MatrixElementTypeImpl(FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns) { +record MatrixElementTypeImpl(FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns, + int byteSize) implements MatrixElementType { + static MatrixElementTypeImpl create(FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns) { if (rows < 2 || rows > 4) { throw new IllegalArgumentException("Matrix element row count must be in range [2, 4]!"); } if (columns < 2 || columns > 4) { throw new IllegalArgumentException("Matrix element column count must be in range [2, 4]!"); } - - this.repr = repr; - this.rows = rows; - this.columns = columns; - byteSize = repr.byteSize() * rows * columns; - } - - @Override - public FloatRepr repr() { - return repr; - } - - @Override - @Range(from = 2, to = 4) - public int rows() { - return rows; - } - - @Override - @Range(from = 2, to = 4) - public int columns() { - return columns; - } - - @Override - public int byteSize() { - return byteSize; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + repr.hashCode(); - result = prime * result + rows; - result = prime * result + columns; - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - MatrixElementTypeImpl other = (MatrixElementTypeImpl) obj; - return repr == other.repr && rows == other.rows && columns == other.columns; + int byteSize = repr.byteSize() * rows * columns; + return new MatrixElementTypeImpl(repr, rows, columns, byteSize); } } diff --git a/src/main/java/com/jozufozu/flywheel/impl/layout/ScalarElementTypeImpl.java b/src/main/java/com/jozufozu/flywheel/impl/layout/ScalarElementTypeImpl.java index a6b6d4361..f2d9cfcf0 100644 --- a/src/main/java/com/jozufozu/flywheel/impl/layout/ScalarElementTypeImpl.java +++ b/src/main/java/com/jozufozu/flywheel/impl/layout/ScalarElementTypeImpl.java @@ -3,45 +3,8 @@ package com.jozufozu.flywheel.impl.layout; import com.jozufozu.flywheel.api.layout.ScalarElementType; import com.jozufozu.flywheel.api.layout.ValueRepr; -final class ScalarElementTypeImpl implements ScalarElementType { - private final ValueRepr repr; - private final int byteSize; - - ScalarElementTypeImpl(ValueRepr repr) { - this.repr = repr; - byteSize = repr.byteSize(); - } - - @Override - public ValueRepr repr() { - return repr; - } - - @Override - public int byteSize() { - return byteSize; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + repr.hashCode(); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - ScalarElementTypeImpl other = (ScalarElementTypeImpl) obj; - return repr == other.repr; +record ScalarElementTypeImpl(ValueRepr repr, int byteSize) implements ScalarElementType { + static ScalarElementTypeImpl create(ValueRepr repr) { + return new ScalarElementTypeImpl(repr, repr.byteSize()); } } diff --git a/src/main/java/com/jozufozu/flywheel/impl/layout/VectorElementTypeImpl.java b/src/main/java/com/jozufozu/flywheel/impl/layout/VectorElementTypeImpl.java index 5f2426a2b..cdb5712f7 100644 --- a/src/main/java/com/jozufozu/flywheel/impl/layout/VectorElementTypeImpl.java +++ b/src/main/java/com/jozufozu/flywheel/impl/layout/VectorElementTypeImpl.java @@ -5,59 +5,15 @@ import org.jetbrains.annotations.Range; import com.jozufozu.flywheel.api.layout.ValueRepr; import com.jozufozu.flywheel.api.layout.VectorElementType; -final class VectorElementTypeImpl implements VectorElementType { - private final ValueRepr repr; - @Range(from = 2, to = 4) - private final int size; - private final int byteSize; +record VectorElementTypeImpl(ValueRepr repr, @Range(from = 2, to = 4) int size, + int byteSize) implements VectorElementType { - VectorElementTypeImpl(ValueRepr repr, @Range(from = 2, to = 4) int size) { + static VectorElementTypeImpl create(ValueRepr repr, @Range(from = 2, to = 4) int size) { if (size < 2 || size > 4) { throw new IllegalArgumentException("Vector element size must be in range [2, 4]!"); } - this.repr = repr; - this.size = size; - byteSize = repr.byteSize() * size; - } - - @Override - public ValueRepr repr() { - return repr; - } - - @Override - @Range(from = 2, to = 4) - public int size() { - return size; - } - - @Override - public int byteSize() { - return byteSize; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + repr.hashCode(); - result = prime * result + size; - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - VectorElementTypeImpl other = (VectorElementTypeImpl) obj; - return repr == other.repr && size == other.size; + int byteSize = repr.byteSize() * size; + return new VectorElementTypeImpl(repr, size, byteSize); } } diff --git a/src/main/java/com/jozufozu/flywheel/lib/math/MoreMath.java b/src/main/java/com/jozufozu/flywheel/lib/math/MoreMath.java index d0ed97451..e6b382a4f 100644 --- a/src/main/java/com/jozufozu/flywheel/lib/math/MoreMath.java +++ b/src/main/java/com/jozufozu/flywheel/lib/math/MoreMath.java @@ -10,7 +10,11 @@ public final class MoreMath { public static final float SQRT_3_OVER_2 = (float) (Math.sqrt(3.0) / 2.0); public static int align16(int numToRound) { - return (numToRound + 16 - 1) & -16; + return (numToRound + 15) & ~15; + } + + public static int align4(int offset1) { + return (offset1 + 3) & ~3; } public static int ceilingDiv(int numerator, int denominator) {