Better memoized shader compilation and global game state

- Properly separate compilation of vertex and fragment shaders
 - Game state is no longer per-program
 - Needs organization
This commit is contained in:
Jozufozu 2022-01-12 00:19:37 -08:00
parent c4f07db75f
commit 2854e1f1dc
14 changed files with 187 additions and 140 deletions

View file

@ -3,7 +3,9 @@ package com.jozufozu.flywheel.backend;
import java.util.HashMap;
import java.util.Map;
import com.jozufozu.flywheel.core.compile.ShaderConstants;
import com.jozufozu.flywheel.core.shader.GameStateProvider;
import com.jozufozu.flywheel.core.shader.StateSnapshot;
import net.minecraft.resources.ResourceLocation;
@ -32,4 +34,28 @@ public class GameStateRegistry {
registeredStateProviders.put(context.getID(), context);
}
public static StateSnapshot takeSnapshot() {
long ctx = 0;
for (GameStateProvider state : registeredStateProviders.values()) {
if (state.isTrue()) {
ctx |= 1;
}
ctx <<= 1;
}
return new StateSnapshot(ctx);
}
public static ShaderConstants getDefines(long ctx) {
long stateID = ctx;
ShaderConstants shaderConstants = new ShaderConstants();
for (GameStateProvider state : registeredStateProviders.values()) {
if ((stateID & 1) == 1) {
state.alterConstants(shaderConstants);
}
stateID >>= 1;
}
return shaderConstants;
}
}

View file

@ -268,4 +268,8 @@ public class SourceFile {
return -1;
}
@Override
public String toString() {
return name.toString();
}
}

View file

@ -1,12 +1,14 @@
package com.jozufozu.flywheel.core.compile;
import java.util.Objects;
import com.jozufozu.flywheel.backend.gl.shader.GlShader;
import com.jozufozu.flywheel.backend.gl.shader.ShaderType;
import com.jozufozu.flywheel.backend.source.FileResolution;
import com.jozufozu.flywheel.backend.source.SourceFile;
import com.jozufozu.flywheel.core.shader.ProgramSpec;
import com.jozufozu.flywheel.core.shader.StateSnapshot;
public class FragmentCompiler extends Memoizer<ProgramContext, GlShader> {
public class FragmentCompiler extends Memoizer<FragmentCompiler.Context, GlShader> {
private final FileResolution header;
private final Template<FragmentTemplateData> fragment;
@ -16,9 +18,8 @@ public class FragmentCompiler extends Memoizer<ProgramContext, GlShader> {
}
@Override
protected GlShader _create(ProgramContext key) {
ProgramSpec spec = key.spec();
SourceFile fragmentFile = spec.getFragmentFile();
protected GlShader _create(Context key) {
SourceFile fragmentFile = key.file;
FragmentTemplateData appliedTemplate = fragment.apply(fragmentFile);
StringBuilder builder = new StringBuilder();
@ -34,11 +35,52 @@ public class FragmentCompiler extends Memoizer<ProgramContext, GlShader> {
builder.append(appliedTemplate.generateFooter());
return new GlShader(spec.name, ShaderType.FRAGMENT, builder.toString());
return new GlShader(fragmentFile.name, ShaderType.FRAGMENT, builder.toString());
}
@Override
protected void _destroy(GlShader value) {
value.delete();
}
public static final class Context {
private final SourceFile file;
private final StateSnapshot ctx;
private final float alphaDiscard;
public Context(SourceFile file, StateSnapshot ctx, float alphaDiscard) {
this.file = file;
this.ctx = ctx;
this.alphaDiscard = alphaDiscard;
}
public ShaderConstants getShaderConstants() {
ShaderConstants shaderConstants = ctx.getDefines();
if (alphaDiscard > 0) {
shaderConstants.define("ALPHA_DISCARD", alphaDiscard);
}
return shaderConstants;
}
@Override
public boolean equals(Object obj) {
if (obj == this) return true;
if (obj == null || obj.getClass() != this.getClass()) return false;
var that = (Context) obj;
return this.file == that.file && Objects.equals(this.ctx, that.ctx) && Float.floatToIntBits(this.alphaDiscard) == Float.floatToIntBits(that.alphaDiscard);
}
@Override
public int hashCode() {
return Objects.hash(file, ctx, alphaDiscard);
}
@Override
public String toString() {
return "Context[" + "file=" + file + ", " + "ctx=" + ctx + ", " + "alphaDiscard=" + alphaDiscard + ']';
}
}
}

View file

@ -62,9 +62,9 @@ public class ProgramCompiler<P extends GlProgram> extends Memoizer<ProgramContex
@Override
protected P _create(ProgramContext ctx) {
return new ProgramAssembler(ctx.spec().name)
.attachShader(vertexCompiler.get(ctx))
.attachShader(fragmentCompiler.get(ctx))
return new ProgramAssembler(ctx.spec.name)
.attachShader(vertexCompiler.get(new VertexCompiler.Context(ctx.spec.getVertexFile(), ctx.ctx, ctx.vertexType)))
.attachShader(fragmentCompiler.get(new FragmentCompiler.Context(ctx.spec.getFragmentFile(), ctx.ctx, ctx.alphaDiscard)))
.link()
.build(this.factory);
}

View file

@ -6,21 +6,18 @@ import javax.annotation.Nullable;
import com.jozufozu.flywheel.api.vertex.VertexType;
import com.jozufozu.flywheel.backend.Backend;
import com.jozufozu.flywheel.backend.GameStateRegistry;
import com.jozufozu.flywheel.backend.RenderLayer;
import com.jozufozu.flywheel.core.shader.ProgramSpec;
import com.jozufozu.flywheel.core.shader.StateSnapshot;
import net.minecraft.resources.ResourceLocation;
/**
* Represents the entire context of a program's usage.
*
* @param alphaDiscard Alpha threshold below which pixels are discarded.
* @param vertexType The vertexType the program should be adapted for.
* @param spec The generic program name.
* @param ctx An ID representing the state at the time of usage.
*/
public record ProgramContext(float alphaDiscard, VertexType vertexType, ProgramSpec spec, long ctx) {
public final class ProgramContext {
/**
* Creates a compilation context for the given program, vertex type and render layer.
*
@ -37,32 +34,7 @@ public record ProgramContext(float alphaDiscard, VertexType vertexType, ProgramS
throw new NullPointerException("Cannot compile shader because '" + programName + "' is not recognized.");
}
return new ProgramContext(getAlphaDiscard(layer), vertexType, spec, spec.getCurrentStateID());
}
public ShaderConstants getShaderConstants() {
ShaderConstants shaderConstants = new ShaderConstants();
shaderConstants.defineAll(spec.getDefines(ctx));
if (alphaDiscard > 0) {
shaderConstants.define("ALPHA_DISCARD", alphaDiscard);
}
return shaderConstants;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ProgramContext that = (ProgramContext) o;
// override for instance equality on vertexType
return alphaDiscard == that.alphaDiscard && ctx == that.ctx && vertexType == that.vertexType && spec.equals(that.spec);
}
@Override
public int hashCode() {
return Objects.hash(alphaDiscard, vertexType, spec, ctx);
return new ProgramContext(spec, getAlphaDiscard(layer), vertexType, GameStateRegistry.takeSnapshot());
}
/**
@ -74,4 +46,40 @@ public record ProgramContext(float alphaDiscard, VertexType vertexType, ProgramS
public static float getAlphaDiscard(@Nullable RenderLayer layer) {
return layer == RenderLayer.CUTOUT ? 0.1f : 0f;
}
public final ProgramSpec spec;
public final float alphaDiscard;
public final VertexType vertexType;
public final StateSnapshot ctx;
/**
* @param spec The program to use.
* @param alphaDiscard Alpha threshold below which pixels are discarded.
* @param vertexType The vertexType the program should be adapted for.
* @param ctx A snapshot of the game state.
*/
public ProgramContext(ProgramSpec spec, float alphaDiscard, VertexType vertexType, StateSnapshot ctx) {
this.spec = spec;
this.alphaDiscard = alphaDiscard;
this.vertexType = vertexType;
this.ctx = ctx;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
var that = (ProgramContext) o;
return spec == that.spec && vertexType == that.vertexType && ctx.equals(that.ctx) && Float.floatToIntBits(alphaDiscard) == Float.floatToIntBits(that.alphaDiscard);
}
@Override
public int hashCode() {
return Objects.hash(spec, alphaDiscard, vertexType, ctx);
}
@Override
public String toString() {
return "ProgramContext{" + "spec=" + spec + ", alphaDiscard=" + alphaDiscard + ", vertexType=" + vertexType + ", ctx=" + ctx + '}';
}
}

View file

@ -1,12 +1,15 @@
package com.jozufozu.flywheel.core.compile;
import java.util.Objects;
import com.jozufozu.flywheel.api.vertex.VertexType;
import com.jozufozu.flywheel.backend.gl.shader.GlShader;
import com.jozufozu.flywheel.backend.gl.shader.ShaderType;
import com.jozufozu.flywheel.backend.source.FileResolution;
import com.jozufozu.flywheel.backend.source.SourceFile;
import com.jozufozu.flywheel.core.shader.ProgramSpec;
import com.jozufozu.flywheel.core.shader.StateSnapshot;
public class VertexCompiler extends Memoizer<ProgramContext, GlShader> {
public class VertexCompiler extends Memoizer<VertexCompiler.Context, GlShader> {
private final Template<? extends VertexData> template;
private final FileResolution header;
@ -16,12 +19,12 @@ public class VertexCompiler extends Memoizer<ProgramContext, GlShader> {
}
@Override
protected GlShader _create(ProgramContext key) {
protected GlShader _create(Context key) {
StringBuilder finalSource = new StringBuilder();
finalSource.append(CompileUtil.generateHeader(template.getVersion(), ShaderType.VERTEX));
key.getShaderConstants().writeInto(finalSource);
key.ctx.getDefines().writeInto(finalSource);
finalSource.append("""
struct Vertex {
@ -32,25 +35,47 @@ public class VertexCompiler extends Memoizer<ProgramContext, GlShader> {
vec3 normal;
};
""");
finalSource.append(key.vertexType()
.getShaderHeader());
finalSource.append(key.vertexType.getShaderHeader());
FileIndexImpl index = new FileIndexImpl();
header.getFile().generateFinalSource(index, finalSource);
ProgramSpec spec = key.spec();
SourceFile vertexFile = spec.getVertexFile();
vertexFile.generateFinalSource(index, finalSource);
key.file.generateFinalSource(index, finalSource);
VertexData appliedTemplate = template.apply(vertexFile);
finalSource.append(appliedTemplate.generateFooter(index, key.vertexType()));
VertexData appliedTemplate = template.apply(key.file);
finalSource.append(appliedTemplate.generateFooter(index, key.vertexType));
return new GlShader(spec.name, ShaderType.VERTEX, finalSource.toString());
return new GlShader(key.file.name, ShaderType.VERTEX, finalSource.toString());
}
@Override
protected void _destroy(GlShader value) {
value.delete();
}
public static class Context {
private final SourceFile file;
private final StateSnapshot ctx;
private final VertexType vertexType;
public Context(SourceFile file, StateSnapshot ctx, VertexType vertexType) {
this.file = file;
this.ctx = ctx;
this.vertexType = vertexType;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
var that = (Context) o;
return file == that.file && vertexType == that.vertexType && ctx.equals(that.ctx);
}
@Override
public int hashCode() {
return Objects.hash(file, ctx, vertexType);
}
}
}

View file

@ -1,15 +1,14 @@
package com.jozufozu.flywheel.core.shader;
import com.jozufozu.flywheel.backend.GameStateRegistry;
import com.mojang.serialization.Codec;
import com.jozufozu.flywheel.core.compile.ShaderConstants;
import net.minecraft.resources.ResourceLocation;
public interface GameStateProvider {
Codec<GameStateProvider> CODEC = ResourceLocation.CODEC.xmap(GameStateRegistry::getStateProvider, GameStateProvider::getID);
ResourceLocation getID();
boolean isTrue();
void alterConstants(ShaderConstants constants);
}

View file

@ -2,6 +2,7 @@ package com.jozufozu.flywheel.core.shader;
import com.jozufozu.flywheel.Flywheel;
import com.jozufozu.flywheel.config.FlwConfig;
import com.jozufozu.flywheel.core.compile.ShaderConstants;
import net.minecraft.resources.ResourceLocation;
@ -24,4 +25,9 @@ public class NormalDebugStateProvider implements GameStateProvider {
public ResourceLocation getID() {
return NAME;
}
@Override
public void alterConstants(ShaderConstants constants) {
constants.define("DEBUG_NORMAL");
}
}

View file

@ -1,10 +1,5 @@
package com.jozufozu.flywheel.core.shader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import com.google.common.collect.ImmutableList;
import com.jozufozu.flywheel.backend.source.FileResolution;
import com.jozufozu.flywheel.backend.source.Resolver;
import com.jozufozu.flywheel.backend.source.SourceFile;
@ -28,27 +23,20 @@ import net.minecraft.resources.ResourceLocation;
*/
public class ProgramSpec {
// TODO: Block model style inheritance?
public static final Codec<ProgramSpec> CODEC = RecordCodecBuilder.create(instance -> instance.group(
ResourceLocation.CODEC.fieldOf("vertex")
.forGetter(ProgramSpec::getSourceLoc),
ResourceLocation.CODEC.fieldOf("fragment")
.forGetter(ProgramSpec::getFragmentLoc),
ProgramState.CODEC.listOf()
.optionalFieldOf("states", Collections.emptyList())
.forGetter(ProgramSpec::getStates))
.forGetter(ProgramSpec::getFragmentLoc))
.apply(instance, ProgramSpec::new));
public ResourceLocation name;
public final FileResolution vertex;
public final FileResolution fragment;
public final ImmutableList<ProgramState> states;
public ProgramSpec(ResourceLocation vertex, ResourceLocation fragment, List<ProgramState> states) {
public ProgramSpec(ResourceLocation vertex, ResourceLocation fragment) {
this.vertex = Resolver.INSTANCE.get(vertex);
this.fragment = Resolver.INSTANCE.get(fragment);
this.states = ImmutableList.copyOf(states);
}
public void setName(ResourceLocation name) {
@ -73,36 +61,8 @@ public class ProgramSpec {
return fragment.getFile();
}
public ImmutableList<ProgramState> getStates() {
return states;
}
/**
* Calculate a unique ID representing the current game state.
*/
public long getCurrentStateID() {
long ctx = 0;
for (ProgramState state : states) {
if (state.context().isTrue()) {
ctx |= 1;
}
ctx <<= 1;
}
return ctx;
}
/**
* Given the stateID, get a list of defines to include at the top of a compiling program.
*/
public List<String> getDefines(long stateID) {
List<String> defines = new ArrayList<>();
for (ProgramState state : states) {
if ((stateID & 1) == 1) {
defines.addAll(state.defines());
}
stateID >>= 1;
}
return defines;
@Override
public String toString() {
return name.toString();
}
}

View file

@ -1,17 +0,0 @@
package com.jozufozu.flywheel.core.shader;
import java.util.Collections;
import java.util.List;
import com.jozufozu.flywheel.util.CodecUtil;
import com.mojang.serialization.Codec;
import com.mojang.serialization.codecs.RecordCodecBuilder;
public record ProgramState(GameStateProvider context, List<String> defines) {
public static final Codec<ProgramState> CODEC = RecordCodecBuilder.create(state -> state.group(GameStateProvider.CODEC.fieldOf("when")
.forGetter(ProgramState::context), CodecUtil.oneOrMore(Codec.STRING)
.optionalFieldOf("define", Collections.emptyList())
.forGetter(ProgramState::defines))
.apply(state, ProgramState::new));
}

View file

@ -0,0 +1,12 @@
package com.jozufozu.flywheel.core.shader;
import com.jozufozu.flywheel.backend.GameStateRegistry;
import com.jozufozu.flywheel.core.compile.ShaderConstants;
public record StateSnapshot(long ctx) {
// TODO: is this needed?
public ShaderConstants getDefines() {
return GameStateRegistry.getDefines(ctx);
}
}

View file

@ -1,10 +1,4 @@
{
"vertex": "flywheel:model.vert",
"fragment": "flywheel:block.frag",
"states": [
{
"when": "flywheel:normal_debug",
"define": "DEBUG_NORMAL"
}
]
"fragment": "flywheel:block.frag"
}

View file

@ -1,10 +1,4 @@
{
"vertex": "flywheel:oriented.vert",
"fragment": "flywheel:block.frag",
"states": [
{
"when": "flywheel:normal_debug",
"define": "DEBUG_NORMAL"
}
]
"fragment": "flywheel:block.frag"
}

View file

@ -1,10 +1,4 @@
{
"vertex": "flywheel:passthru.vert",
"fragment": "flywheel:block.frag",
"states": [
{
"when": "flywheel:normal_debug",
"define": "DEBUG_NORMAL"
}
]
"fragment": "flywheel:block.frag"
}