mirror of
https://github.com/Jozufozu/Flywheel.git
synced 2025-01-06 04:16:36 +01:00
A lot to unpack here
- Elements are 4 byte aligned and store their offset. - Make all ElementTypeImpls records. - Implement unpacking for all types/reprs. - Add many utilities to GlslExpr to facilitate unpacking.
This commit is contained in:
parent
9309266435
commit
7ad163588e
10 changed files with 326 additions and 220 deletions
|
@ -13,8 +13,7 @@ public interface Layout {
|
|||
@Unmodifiable
|
||||
List<Element> elements();
|
||||
|
||||
@Unmodifiable
|
||||
Map<String, ElementType> asMap();
|
||||
@Unmodifiable Map<String, Element> asMap();
|
||||
|
||||
int byteSize();
|
||||
|
||||
|
@ -23,5 +22,7 @@ public interface Layout {
|
|||
String name();
|
||||
|
||||
ElementType type();
|
||||
|
||||
int offset();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,10 +4,19 @@ import java.util.ArrayList;
|
|||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import com.jozufozu.flywheel.Flywheel;
|
||||
import com.jozufozu.flywheel.api.instance.InstanceType;
|
||||
import com.jozufozu.flywheel.api.layout.FloatRepr;
|
||||
import com.jozufozu.flywheel.api.layout.IntegerRepr;
|
||||
import com.jozufozu.flywheel.api.layout.Layout;
|
||||
import com.jozufozu.flywheel.api.layout.MatrixElementType;
|
||||
import com.jozufozu.flywheel.api.layout.ScalarElementType;
|
||||
import com.jozufozu.flywheel.api.layout.UnsignedIntegerRepr;
|
||||
import com.jozufozu.flywheel.api.layout.VectorElementType;
|
||||
import com.jozufozu.flywheel.backend.compile.LayoutInterpreter;
|
||||
import com.jozufozu.flywheel.backend.compile.Pipeline;
|
||||
import com.jozufozu.flywheel.glsl.SourceComponent;
|
||||
|
@ -15,7 +24,7 @@ import com.jozufozu.flywheel.glsl.generate.FnSignature;
|
|||
import com.jozufozu.flywheel.glsl.generate.GlslBlock;
|
||||
import com.jozufozu.flywheel.glsl.generate.GlslBuilder;
|
||||
import com.jozufozu.flywheel.glsl.generate.GlslExpr;
|
||||
import com.jozufozu.flywheel.lib.layout.LayoutItem;
|
||||
import com.jozufozu.flywheel.glsl.generate.GlslStruct;
|
||||
|
||||
import net.minecraft.resources.ResourceLocation;
|
||||
|
||||
|
@ -27,10 +36,8 @@ public class IndirectComponent implements SourceComponent {
|
|||
private static final String UNPACK_FN_NAME = "_flw_unpackInstance";
|
||||
|
||||
private final Layout layout;
|
||||
private final List<LayoutItem> layoutItems;
|
||||
|
||||
public IndirectComponent(InstanceType<?> type) {
|
||||
this.layoutItems = type.oldLayout().layoutItems;
|
||||
this.layout = type.layout();
|
||||
}
|
||||
|
||||
|
@ -60,8 +67,6 @@ public class IndirectComponent implements SourceComponent {
|
|||
public String generateIndirect() {
|
||||
var builder = new GlslBuilder();
|
||||
|
||||
generateHeader(builder);
|
||||
|
||||
generateInstanceStruct(builder);
|
||||
|
||||
builder.blankLine();
|
||||
|
@ -73,13 +78,6 @@ public class IndirectComponent implements SourceComponent {
|
|||
return builder.build();
|
||||
}
|
||||
|
||||
private void generateHeader(GlslBuilder builder) {
|
||||
layoutItems.stream()
|
||||
.map(LayoutItem::type)
|
||||
.distinct()
|
||||
.forEach(type -> type.declare(builder));
|
||||
}
|
||||
|
||||
private void generateInstanceStruct(GlslBuilder builder) {
|
||||
var instance = builder.struct();
|
||||
instance.setName(STRUCT_NAME);
|
||||
|
@ -93,13 +91,9 @@ public class IndirectComponent implements SourceComponent {
|
|||
packed.setName(PACKED_STRUCT_NAME);
|
||||
|
||||
var unpackArgs = new ArrayList<GlslExpr>();
|
||||
for (LayoutItem field : layoutItems) {
|
||||
GlslExpr unpack = UNPACKING_VARIABLE.access(field.name())
|
||||
.transform(field.type()::unpack);
|
||||
unpackArgs.add(unpack);
|
||||
|
||||
packed.addField(field.type()
|
||||
.packedTypeName(), field.name());
|
||||
for (Layout.Element element : layout.elements()) {
|
||||
unpackArgs.add(unpackElement(element, packed));
|
||||
}
|
||||
|
||||
var block = new GlslBlock();
|
||||
|
@ -114,4 +108,224 @@ public class IndirectComponent implements SourceComponent {
|
|||
.build())
|
||||
.body(block);
|
||||
}
|
||||
|
||||
public static GlslExpr unpackElement(Layout.Element element, GlslStruct packed) {
|
||||
// FIXME: I don't think we're unpacking signed byte/short values correctly
|
||||
// FIXME: we definitely don't consider endianness. this all assumes little endian which works on my machine.
|
||||
var type = element.type();
|
||||
|
||||
if (type instanceof ScalarElementType scalar) {
|
||||
return unpackScalar(element.name(), packed, scalar);
|
||||
} else if (type instanceof VectorElementType vector) {
|
||||
return unpackVector(element.name(), packed, vector);
|
||||
} else if (type instanceof MatrixElementType matrix) {
|
||||
return unpackMatrix(element.name(), packed, matrix);
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException("Unknown type " + type);
|
||||
}
|
||||
|
||||
private static GlslExpr unpackMatrix(String name, GlslStruct packed, MatrixElementType matrix) {
|
||||
var repr = matrix.repr();
|
||||
|
||||
int rows = matrix.rows();
|
||||
int columns = matrix.columns();
|
||||
|
||||
List<GlslExpr> args = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < columns; i++) {
|
||||
args.add(unpackFloatVector(name + "_" + i, repr, packed, rows));
|
||||
}
|
||||
|
||||
return GlslExpr.call("mat" + columns + "x" + rows, args);
|
||||
}
|
||||
|
||||
private static GlslExpr unpackScalar(String fieldName, GlslStruct packed, ScalarElementType scalar) {
|
||||
var repr = scalar.repr();
|
||||
|
||||
if (repr instanceof FloatRepr floatRepr) {
|
||||
return unpackFloatScalar(fieldName, floatRepr, packed);
|
||||
} else if (repr instanceof IntegerRepr intRepr) {
|
||||
return unpackIntScalar(fieldName, intRepr, packed);
|
||||
} else if (repr instanceof UnsignedIntegerRepr unsignedIntegerRepr) {
|
||||
return unpackUnsignedScalar(fieldName, unsignedIntegerRepr, packed);
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException("Unknown repr " + repr);
|
||||
}
|
||||
|
||||
private static GlslExpr unpackVector(String fieldName, GlslStruct packed, VectorElementType vector) {
|
||||
var repr = vector.repr();
|
||||
|
||||
int size = vector.size();
|
||||
|
||||
if (repr instanceof FloatRepr floatRepr) {
|
||||
return unpackFloatVector(fieldName, floatRepr, packed, size);
|
||||
} else if (repr instanceof IntegerRepr intRepr) {
|
||||
return unpackIntVector(fieldName, intRepr, packed, size);
|
||||
} else if (repr instanceof UnsignedIntegerRepr unsignedIntegerRepr) {
|
||||
return unpackUnsignedVector(fieldName, unsignedIntegerRepr, packed, size);
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException("Unknown repr " + repr);
|
||||
}
|
||||
|
||||
private static GlslExpr unpackFloatVector(String fieldName, FloatRepr floatRepr, GlslStruct packed, int size) {
|
||||
return switch (floatRepr) {
|
||||
case NORMALIZED_BYTE -> unpackBuiltin(fieldName, packed, size, "unpackSnorm4x8");
|
||||
case NORMALIZED_UNSIGNED_BYTE -> unpackBuiltin(fieldName, packed, size, "unpackUnorm4x8");
|
||||
case NORMALIZED_SHORT -> unpackBuiltin(fieldName, packed, size, "unpackSnorm2x16");
|
||||
case NORMALIZED_UNSIGNED_SHORT -> unpackBuiltin(fieldName, packed, size, "unpackUnorm2x16");
|
||||
case NORMALIZED_INT -> unpack(fieldName, packed, size, "int", "vec" + size, e -> e.div(2147483647f)
|
||||
.clamp(-1, 1));
|
||||
case NORMALIZED_UNSIGNED_INT ->
|
||||
unpack(fieldName, packed, size, "uint", "vec" + size, e -> e.div(4294967295f));
|
||||
case BYTE -> unpackByteBacked(fieldName, packed, size, "vec" + size, e -> e.cast("int")
|
||||
.cast("float"));
|
||||
case UNSIGNED_BYTE -> unpackByteBacked(fieldName, packed, size, "vec" + size, e -> e.cast("float"));
|
||||
case SHORT -> unpackShortBacked(fieldName, packed, size, "vec" + size, e -> e.cast("int")
|
||||
.cast("float"));
|
||||
case UNSIGNED_SHORT -> unpackShortBacked(fieldName, packed, size, "vec" + size, e -> e.cast("float"));
|
||||
case INT -> unpack(fieldName, packed, size, "int", "vec" + size, e -> e.cast("float"));
|
||||
case UNSIGNED_INT -> unpack(fieldName, packed, size, "float", "vec" + size, e -> e.cast("float"));
|
||||
case FLOAT -> unpack(fieldName, packed, size, "float", "vec" + size);
|
||||
};
|
||||
}
|
||||
|
||||
private static GlslExpr unpackUnsignedVector(String fieldName, UnsignedIntegerRepr unsignedIntegerRepr, GlslStruct packed, int size) {
|
||||
return switch (unsignedIntegerRepr) {
|
||||
case UNSIGNED_BYTE -> unpackByteBacked(fieldName, packed, size, "uvec" + size, e -> e.cast("uint"));
|
||||
case UNSIGNED_SHORT -> unpackShortBacked(fieldName, packed, size, "uvec" + size, e -> e.cast("uint"));
|
||||
case UNSIGNED_INT -> unpack(fieldName, packed, size, "uint", "uvec" + size);
|
||||
};
|
||||
}
|
||||
|
||||
private static GlslExpr unpackIntVector(String fieldName, IntegerRepr repr, GlslStruct packed, int size) {
|
||||
return switch (repr) {
|
||||
case BYTE -> unpackByteBacked(fieldName, packed, size, "ivec" + size, e -> e.cast("int"));
|
||||
case SHORT -> unpackShortBacked(fieldName, packed, size, "ivec" + size, e -> e.cast("int"));
|
||||
case INT -> unpack(fieldName, packed, size, "int", "ivec" + size);
|
||||
};
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static GlslExpr unpack(String fieldName, GlslStruct packed, int size, String backingType, String outType) {
|
||||
return unpack(fieldName, packed, size, backingType, outType, Function.identity());
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static GlslExpr unpack(String fieldName, GlslStruct packed, int size, String backingType, String outType, Function<GlslExpr, GlslExpr> perElement) {
|
||||
List<GlslExpr> args = new ArrayList<>();
|
||||
for (int i = 0; i < size; i++) {
|
||||
var name = "_" + fieldName + "_" + i;
|
||||
packed.addField(backingType, name);
|
||||
args.add(UNPACKING_VARIABLE.access(name)
|
||||
.transform(perElement));
|
||||
}
|
||||
return GlslExpr.call(outType, args);
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static GlslExpr unpackBuiltin(String fieldName, GlslStruct packed, int size, String func) {
|
||||
packed.addField("uint", fieldName);
|
||||
GlslExpr expr = UNPACKING_VARIABLE.access(fieldName)
|
||||
.callFunction(func);
|
||||
return switch (size) {
|
||||
case 2 -> expr.swizzle("xy");
|
||||
case 3 -> expr.swizzle("xyz");
|
||||
case 4 -> expr;
|
||||
default -> throw new IllegalArgumentException("Invalid vector size " + size);
|
||||
};
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static GlslExpr unpackByteBacked(String fieldName, GlslStruct packed, int size, String outType, Function<GlslExpr, GlslExpr> perElement) {
|
||||
packed.addField("uint", fieldName);
|
||||
List<GlslExpr> args = new ArrayList<>();
|
||||
for (int i = 0; i < size; i++) {
|
||||
int bitPos = i * 8;
|
||||
var element = UNPACKING_VARIABLE.access(fieldName)
|
||||
.and(0xFF << bitPos)
|
||||
.rsh(bitPos);
|
||||
args.add(perElement.apply(element));
|
||||
}
|
||||
return GlslExpr.call(outType + size, args);
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static GlslExpr unpackShortBacked(String fieldName, GlslStruct packed, int size, String outType, Function<GlslExpr, GlslExpr> perElement) {
|
||||
List<GlslExpr> args = new ArrayList<>();
|
||||
for (int i = 0; i < size; i++) {
|
||||
int unpackField = i / 2;
|
||||
int bitPos = (i % 2) * 16;
|
||||
var name = "_" + fieldName + "_" + unpackField;
|
||||
if (bitPos == 0) {
|
||||
// First time we're seeing this field, add it to the struct.
|
||||
packed.addField("uint", name);
|
||||
}
|
||||
var element = UNPACKING_VARIABLE.access(name)
|
||||
.and(0xFFFF << bitPos)
|
||||
.rsh(bitPos);
|
||||
args.add(perElement.apply(element));
|
||||
}
|
||||
return GlslExpr.call(outType, args);
|
||||
}
|
||||
|
||||
private static GlslExpr unpackFloatScalar(String fieldName, FloatRepr repr, GlslStruct packed) {
|
||||
return switch (repr) {
|
||||
case BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF)
|
||||
.cast("int")
|
||||
.cast("float"));
|
||||
case NORMALIZED_BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackSnorm4x8")
|
||||
.swizzle("x"));
|
||||
case UNSIGNED_BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF)
|
||||
.cast("float"));
|
||||
case NORMALIZED_UNSIGNED_BYTE ->
|
||||
unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackUnorm4x8")
|
||||
.swizzle("x"));
|
||||
case SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF)
|
||||
.cast("int")
|
||||
.cast("float"));
|
||||
case NORMALIZED_SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackSnorm2x16")
|
||||
.swizzle("x"));
|
||||
case UNSIGNED_SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF)
|
||||
.cast("float"));
|
||||
case NORMALIZED_UNSIGNED_SHORT ->
|
||||
unpackScalar(fieldName, packed, "uint", e -> e.callFunction("unpackUnorm2x16")
|
||||
.swizzle("x"));
|
||||
case INT -> unpackScalar(fieldName, packed, "int", e -> e.cast("float"));
|
||||
case NORMALIZED_INT -> unpackScalar(fieldName, packed, "int", e -> e.div(2147483647f)
|
||||
.clamp(-1, 1));
|
||||
case UNSIGNED_INT -> unpackScalar(fieldName, packed, "uint", e -> e.cast("float"));
|
||||
case NORMALIZED_UNSIGNED_INT -> unpackScalar(fieldName, packed, "uint", e -> e.div(4294967295f));
|
||||
case FLOAT -> unpackScalar(fieldName, packed, "float");
|
||||
};
|
||||
}
|
||||
|
||||
private static GlslExpr unpackUnsignedScalar(String fieldName, UnsignedIntegerRepr repr, GlslStruct packed) {
|
||||
return switch (repr) {
|
||||
case UNSIGNED_BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF));
|
||||
case UNSIGNED_SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF));
|
||||
case UNSIGNED_INT -> unpackScalar(fieldName, packed, "uint");
|
||||
};
|
||||
}
|
||||
|
||||
private static GlslExpr unpackIntScalar(String fieldName, IntegerRepr intRepr, GlslStruct packed) {
|
||||
return switch (intRepr) {
|
||||
case BYTE -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFF)
|
||||
.cast("int"));
|
||||
case SHORT -> unpackScalar(fieldName, packed, "uint", e -> e.and(0xFFFF)
|
||||
.cast("int"));
|
||||
case INT -> unpackScalar(fieldName, packed, "int");
|
||||
};
|
||||
}
|
||||
|
||||
private static GlslExpr unpackScalar(String fieldName, GlslStruct packed, String packedType) {
|
||||
return unpackScalar(fieldName, packed, packedType, Function.identity());
|
||||
}
|
||||
|
||||
private static GlslExpr unpackScalar(String fieldName, GlslStruct packed, String packedType, Function<GlslExpr, GlslExpr> perElement) {
|
||||
packed.addField(packedType, fieldName);
|
||||
return perElement.apply(UNPACKING_VARIABLE.access(fieldName));
|
||||
}
|
||||
}
|
||||
|
|
15
src/main/java/com/jozufozu/flywheel/glsl/generate/BinOp.java
Normal file
15
src/main/java/com/jozufozu/flywheel/glsl/generate/BinOp.java
Normal file
|
@ -0,0 +1,15 @@
|
|||
package com.jozufozu.flywheel.glsl.generate;
|
||||
|
||||
public enum BinOp {
|
||||
BITWISE_AND("&"),
|
||||
RIGHT_SHIFT(">>"),
|
||||
DIVIDE("/"),
|
||||
// TODO: add more as we need them
|
||||
;
|
||||
|
||||
public final String op;
|
||||
|
||||
BinOp(String op) {
|
||||
this.op = op;
|
||||
}
|
||||
}
|
|
@ -32,15 +32,23 @@ public interface GlslExpr {
|
|||
}
|
||||
|
||||
static GlslExpr intLiteral(int expr) {
|
||||
return new IntLiteral(expr);
|
||||
return new RawLiteral(Integer.toString(expr));
|
||||
}
|
||||
|
||||
static GlslExpr uintLiteral(int expr) {
|
||||
return new UIntLiteral(expr);
|
||||
return new RawLiteral(Integer.toUnsignedString(expr) + 'u');
|
||||
}
|
||||
|
||||
static GlslExpr uintHexLiteral(int expr) {
|
||||
return new RawLiteral("0x" + Integer.toHexString(expr) + 'u');
|
||||
}
|
||||
|
||||
static GlslExpr boolLiteral(boolean expr) {
|
||||
return new BoolLiteral(expr);
|
||||
return new RawLiteral(Boolean.toString(expr));
|
||||
}
|
||||
|
||||
static GlslExpr floatLiteral(float expr) {
|
||||
return new RawLiteral(Float.toString(expr));
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -53,6 +61,10 @@ public interface GlslExpr {
|
|||
return new FunctionCall(name, this);
|
||||
}
|
||||
|
||||
default FunctionCall cast(String name) {
|
||||
return new FunctionCall(name, this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Swizzle the components of this expression.
|
||||
*
|
||||
|
@ -83,6 +95,25 @@ public interface GlslExpr {
|
|||
return f.apply(this);
|
||||
}
|
||||
|
||||
default GlslExpr and(int mask) {
|
||||
return new Binary(this, uintHexLiteral(mask), BinOp.BITWISE_AND);
|
||||
}
|
||||
|
||||
default GlslExpr rsh(int by) {
|
||||
if (by == 0) {
|
||||
return this;
|
||||
}
|
||||
return new Binary(this, uintLiteral(by), BinOp.RIGHT_SHIFT);
|
||||
}
|
||||
|
||||
default GlslExpr div(float v) {
|
||||
return new Binary(this, floatLiteral(v), BinOp.DIVIDE);
|
||||
}
|
||||
|
||||
default GlslExpr clamp(float from, float to) {
|
||||
return new Clamp(this, floatLiteral(from), floatLiteral(to));
|
||||
}
|
||||
|
||||
String prettyPrint();
|
||||
|
||||
record Variable(String name) implements GlslExpr {
|
||||
|
@ -132,30 +163,24 @@ public interface GlslExpr {
|
|||
|
||||
}
|
||||
|
||||
record IntLiteral(int value) implements GlslExpr {
|
||||
record Clamp(GlslExpr value, GlslExpr from, GlslExpr to) implements GlslExpr {
|
||||
@Override
|
||||
public String prettyPrint() {
|
||||
return Integer.toString(value);
|
||||
return "clamp(" + value.prettyPrint() + ", " + from.prettyPrint() + ", " + to.prettyPrint() + ")";
|
||||
}
|
||||
}
|
||||
|
||||
record UIntLiteral(int value) implements GlslExpr {
|
||||
public UIntLiteral {
|
||||
if (value < 0) {
|
||||
throw new IllegalArgumentException("UIntLiteral must be positive");
|
||||
}
|
||||
}
|
||||
|
||||
record Binary(GlslExpr lhs, GlslExpr rhs, BinOp op) implements GlslExpr {
|
||||
@Override
|
||||
public String prettyPrint() {
|
||||
return Integer.toString(value) + 'u';
|
||||
return "(" + lhs.prettyPrint() + " " + op.op + " " + rhs.prettyPrint() + ")";
|
||||
}
|
||||
}
|
||||
|
||||
record BoolLiteral(boolean value) implements GlslExpr {
|
||||
record RawLiteral(String value) implements GlslExpr {
|
||||
@Override
|
||||
public String prettyPrint() {
|
||||
return Boolean.toString(value);
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,12 +7,14 @@ import java.util.Set;
|
|||
|
||||
import org.jetbrains.annotations.Range;
|
||||
|
||||
import com.jozufozu.flywheel.api.layout.ElementType;
|
||||
import com.jozufozu.flywheel.api.layout.FloatRepr;
|
||||
import com.jozufozu.flywheel.api.layout.Layout;
|
||||
import com.jozufozu.flywheel.api.layout.Layout.Element;
|
||||
import com.jozufozu.flywheel.api.layout.LayoutBuilder;
|
||||
import com.jozufozu.flywheel.api.layout.ValueRepr;
|
||||
import com.jozufozu.flywheel.impl.layout.LayoutImpl.ElementImpl;
|
||||
import com.jozufozu.flywheel.lib.math.MoreMath;
|
||||
|
||||
import it.unimi.dsi.fastutil.objects.Object2IntMap;
|
||||
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
|
||||
|
@ -101,23 +103,21 @@ public class LayoutBuilderImpl implements LayoutBuilder {
|
|||
);
|
||||
|
||||
private final List<Element> elements = new ArrayList<>();
|
||||
private int offset = 0;
|
||||
|
||||
@Override
|
||||
public LayoutBuilder scalar(String name, ValueRepr repr) {
|
||||
elements.add(new ElementImpl(name, new ScalarElementTypeImpl(repr)));
|
||||
return this;
|
||||
return element(name, ScalarElementTypeImpl.create(repr));
|
||||
}
|
||||
|
||||
@Override
|
||||
public LayoutBuilder vector(String name, ValueRepr repr, @Range(from = 2, to = 4) int size) {
|
||||
elements.add(new ElementImpl(name, new VectorElementTypeImpl(repr, size)));
|
||||
return this;
|
||||
return element(name, VectorElementTypeImpl.create(repr, size));
|
||||
}
|
||||
|
||||
@Override
|
||||
public LayoutBuilder matrix(String name, FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns) {
|
||||
elements.add(new ElementImpl(name, new MatrixElementTypeImpl(repr, rows, columns)));
|
||||
return this;
|
||||
return element(name, MatrixElementTypeImpl.create(repr, rows, columns));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -125,6 +125,13 @@ public class LayoutBuilderImpl implements LayoutBuilder {
|
|||
return matrix(name, repr, size, size);
|
||||
}
|
||||
|
||||
private LayoutBuilder element(String name, ElementType type) {
|
||||
elements.add(new ElementImpl(name, type, offset));
|
||||
offset += type.byteSize();
|
||||
offset = MoreMath.align4(offset);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Layout build() {
|
||||
Object2IntMap<String> name2IndexMap = new Object2IntOpenHashMap<>();
|
||||
|
|
|
@ -15,16 +15,16 @@ final class LayoutImpl implements Layout {
|
|||
@Unmodifiable
|
||||
private final List<Element> elements;
|
||||
@Unmodifiable
|
||||
private final Map<String, ElementType> map;
|
||||
private final Map<String, Element> map;
|
||||
private final int byteSize;
|
||||
|
||||
LayoutImpl(@Unmodifiable List<Element> elements) {
|
||||
this.elements = elements;
|
||||
|
||||
Object2ObjectOpenHashMap<String, ElementType> map = new Object2ObjectOpenHashMap<>();
|
||||
Object2ObjectOpenHashMap<String, Element> map = new Object2ObjectOpenHashMap<>();
|
||||
int byteSize = 0;
|
||||
for (Element element : this.elements) {
|
||||
map.put(element.name(), element.type());
|
||||
map.put(element.name(), element);
|
||||
byteSize += element.type().byteSize();
|
||||
}
|
||||
map.trim();
|
||||
|
@ -41,7 +41,7 @@ final class LayoutImpl implements Layout {
|
|||
|
||||
@Override
|
||||
@Unmodifiable
|
||||
public Map<String, ElementType> asMap() {
|
||||
public Map<String, Element> asMap() {
|
||||
return map;
|
||||
}
|
||||
|
||||
|
@ -73,29 +73,6 @@ final class LayoutImpl implements Layout {
|
|||
return elements.equals(other.elements);
|
||||
}
|
||||
|
||||
record ElementImpl(String name, ElementType type) implements Element {
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + name.hashCode();
|
||||
result = prime * result + type.hashCode();
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
ElementImpl other = (ElementImpl) obj;
|
||||
return name.equals(other.name) && type.equals(other.type);
|
||||
}
|
||||
record ElementImpl(String name, ElementType type, int offset) implements Element {
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,72 +5,16 @@ import org.jetbrains.annotations.Range;
|
|||
import com.jozufozu.flywheel.api.layout.FloatRepr;
|
||||
import com.jozufozu.flywheel.api.layout.MatrixElementType;
|
||||
|
||||
final class MatrixElementTypeImpl implements MatrixElementType {
|
||||
private final FloatRepr repr;
|
||||
@Range(from = 2, to = 4)
|
||||
private final int rows;
|
||||
@Range(from = 2, to = 4)
|
||||
private final int columns;
|
||||
private final int byteSize;
|
||||
|
||||
MatrixElementTypeImpl(FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns) {
|
||||
record MatrixElementTypeImpl(FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns,
|
||||
int byteSize) implements MatrixElementType {
|
||||
static MatrixElementTypeImpl create(FloatRepr repr, @Range(from = 2, to = 4) int rows, @Range(from = 2, to = 4) int columns) {
|
||||
if (rows < 2 || rows > 4) {
|
||||
throw new IllegalArgumentException("Matrix element row count must be in range [2, 4]!");
|
||||
}
|
||||
if (columns < 2 || columns > 4) {
|
||||
throw new IllegalArgumentException("Matrix element column count must be in range [2, 4]!");
|
||||
}
|
||||
|
||||
this.repr = repr;
|
||||
this.rows = rows;
|
||||
this.columns = columns;
|
||||
byteSize = repr.byteSize() * rows * columns;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatRepr repr() {
|
||||
return repr;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Range(from = 2, to = 4)
|
||||
public int rows() {
|
||||
return rows;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Range(from = 2, to = 4)
|
||||
public int columns() {
|
||||
return columns;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int byteSize() {
|
||||
return byteSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + repr.hashCode();
|
||||
result = prime * result + rows;
|
||||
result = prime * result + columns;
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
MatrixElementTypeImpl other = (MatrixElementTypeImpl) obj;
|
||||
return repr == other.repr && rows == other.rows && columns == other.columns;
|
||||
int byteSize = repr.byteSize() * rows * columns;
|
||||
return new MatrixElementTypeImpl(repr, rows, columns, byteSize);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,45 +3,8 @@ package com.jozufozu.flywheel.impl.layout;
|
|||
import com.jozufozu.flywheel.api.layout.ScalarElementType;
|
||||
import com.jozufozu.flywheel.api.layout.ValueRepr;
|
||||
|
||||
final class ScalarElementTypeImpl implements ScalarElementType {
|
||||
private final ValueRepr repr;
|
||||
private final int byteSize;
|
||||
|
||||
ScalarElementTypeImpl(ValueRepr repr) {
|
||||
this.repr = repr;
|
||||
byteSize = repr.byteSize();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ValueRepr repr() {
|
||||
return repr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int byteSize() {
|
||||
return byteSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + repr.hashCode();
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
ScalarElementTypeImpl other = (ScalarElementTypeImpl) obj;
|
||||
return repr == other.repr;
|
||||
record ScalarElementTypeImpl(ValueRepr repr, int byteSize) implements ScalarElementType {
|
||||
static ScalarElementTypeImpl create(ValueRepr repr) {
|
||||
return new ScalarElementTypeImpl(repr, repr.byteSize());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,59 +5,15 @@ import org.jetbrains.annotations.Range;
|
|||
import com.jozufozu.flywheel.api.layout.ValueRepr;
|
||||
import com.jozufozu.flywheel.api.layout.VectorElementType;
|
||||
|
||||
final class VectorElementTypeImpl implements VectorElementType {
|
||||
private final ValueRepr repr;
|
||||
@Range(from = 2, to = 4)
|
||||
private final int size;
|
||||
private final int byteSize;
|
||||
record VectorElementTypeImpl(ValueRepr repr, @Range(from = 2, to = 4) int size,
|
||||
int byteSize) implements VectorElementType {
|
||||
|
||||
VectorElementTypeImpl(ValueRepr repr, @Range(from = 2, to = 4) int size) {
|
||||
static VectorElementTypeImpl create(ValueRepr repr, @Range(from = 2, to = 4) int size) {
|
||||
if (size < 2 || size > 4) {
|
||||
throw new IllegalArgumentException("Vector element size must be in range [2, 4]!");
|
||||
}
|
||||
|
||||
this.repr = repr;
|
||||
this.size = size;
|
||||
byteSize = repr.byteSize() * size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ValueRepr repr() {
|
||||
return repr;
|
||||
}
|
||||
|
||||
@Override
|
||||
@Range(from = 2, to = 4)
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int byteSize() {
|
||||
return byteSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
final int prime = 31;
|
||||
int result = 1;
|
||||
result = prime * result + repr.hashCode();
|
||||
result = prime * result + size;
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (this == obj) {
|
||||
return true;
|
||||
}
|
||||
if (obj == null) {
|
||||
return false;
|
||||
}
|
||||
if (getClass() != obj.getClass()) {
|
||||
return false;
|
||||
}
|
||||
VectorElementTypeImpl other = (VectorElementTypeImpl) obj;
|
||||
return repr == other.repr && size == other.size;
|
||||
int byteSize = repr.byteSize() * size;
|
||||
return new VectorElementTypeImpl(repr, size, byteSize);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,11 @@ public final class MoreMath {
|
|||
public static final float SQRT_3_OVER_2 = (float) (Math.sqrt(3.0) / 2.0);
|
||||
|
||||
public static int align16(int numToRound) {
|
||||
return (numToRound + 16 - 1) & -16;
|
||||
return (numToRound + 15) & ~15;
|
||||
}
|
||||
|
||||
public static int align4(int offset1) {
|
||||
return (offset1 + 3) & ~3;
|
||||
}
|
||||
|
||||
public static int ceilingDiv(int numerator, int denominator) {
|
||||
|
|
Loading…
Reference in a new issue