A lot to unpack here

- Elements are 4 byte aligned and store their offset.
- Make all ElementTypeImpls records.
- Implement unpacking for all types/reprs.
- Add many utilities to GlslExpr to facilitate unpacking.
This commit is contained in:
Jozufozu 2024-01-05 18:34:28 -08:00
parent 9309266435
commit 7ad163588e
10 changed files with 326 additions and 220 deletions

View file

@ -13,8 +13,7 @@ public interface Layout {
@Unmodifiable
List<Element> elements();
@Unmodifiable
Map<String, ElementType> asMap();
@Unmodifiable Map<String, Element> asMap();
int byteSize();
@ -23,5 +22,7 @@ public interface Layout {
String name();
ElementType type();
int offset();
}
}

View file

@ -4,10 +4,19 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import org.jetbrains.annotations.NotNull;
import com.jozufozu.flywheel.Flywheel;
import com.jozufozu.flywheel.api.instance.InstanceType;
import com.jozufozu.flywheel.api.layout.FloatRepr;
import com.jozufozu.flywheel.api.layout.IntegerRepr;
import com.jozufozu.flywheel.api.layout.Layout;
import com.jozufozu.flywheel.api.layout.MatrixElementType;
import com.jozufozu.flywheel.api.layout.ScalarElementType;
import com.jozufozu.flywheel.api.layout.UnsignedIntegerRepr;
import com.jozufozu.flywheel.api.layout.VectorElementType;
import com.jozufozu.flywheel.backend.compile.LayoutInterpreter;
import com.jozufozu.flywheel.backend.compile.Pipeline;
import com.jozufozu.flywheel.glsl.SourceComponent;
@ -15,7 +24,7 @@ import com.jozufozu.flywheel.glsl.generate.FnSignature;
import com.jozufozu.flywheel.glsl.generate.GlslBlock;
import com.jozufozu.flywheel.glsl.generate.GlslBuilder;
import com.jozufozu.flywheel.glsl.generate.GlslExpr;
import com.jozufozu.flywheel.lib.layout.LayoutItem;
import com.jozufozu.flywheel.glsl.generate.GlslStruct;
import net.minecraft.resources.ResourceLocation;
@ -27,10 +36,8 @@ public class IndirectComponent implements SourceComponent {
private static final String UNPACK_FN_NAME = "_flw_unpackInstance";
private final Layout layout;
private final List<LayoutItem> layoutItems;
public IndirectComponent(InstanceType<?> type) {
this.layoutItems = type.oldLayout().layoutItems;
this.layout = type.layout();
}
@ -60,8 +67,6 @@ public class IndirectComponent implements SourceComponent {
public String generateIndirect() {
var builder = new GlslBuilder();
generateHeader(builder);
generateInstanceStruct(builder);
builder.blankLine();
@ -73,13 +78,6 @@ public class IndirectComponent implements SourceComponent {
return builder.build();
}
private void generateHeader(GlslBuilder builder) {
layoutItems.stream()
.map(LayoutItem::type)
.distinct()
.forEach(type -> type.declare(builder));
}
private void generateInstanceStruct(GlslBuilder builder) {
var instance = builder.struct();
instance.setName(STRUCT_NAME);
@ -93,13 +91,9 @@ public class IndirectComponent implements SourceComponent {
packed.setName(PACKED_STRUCT_NAME);
var unpackArgs = new ArrayList<GlslExpr>();
for (LayoutItem field : layoutItems) {
GlslExpr unpack = UNPACKING_VARIABLE.access(field.name())
.transform(field.type()::unpack);
unpackArgs.add(unpack);
packed.addField(field.type()
.packedTypeName(), field.name());
for (Layout.Element element : layout.elements()) {
unpackArgs.add(unpackElement(element, packed));
}
var block = new GlslBlock();
@ -114,4 +108,224 @@ public class IndirectComponent implements SourceComponent {
.build())
.body(block);
}
public static GlslExpr unpackElement(Layout.Element element, GlslStruct packed) {
// FIXME: I don't think we're unpacking signed byte/short values correctly
// FIXME: we definitely don't consider endianness. this all assumes little endian which works on my machine.
var type = element.type();
if (type instanceof ScalarElementType scalar) {
return unpackScalar(element.name(), packed, scalar);
} else if (type instanceof VectorElementType vector) {
return unpackVector(element.name(), packed, vector);
} else if (type instanceof MatrixElementType matrix) {
return unpackMatrix(element.name(), packed, matrix);
}
throw new IllegalArgumentException("Unknown type " + type);
}
private static GlslExpr unpackMatrix(String name, GlslStruct packed, MatrixElementType matrix) {
var repr = matrix.repr();
int rows = matrix.rows();
int columns = matrix.columns();
List<GlslExpr> args = new ArrayList<>();
for (int i = 0; i < columns; i++) {
args.add(unpackFloatVector(name + "_" + i, repr, packed, rows));
}
return GlslExpr.call("mat" + columns + "x" + rows, args);
}
private static GlslExpr unpackScalar(String fieldName, GlslStruct packed, ScalarElementType scalar) {
var repr = scalar.repr();
if (repr instanceof FloatRepr floatRepr) {
return unpackFloatScalar(fieldName, floatRepr, packed);
} else if (repr instanceof IntegerRepr intRepr) {
return unpackIntScalar(fieldName, intRepr, packed);
} else if (repr instanceof UnsignedIntegerRepr unsignedIntegerRepr) {
return unpackUnsignedScalar(fieldName, unsignedIntegerRepr, packed);
}
throw new IllegalArgumentException("Unknown repr " + repr);
}
private static GlslExpr unpackVector(String fieldName, GlslStruct packed, VectorElementType vector) {
var repr = vector.repr();
int size = vector.size();
if (repr instanceof FloatRepr floatRepr) {
return unpackFloatVector(fieldName, floatRepr, packed, size);
} else if (repr instanceof IntegerRepr intRepr) {
return unpackIntVector(fieldName, intRepr, packed, size);
} else if (repr instanceof UnsignedIntegerRepr unsignedIntegerRepr) {
return unpackUnsignedVector(fieldName, unsignedIntegerRepr, packed, size);
}
throw new IllegalArgumentException("Unknown repr " + repr);
}
private static GlslExpr unpackFloatVector(String fieldName, FloatRepr floatRepr, GlslStruct packed, int size) {
return switch (floatRepr) {
case NORMALIZED_BYTE -> unpackBuiltin(fieldName, packed, size, "unpackSnorm4x8");
case NORMALIZED_UNSIGNED_BYTE -> unpackBuiltin(fieldName, packed, size, "unpackUnorm4x8");
case NORMALIZED_SHORT -> unpackBuiltin(fieldName, packed, size, "unpackSnorm2x16");
case NORMALIZED_UNSIGNED_SHORT -> unpackBuiltin(fieldName, packed, size, "unpackUnorm2x16");
case NORMALIZED_INT -> unpack(fieldName, packed, size, "int", "vec" + size, e -> e.div(2147483647f)
.clamp(-1, 1));
case NORMALIZED_UNSIGNED_INT ->
unpack(fieldName, packed, size, "uint", "vec" + size, e -> e.div(4294967295f));
case BYTE -> unpackByteBacked(fieldName, packed, size, "vec" + size, e -> e.cast("int")
.cast("float"));
case UNSIGNED_BYTE -> unpackByteBacked(fieldName, packed, size, "vec" + size, e -> e.cast("float"));
case SHORT -> unpackShortBacked(fieldName, packed, size, "vec" + size, e -> e.cast("int")
.cast("float"));
case UNSIGNED_SHORT -> unpackShortBacked(fieldName, packed, size, "vec" + size, e -> e.cast("float"));
case INT -> unpack(fieldName, packed, size, "int", "vec" + size, e -> e.cast("float"));
case UNSIGNED_INT -> unpack(fieldName, packed, size, "float", "vec" + size, e -> e.cast("float"));
case FLOAT -> unpack(fieldName, packed, size, "float", "vec" + size);
};
}
private static GlslExpr unpackUnsignedVector(String fieldName, UnsignedIntegerRepr unsignedIntegerRepr, GlslStruct packed, int size) {
return switch (unsignedIntegerRepr) {
case UNSIGNED_BYTE -> unpackByteBacked(fieldName, packed, size, "uvec" + size, e -> e.cast("uint"));
case UNSIGNED_SHORT -> unpackShortBacked(fieldName, packed, size, "uvec" + size, e -> e.cast("uint"));
case UNSIGNED_INT -> unpack(fieldName, packed, size, "uint", "uvec" + size);
};
}
private static GlslExpr unpackIntVector(String fieldName, IntegerRepr repr, GlslStruct packed, int size) {
return switch (repr) {
case BYTE -> unpackByteBacked(fieldName, packed, size, "ivec" + size, e -> e.cast("int"));
case SHORT -> unpackShortBacked(fieldName, packed, size, "ivec" + size, e -> e.cast("int"));
case INT -> unpack(fieldName, packed, size, "int", "ivec" + size);
};
}
@NotNull
private static GlslExpr unpack(String fieldName, GlslStruct packed, int size, String backingType, String outType) {
return unpack(fieldName, packed, size, backingType, outType, Function.identity());
}
@NotNull
private static GlslExpr unpack(String fieldName, GlslStruct packed, int size, String backingType, String outType, Function<GlslExpr, GlslExpr> perElement) {
List<GlslExpr> args = new ArrayList<>();
for (int i = 0; i < size; i++) {
var name = "_" + fieldName + "_" + i;
packed.addField(backingType, name);
args.add(UNPACKING_VARIABLE.access(name)
.transform(perElement));
}
return GlslExpr.call(outType, args);
}
@NotNull
private static GlslExpr unpackBuiltin(String fieldName, GlslStruct packed, int size, String func) {
packed.addField("uint", fieldName);
GlslExpr expr = UNPACKING_VARIABLE.access(fieldName)
.callFunction(func);
return switch (size) {
case 2 -> expr.swizzle("xy");
case 3 -> expr.swizzle("xyz");
case 4 -> expr;
default -> throw new IllegalArgumentException("Invalid vector size " + size);
};
}
@NotNull
private static GlslExpr unpackByteBacked(String fieldName, GlslStruct packed, int size, String outType, Function<GlslExpr, GlslExpr> perElement) {
packed.addField("uint", fieldName);
List<GlslExpr> args = new ArrayList<>();
for (int i = 0; i < size; i++) {
int bitPos = i * 8;
var element = UNPACKING_VARIABLE.access(fieldName)
.and(0xFF << bitPos)
.rsh(bitPos);
args.add(perElement.apply(element));
}
return GlslExpr.call(outType + size, args);
}
@NotNull
private static GlslExpr unpackShortBacked(String fieldName, GlslStruct packed, int size, String outType, Function<GlslExpr, GlslExpr> perElement) {
List<GlslExpr> args = new ArrayList<>();
for (int i = 0; i < size; i++) {
int unpackField = i / 2;
int bitPos = (i % 2) * 16;
var name = "_" + fieldName + "_" + unpackField;
if (bitPos == 0) {
// First time we're seeing this field, add it to the struct.
packed.addField("uint", name);
}
var element = UNPACKING_VARIABLE.access(name)
.and(0xFFFF << bitPos)
.rsh(bitPos);
args.add(perElement.apply(element));
}
return GlslExpr.call(outType, args);
}
private static GlslExpr unpackFloatScalar(String fieldName, FloatRepr repr, GlslStruct packed) {
return switch (repr) {
case BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF)
.cast("int")
.cast("float"));
case NORMALIZED_BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackSnorm4x8")
.swizzle("x"));
case UNSIGNED_BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF)
.cast("float"));
case NORMALIZED_UNSIGNED_BYTE ->
unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackUnorm4x8")
.swizzle("x"));
case SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF)
.cast("int")
.cast("float"));
case NORMALIZED_SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackSnorm2x16")
.swizzle("x"));
case UNSIGNED_SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF)
.cast("float"));
case NORMALIZED_UNSIGNED_SHORT ->
unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackUnorm2x16")
.swizzle("x"));
case INT -> unpackScalar(fieldName, packed, "int", e -> e.cast("float"));
case NORMALIZED_INT -> unpackScalar(fieldName, packed, "int", e -> e.div(2147483647f)
.clamp(-1, 1));
case UNSIGNED_INT -> unpackScalar(fieldName, packed, "uint", e -> e.cast("float"));
case NORMALIZED_UNSIGNED_INT -> unpackScalar(fieldName, packed, "uint", e -> e.div(4294967295f));
case FLOAT -> unpackScalar(fieldName, packed, "float");
};
}
private static GlslExpr unpackUnsignedScalar(String fieldName, UnsignedIntegerRepr repr, GlslStruct packed) {
return switch (repr) {
case UNSIGNED_BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF));
case UNSIGNED_SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF));
case UNSIGNED_INT -> unpackScalar(fieldName, packed, "uint");
};
}
private static GlslExpr unpackIntScalar(String fieldName, IntegerRepr intRepr, GlslStruct packed) {
return switch (intRepr) {
case BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF)
.cast("int"));
case SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF)
.cast("int"));
case INT -> unpackScalar(fieldName, packed, "int");
};
}
private static GlslExpr unpackScalar(String fieldName, GlslStruct packed, String packedType) {
return unpackScalar(fieldName, packed, packedType, Function.identity());
}
private static GlslExpr unpackScalar(String fieldName, GlslStruct packed, String packedType, Function<GlslExpr, GlslExpr> perElement) {
packed.addField(packedType, fieldName);
return perElement.apply(UNPACKING_VARIABLE.access(fieldName));
}
}

View file

@ -0,0 +1,15 @@
package com.jozufozu.flywheel.glsl.generate;
public enum BinOp {
BITWISE_AND("&"),
RIGHT_SHIFT(">>"),
DIVIDE("/"),
// TODO: add more as we need them
;
public final String op;
BinOp(String op) {
this.op = op;
}
}

View file

@ -32,15 +32,23 @@ public interface GlslExpr {
}
static GlslExpr intLiteral(int expr) {
return new IntLiteral(expr);
return new RawLiteral(Integer.toString(expr));
}
static GlslExpr uintLiteral(int expr) {
return new UIntLiteral(expr);
return new RawLiteral(Integer.toUnsignedString(expr) + 'u');
}
static GlslExpr uintHexLiteral(int expr) {
return new RawLiteral("0x" + Integer.toHexString(expr) + 'u');
}
static GlslExpr boolLiteral(boolean expr) {
return new BoolLiteral(expr);
return new RawLiteral(Boolean.toString(expr));
}
static GlslExpr floatLiteral(float expr) {
return new RawLiteral(Float.toString(expr));
}
/**
@ -53,6 +61,10 @@ public interface GlslExpr {
return new FunctionCall(name, this);
}
default FunctionCall cast(String name) {
return new FunctionCall(name, this);
}
/**
* Swizzle the components of this expression.
*
@ -83,6 +95,25 @@ public interface GlslExpr {
return f.apply(this);
}
default GlslExpr and(int mask) {
return new Binary(this, uintHexLiteral(mask), BinOp.BITWISE_AND);
}
default GlslExpr rsh(int by) {
if (by == 0) {
return this;
}
return new Binary(this, uintLiteral(by), BinOp.RIGHT_SHIFT);
}
default GlslExpr div(float v) {
return new Binary(this, floatLiteral(v), BinOp.DIVIDE);
}
default GlslExpr clamp(float from, float to) {
return new Clamp(this, floatLiteral(from), floatLiteral(to));
}
String prettyPrint();
record Variable(String name) implements GlslExpr {
@ -132,30 +163,24 @@ public interface GlslExpr {
}
record IntLiteral(int value) implements GlslExpr {
record Clamp(GlslExpr value, GlslExpr from, GlslExpr to) implements GlslExpr {
@Override
public String prettyPrint() {
return Integer.toString(value);
return "clamp(" + value.prettyPrint() + ", " + from.prettyPrint() + ", " + to.prettyPrint() + ")";
}
}
record UIntLiteral(int value) implements GlslExpr {
public UIntLiteral {
if (value < 0) {
throw new IllegalArgumentException("UIntLiteral must be positive");
record Binary(GlslExpr lhs, GlslExpr rhs, BinOp op) implements GlslExpr {
@Override
public String prettyPrint() {
return "(" + lhs.prettyPrint() + " " + op.op + " " + rhs.prettyPrint() + ")";
}
}
record RawLiteral(String value) implements GlslExpr {
@Override
public String prettyPrint() {
return Integer.toString(value) + 'u';
}
}
record BoolLiteral(boolean value) implements GlslExpr {
@Override
public String prettyPrint() {
return Boolean.toString(value);
return value;
}
}
}

View file

@ -7,12 +7,14 @@ import java.util.Set;
import org.jetbrains.annotations.Range;
import com.jozufozu.flywheel.api.layout.ElementType;
import com.jozufozu.flywheel.api.layout.FloatRepr;
import com.jozufozu.flywheel.api.layout.Layout;
import com.jozufozu.flywheel.api.layout.Layout.Element;
import com.jozufozu.flywheel.api.layout.LayoutBuilder;
import com.jozufozu.flywheel.api.layout.ValueRepr;
import com.jozufozu.flywheel.impl.layout.LayoutImpl.ElementImpl;
import com.jozufozu.flywheel.lib.math.MoreMath;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
@ -101,23 +103,21 @@ public class LayoutBuilderImpl implements LayoutBuilder {
);
private final List<Element> elements = new ArrayList<>();
private int offset = 0;
@Override
public LayoutBuilder scalar(String name, ValueRepr repr) {
elements.add(new ElementImpl(name, new ScalarElementTypeImpl(repr)));
return this;
return element(name, ScalarElementTypeImpl.create(repr));
}
@Override
public LayoutBuilder vector(String name, ValueRepr repr, @Range(from = 2, to = 4) int size) {
elements.add(new ElementImpl(name, new VectorElementTypeImpl(repr, size)));
return this;
return element(name, VectorElementTypeImpl.create(repr, size));
}
@Override
public LayoutBuilder matrix(String name, FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns) {
elements.add(new ElementImpl(name, new MatrixElementTypeImpl(repr, rows, columns)));
return this;
return element(name, MatrixElementTypeImpl.create(repr, rows, columns));
}
@Override
@ -125,6 +125,13 @@ public class LayoutBuilderImpl implements LayoutBuilder {
return matrix(name, repr, size, size);
}
private LayoutBuilder element(String name, ElementType type) {
elements.add(new ElementImpl(name, type, offset));
offset += type.byteSize();
offset = MoreMath.align4(offset);
return this;
}
@Override
public Layout build() {
Object2IntMap<String> name2IndexMap = new Object2IntOpenHashMap<>();

View file

@ -15,16 +15,16 @@ final class LayoutImpl implements Layout {
@Unmodifiable
private final List<Element> elements;
@Unmodifiable
private final Map<String, ElementType> map;
private final Map<String, Element> map;
private final int byteSize;
LayoutImpl(@Unmodifiable List<Element> elements) {
this.elements = elements;
Object2ObjectOpenHashMap<String, ElementType> map = new Object2ObjectOpenHashMap<>();
Object2ObjectOpenHashMap<String, Element> map = new Object2ObjectOpenHashMap<>();
int byteSize = 0;
for (Element element : this.elements) {
map.put(element.name(), element.type());
map.put(element.name(), element);
byteSize += element.type().byteSize();
}
map.trim();
@ -41,7 +41,7 @@ final class LayoutImpl implements Layout {
@Override
@Unmodifiable
public Map<String, ElementType> asMap() {
public Map<String, Element> asMap() {
return map;
}
@ -73,29 +73,6 @@ final class LayoutImpl implements Layout {
return elements.equals(other.elements);
}
record ElementImpl(String name, ElementType type) implements Element {
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + name.hashCode();
result = prime * result + type.hashCode();
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
ElementImpl other = (ElementImpl) obj;
return name.equals(other.name) && type.equals(other.type);
}
record ElementImpl(String name, ElementType type, int offset) implements Element {
}
}

View file

@ -5,72 +5,16 @@ import org.jetbrains.annotations.Range;
import com.jozufozu.flywheel.api.layout.FloatRepr;
import com.jozufozu.flywheel.api.layout.MatrixElementType;
final class MatrixElementTypeImpl implements MatrixElementType {
private final FloatRepr repr;
@Range(from = 2, to = 4)
private final int rows;
@Range(from = 2, to = 4)
private final int columns;
private final int byteSize;
MatrixElementTypeImpl(FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns) {
record MatrixElementTypeImpl(FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns,
int byteSize) implements MatrixElementType {
static MatrixElementTypeImpl create(FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns) {
if (rows < 2 || rows > 4) {
throw new IllegalArgumentException("Matrix element row count must be in range [2, 4]!");
}
if (columns < 2 || columns > 4) {
throw new IllegalArgumentException("Matrix element column count must be in range [2, 4]!");
}
this.repr = repr;
this.rows = rows;
this.columns = columns;
byteSize = repr.byteSize() * rows * columns;
}
@Override
public FloatRepr repr() {
return repr;
}
@Override
@Range(from = 2, to = 4)
public int rows() {
return rows;
}
@Override
@Range(from = 2, to = 4)
public int columns() {
return columns;
}
@Override
public int byteSize() {
return byteSize;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + repr.hashCode();
result = prime * result + rows;
result = prime * result + columns;
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
MatrixElementTypeImpl other = (MatrixElementTypeImpl) obj;
return repr == other.repr && rows == other.rows && columns == other.columns;
int byteSize = repr.byteSize() * rows * columns;
return new MatrixElementTypeImpl(repr, rows, columns, byteSize);
}
}

View file

@ -3,45 +3,8 @@ package com.jozufozu.flywheel.impl.layout;
import com.jozufozu.flywheel.api.layout.ScalarElementType;
import com.jozufozu.flywheel.api.layout.ValueRepr;
final class ScalarElementTypeImpl implements ScalarElementType {
private final ValueRepr repr;
private final int byteSize;
ScalarElementTypeImpl(ValueRepr repr) {
this.repr = repr;
byteSize = repr.byteSize();
}
@Override
public ValueRepr repr() {
return repr;
}
@Override
public int byteSize() {
return byteSize;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + repr.hashCode();
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
ScalarElementTypeImpl other = (ScalarElementTypeImpl) obj;
return repr == other.repr;
record ScalarElementTypeImpl(ValueRepr repr, int byteSize) implements ScalarElementType {
static ScalarElementTypeImpl create(ValueRepr repr) {
return new ScalarElementTypeImpl(repr, repr.byteSize());
}
}

View file

@ -5,59 +5,15 @@ import org.jetbrains.annotations.Range;
import com.jozufozu.flywheel.api.layout.ValueRepr;
import com.jozufozu.flywheel.api.layout.VectorElementType;
final class VectorElementTypeImpl implements VectorElementType {
private final ValueRepr repr;
@Range(from = 2, to = 4)
private final int size;
private final int byteSize;
record VectorElementTypeImpl(ValueRepr repr, @Range(from = 2, to = 4) int size,
int byteSize) implements VectorElementType {
VectorElementTypeImpl(ValueRepr repr, @Range(from = 2, to = 4) int size) {
static VectorElementTypeImpl create(ValueRepr repr, @Range(from = 2, to = 4) int size) {
if (size < 2 || size > 4) {
throw new IllegalArgumentException("Vector element size must be in range [2, 4]!");
}
this.repr = repr;
this.size = size;
byteSize = repr.byteSize() * size;
}
@Override
public ValueRepr repr() {
return repr;
}
@Override
@Range(from = 2, to = 4)
public int size() {
return size;
}
@Override
public int byteSize() {
return byteSize;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + repr.hashCode();
result = prime * result + size;
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
VectorElementTypeImpl other = (VectorElementTypeImpl) obj;
return repr == other.repr && size == other.size;
int byteSize = repr.byteSize() * size;
return new VectorElementTypeImpl(repr, size, byteSize);
}
}

View file

@ -10,7 +10,11 @@ public final class MoreMath {
public static final float SQRT_3_OVER_2 = (float) (Math.sqrt(3.0) / 2.0);
public static int align16(int numToRound) {
return (numToRound + 16 - 1) & -16;
return (numToRound + 15) & ~15;
}
public static int align4(int offset1) {
return (offset1 + 3) & ~3;
}
public static int ceilingDiv(int numerator, int denominator) {