feat(aa): rule interface: replace less & equal by the compare method.

- set a new alphabet order to sort AARE based string.
- unify compare function for all rules
- handle some special sort order, eg: base include
This commit is contained in:
Alexandre Pujol 2024-06-19 18:34:58 +01:00
parent 747292e954
commit 4cbacc186c
Failed to generate hash of commit
22 changed files with 250 additions and 399 deletions

View file

@ -16,12 +16,8 @@ func (r *All) Validate() error {
return nil
}
func (r *All) Less(other any) bool {
return false
}
func (r *All) Equals(other any) bool {
return false
func (r *All) Compare(other Rule) int {
return 0
}
func (r *All) String() string {

View file

@ -76,26 +76,6 @@ func newRuleFromLog(log map[string]string) RuleBase {
}
}
func (r RuleBase) Less(other any) bool {
return false
}
func (r RuleBase) Equals(other any) bool {
return false
}
func (r RuleBase) String() string {
return renderTemplate(r.Kind(), r)
}
func (r RuleBase) Constraint() constraint {
return anyKind
}
func (r RuleBase) Kind() Kind {
return COMMENT
}
type Qualifier struct {
Audit bool
AccessType string
@ -109,13 +89,9 @@ func newQualifierFromLog(log map[string]string) Qualifier {
return Qualifier{Audit: audit}
}
func (r Qualifier) Less(other Qualifier) bool {
if r.Audit != other.Audit {
return r.Audit
func (r Qualifier) Compare(o Qualifier) int {
if r := compare(r.Audit, o.Audit); r != 0 {
return r
}
return r.AccessType < other.AccessType
}
func (r Qualifier) Equals(other Qualifier) bool {
return r.Audit == other.Audit && r.AccessType == other.AccessType
return compare(r.AccessType, o.AccessType)
}

View file

@ -19,14 +19,9 @@ func (r *Hat) Validate() error {
return nil
}
func (p *Hat) Less(other any) bool {
func (r *Hat) Compare(other Rule) int {
o, _ := other.(*Hat)
return p.Name < o.Name
}
func (p *Hat) Equals(other any) bool {
o, _ := other.(*Hat)
return p.Name == o.Name
return compare(r.Name, o.Name)
}
func (p *Hat) String() string {

View file

@ -6,7 +6,6 @@ package aa
import (
"fmt"
"slices"
)
const CAPABILITY Kind = "capability"
@ -47,19 +46,12 @@ func (r *Capability) Validate() error {
return nil
}
func (r *Capability) Less(other any) bool {
func (r *Capability) Compare(other Rule) int {
o, _ := other.(*Capability)
for i := 0; i < len(r.Names) && i < len(o.Names); i++ {
if r.Names[i] != o.Names[i] {
return r.Names[i] < o.Names[i]
}
if res := compare(r.Names, o.Names); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Capability) Equals(other any) bool {
o, _ := other.(*Capability)
return slices.Equal(r.Names, o.Names) && r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Capability) String() string {

View file

@ -39,24 +39,18 @@ func (r *ChangeProfile) Validate() error {
return nil
}
func (r *ChangeProfile) Less(other any) bool {
func (r *ChangeProfile) Compare(other Rule) int {
o, _ := other.(*ChangeProfile)
if r.ExecMode != o.ExecMode {
return r.ExecMode < o.ExecMode
if res := compare(r.ExecMode, o.ExecMode); res != 0 {
return res
}
if r.Exec != o.Exec {
return r.Exec < o.Exec
if res := compare(r.Exec, o.Exec); res != 0 {
return res
}
if r.ProfileName != o.ProfileName {
return r.ProfileName < o.ProfileName
if res := compare(r.ProfileName, o.ProfileName); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *ChangeProfile) Equals(other any) bool {
o, _ := other.(*ChangeProfile)
return r.ExecMode == o.ExecMode && r.Exec == o.Exec &&
r.ProfileName == o.ProfileName && r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *ChangeProfile) String() string {

View file

@ -19,9 +19,43 @@ func Must[T any](v T, err error) T {
return v
}
// cmpFileAccess compares two access strings for file rules.
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
func compare(a, b any) int {
switch a := a.(type) {
case int:
return a - b.(int)
case string:
a = strings.ToLower(a)
b := strings.ToLower(b.(string))
if a == b {
return 0
}
for i := 0; i < len(a) && i < len(b); i++ {
if a[i] != b[i] {
return stringWeights[a[i]] - stringWeights[b[i]]
}
}
return len(a) - len(b)
case []string:
return slices.CompareFunc(a, b.([]string), func(s1, s2 string) int {
return compare(s1, s2)
})
case bool:
return boolToInt(a) - boolToInt(b.(bool))
default:
panic("compare: unsupported type")
}
}
// compareFileAccess compares two access strings for file rules.
// It is aimed to be used in slices.SortFunc.
func cmpFileAccess(i, j string) int {
func compareFileAccess(i, j string) int {
if slices.Contains(requirements[FILE]["access"], i) &&
slices.Contains(requirements[FILE]["access"], j) {
return requirementsWeights[FILE]["access"][i] - requirementsWeights[FILE]["access"][j]
@ -115,6 +149,6 @@ func toAccess(kind Kind, input string) ([]string, error) {
return toValues(kind, "access", input)
}
slices.SortFunc(res, cmpFileAccess)
slices.SortFunc(res, compareFileAccess)
return slices.Compact(res), nil
}

View file

@ -6,7 +6,6 @@ package aa
import (
"fmt"
"slices"
)
const DBUS Kind = "dbus"
@ -63,43 +62,33 @@ func (r *Dbus) Validate() error {
return validateValues(r.Kind(), "bus", []string{r.Bus})
}
func (r *Dbus) Less(other any) bool {
func (r *Dbus) Compare(other Rule) int {
o, _ := other.(*Dbus)
for i := 0; i < len(r.Access) && i < len(o.Access); i++ {
if r.Access[i] != o.Access[i] {
return r.Access[i] < o.Access[i]
}
if res := compare(r.Access, o.Access); res != 0 {
return res
}
if r.Bus != o.Bus {
return r.Bus < o.Bus
if res := compare(r.Bus, o.Bus); res != 0 {
return res
}
if r.Name != o.Name {
return r.Name < o.Name
if res := compare(r.Name, o.Name); res != 0 {
return res
}
if r.Path != o.Path {
return r.Path < o.Path
if res := compare(r.Path, o.Path); res != 0 {
return res
}
if r.Interface != o.Interface {
return r.Interface < o.Interface
if res := compare(r.Interface, o.Interface); res != 0 {
return res
}
if r.Member != o.Member {
return r.Member < o.Member
if res := compare(r.Member, o.Member); res != 0 {
return res
}
if r.PeerName != o.PeerName {
return r.PeerName < o.PeerName
if res := compare(r.PeerName, o.PeerName); res != 0 {
return res
}
if r.PeerLabel != o.PeerLabel {
return r.PeerLabel < o.PeerLabel
if res := compare(r.PeerLabel, o.PeerLabel); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Dbus) Equals(other any) bool {
o, _ := other.(*Dbus)
return slices.Equal(r.Access, o.Access) && r.Bus == o.Bus && r.Name == o.Name &&
r.Path == o.Path && r.Interface == o.Interface &&
r.Member == o.Member && r.PeerName == o.PeerName &&
r.PeerLabel == o.PeerLabel && r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Dbus) String() string {

View file

@ -68,32 +68,27 @@ func (r *File) Validate() error {
return nil
}
func (r *File) Less(other any) bool {
func (r *File) Compare(other Rule) int {
o, _ := other.(*File)
letterR := getLetterIn(fileAlphabet, r.Path)
letterO := getLetterIn(fileAlphabet, o.Path)
if fileWeights[letterR] != fileWeights[letterO] && letterR != "" && letterO != "" {
return fileWeights[letterR] < fileWeights[letterO]
return fileWeights[letterR] - fileWeights[letterO]
}
if r.Path != o.Path {
return r.Path < o.Path
if res := compare(r.Owner, o.Owner); res != 0 {
return res
}
if o.Owner != r.Owner {
return r.Owner
if res := compare(r.Path, o.Path); res != 0 {
return res
}
if len(r.Access) != len(o.Access) {
return len(r.Access) < len(o.Access)
if res := compare(r.Access, o.Access); res != 0 {
return res
}
if r.Target != o.Target {
return r.Target < o.Target
if res := compare(r.Target, o.Target); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *File) Equals(other any) bool {
o, _ := other.(*File)
return r.Path == o.Path && slices.Equal(r.Access, o.Access) && r.Owner == o.Owner &&
r.Target == o.Target && r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *File) String() string {
@ -131,27 +126,22 @@ func (r *Link) Validate() error {
return nil
}
func (r *Link) Less(other any) bool {
func (r *Link) Compare(other Rule) int {
o, _ := other.(*Link)
if r.Path != o.Path {
return r.Path < o.Path
}
if o.Owner != r.Owner {
return r.Owner
}
if r.Target != o.Target {
return r.Target < o.Target
}
if r.Subset != o.Subset {
return r.Subset
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Link) Equals(other any) bool {
o, _ := other.(*Link)
return r.Subset == o.Subset && r.Owner == o.Owner && r.Path == o.Path &&
r.Target == o.Target && r.Qualifier.Equals(o.Qualifier)
if res := compare(r.Owner, o.Owner); res != 0 {
return res
}
if res := compare(r.Path, o.Path); res != 0 {
return res
}
if res := compare(r.Target, o.Target); res != 0 {
return res
}
if res := compare(r.Subset, o.Subset); res != 0 {
return res
}
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Link) String() string {

View file

@ -6,7 +6,6 @@ package aa
import (
"fmt"
"slices"
)
const IOURING Kind = "io_uring"
@ -40,20 +39,15 @@ func (r *IOUring) Validate() error {
return nil
}
func (r *IOUring) Less(other any) bool {
func (r *IOUring) Compare(other Rule) int {
o, _ := other.(*IOUring)
if len(r.Access) != len(o.Access) {
return len(r.Access) < len(o.Access)
if res := compare(r.Access, o.Access); res != 0 {
return res
}
if r.Label != o.Label {
return r.Label < o.Label
if res := compare(r.Label, o.Label); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *IOUring) Equals(other any) bool {
o, _ := other.(*IOUring)
return slices.Equal(r.Access, o.Access) && r.Label == o.Label && r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *IOUring) String() string {

View file

@ -6,7 +6,6 @@ package aa
import (
"fmt"
"slices"
)
const (
@ -48,15 +47,11 @@ func (m MountConditions) Validate() error {
return validateValues(MOUNT, "flags", m.Options)
}
func (m MountConditions) Less(other MountConditions) bool {
if m.FsType != other.FsType {
return m.FsType < other.FsType
func (m MountConditions) Compare(other MountConditions) int {
if res := compare(m.FsType, other.FsType); res != 0 {
return res
}
return len(m.Options) < len(other.Options)
}
func (m MountConditions) Equals(other MountConditions) bool {
return m.FsType == other.FsType && slices.Equal(m.Options, other.Options)
return compare(m.Options, other.Options)
}
type Mount struct {
@ -84,25 +79,18 @@ func (r *Mount) Validate() error {
return nil
}
func (r *Mount) Less(other any) bool {
func (r *Mount) Compare(other Rule) int {
o, _ := other.(*Mount)
if r.Source != o.Source {
return r.Source < o.Source
if res := compare(r.Source, o.Source); res != 0 {
return res
}
if r.MountPoint != o.MountPoint {
return r.MountPoint < o.MountPoint
if res := compare(r.MountPoint, o.MountPoint); res != 0 {
return res
}
if r.MountConditions.Equals(o.MountConditions) {
return r.MountConditions.Less(o.MountConditions)
if res := r.MountConditions.Compare(o.MountConditions); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Mount) Equals(other any) bool {
o, _ := other.(*Mount)
return r.Source == o.Source && r.MountPoint == o.MountPoint &&
r.MountConditions.Equals(o.MountConditions) &&
r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Mount) String() string {
@ -140,22 +128,15 @@ func (r *Umount) Validate() error {
return nil
}
func (r *Umount) Less(other any) bool {
func (r *Umount) Compare(other Rule) int {
o, _ := other.(*Umount)
if r.MountPoint != o.MountPoint {
return r.MountPoint < o.MountPoint
if res := compare(r.MountPoint, o.MountPoint); res != 0 {
return res
}
if r.MountConditions.Equals(o.MountConditions) {
return r.MountConditions.Less(o.MountConditions)
if res := r.MountConditions.Compare(o.MountConditions); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Umount) Equals(other any) bool {
o, _ := other.(*Umount)
return r.MountPoint == o.MountPoint &&
r.MountConditions.Equals(o.MountConditions) &&
r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Umount) String() string {
@ -193,22 +174,15 @@ func (r *Remount) Validate() error {
return nil
}
func (r *Remount) Less(other any) bool {
func (r *Remount) Compare(other Rule) int {
o, _ := other.(*Remount)
if r.MountPoint != o.MountPoint {
return r.MountPoint < o.MountPoint
if res := compare(r.MountPoint, o.MountPoint); res != 0 {
return res
}
if r.MountConditions.Equals(o.MountConditions) {
return r.MountConditions.Less(o.MountConditions)
if res := r.MountConditions.Compare(o.MountConditions); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Remount) Equals(other any) bool {
o, _ := other.(*Remount)
return r.MountPoint == o.MountPoint &&
r.MountConditions.Equals(o.MountConditions) &&
r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Remount) String() string {

View file

@ -58,24 +58,18 @@ func (r *Mqueue) Validate() error {
return nil
}
func (r *Mqueue) Less(other any) bool {
func (r *Mqueue) Compare(other Rule) int {
o, _ := other.(*Mqueue)
if len(r.Access) != len(o.Access) {
return len(r.Access) < len(o.Access)
if res := compare(r.Access, o.Access); res != 0 {
return res
}
if r.Type != o.Type {
return r.Type < o.Type
if res := compare(r.Type, o.Type); res != 0 {
return res
}
if r.Label != o.Label {
return r.Label < o.Label
if res := compare(r.Label, o.Label); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Mqueue) Equals(other any) bool {
o, _ := other.(*Mqueue)
return slices.Equal(r.Access, o.Access) && r.Type == o.Type && r.Label == o.Label &&
r.Name == o.Name && r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Mqueue) String() string {

View file

@ -46,14 +46,14 @@ func newAddressExprFromLog(log map[string]string) AddressExpr {
}
}
func (r AddressExpr) Less(other AddressExpr) bool {
if r.Source != other.Source {
return r.Source < other.Source
func (r AddressExpr) Compare(other AddressExpr) int {
if res := compare(r.Source, other.Source); res != 0 {
return res
}
if r.Destination != other.Destination {
return r.Destination < other.Destination
if res := compare(r.Destination, other.Destination); res != 0 {
return res
}
return r.Port < other.Port
return compare(r.Port, other.Port)
}
func (r AddressExpr) Equals(other AddressExpr) bool {
@ -94,28 +94,21 @@ func (r *Network) Validate() error {
return nil
}
func (r *Network) Less(other any) bool {
func (r *Network) Compare(other Rule) int {
o, _ := other.(*Network)
if r.Domain != o.Domain {
return r.Domain < o.Domain
if res := compare(r.Domain, o.Domain); res != 0 {
return res
}
if r.Type != o.Type {
return r.Type < o.Type
if res := compare(r.Type, o.Type); res != 0 {
return res
}
if r.Protocol != o.Protocol {
return r.Protocol < o.Protocol
if res := compare(r.Protocol, o.Protocol); res != 0 {
return res
}
if r.AddressExpr.Less(o.AddressExpr) {
return r.AddressExpr.Less(o.AddressExpr)
if res := r.AddressExpr.Compare(o.AddressExpr); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Network) Equals(other any) bool {
o, _ := other.(*Network)
return r.Domain == o.Domain && r.Type == o.Type &&
r.Protocol == o.Protocol && r.AddressExpr.Equals(o.AddressExpr) &&
r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Network) String() string {

View file

@ -28,25 +28,18 @@ func (r *PivotRoot) Validate() error {
return nil
}
func (r *PivotRoot) Less(other any) bool {
func (r *PivotRoot) Compare(other Rule) int {
o, _ := other.(*PivotRoot)
if r.OldRoot != o.OldRoot {
return r.OldRoot < o.OldRoot
if res := compare(r.OldRoot, o.OldRoot); res != 0 {
return res
}
if r.NewRoot != o.NewRoot {
return r.NewRoot < o.NewRoot
if res := compare(r.NewRoot, o.NewRoot); res != 0 {
return res
}
if r.TargetProfile != o.TargetProfile {
return r.TargetProfile < o.TargetProfile
if res := compare(r.TargetProfile, o.TargetProfile); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *PivotRoot) Equals(other any) bool {
o, _ := other.(*PivotRoot)
return r.OldRoot == o.OldRoot && r.NewRoot == o.NewRoot &&
r.TargetProfile == o.TargetProfile &&
r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *PivotRoot) String() string {

View file

@ -6,7 +6,6 @@ package aa
import (
"fmt"
"slices"
"strings"
)
@ -34,12 +33,8 @@ func (r *Comment) Validate() error {
return nil
}
func (r *Comment) Less(other any) bool {
return false
}
func (r *Comment) Equals(other any) bool {
return false
func (r *Comment) Compare(other Rule) int {
return 0
}
func (r *Comment) String() string {
@ -93,17 +88,12 @@ func (r *Abi) Validate() error {
return nil
}
func (r *Abi) Less(other any) bool {
func (r *Abi) Compare(other Rule) int {
o, _ := other.(*Abi)
if r.Path != o.Path {
return r.Path < o.Path
if res := compare(r.Path, o.Path); res != 0 {
return res
}
return r.IsMagic == o.IsMagic
}
func (r *Abi) Equals(other any) bool {
o, _ := other.(*Abi)
return r.Path == o.Path && r.IsMagic == o.IsMagic
return compare(r.IsMagic, o.IsMagic)
}
func (r *Abi) String() string {
@ -145,17 +135,12 @@ func (r *Alias) Validate() error {
return nil
}
func (r Alias) Less(other any) bool {
func (r *Alias) Compare(other Rule) int {
o, _ := other.(*Alias)
if r.Path != o.Path {
return r.Path < o.Path
if res := compare(r.Path, o.Path); res != 0 {
return res
}
return r.RewrittenPath < o.RewrittenPath
}
func (r Alias) Equals(other any) bool {
o, _ := other.(*Alias)
return r.Path == o.Path && r.RewrittenPath == o.RewrittenPath
return compare(r.RewrittenPath, o.RewrittenPath)
}
func (r *Alias) String() string {
@ -216,20 +201,22 @@ func (r *Include) Validate() error {
return nil
}
func (r *Include) Less(other any) bool {
func (r *Include) Compare(other Rule) int {
const base = "abstractions/base"
o, _ := other.(*Include)
if r.Path == o.Path {
return r.Path < o.Path
if res := compare(r.Path, o.Path); res != 0 {
if r.Path == base {
return -1
}
if o.Path == base {
return 1
}
return res
}
if r.IsMagic != o.IsMagic {
return r.IsMagic
if res := compare(r.IsMagic, o.IsMagic); res != 0 {
return res
}
return r.IfExists
}
func (r *Include) Equals(other any) bool {
o, _ := other.(*Include)
return r.Path == o.Path && r.IsMagic == o.IsMagic && r.IfExists == o.IfExists
return compare(r.IfExists, o.IfExists)
}
func (r *Include) String() string {
@ -284,17 +271,8 @@ func (r *Variable) Validate() error {
return nil
}
func (r *Variable) Less(other any) bool {
o, _ := other.(*Variable)
if r.Name != o.Name {
return r.Name < o.Name
}
return len(r.Values) < len(o.Values)
}
func (r *Variable) Equals(other any) bool {
o, _ := other.(*Variable)
return r.Name == o.Name && slices.Equal(r.Values, o.Values)
func (r *Variable) Compare(other Rule) int {
return 0
}
func (r *Variable) String() string {

View file

@ -6,7 +6,6 @@ package aa
import (
"fmt"
"maps"
"slices"
"strings"
)
@ -96,19 +95,12 @@ func (r *Profile) Validate() error {
return r.Rules.Validate()
}
func (p *Profile) Less(other any) bool {
func (r *Profile) Compare(other Rule) int {
o, _ := other.(*Profile)
if p.Name != o.Name {
return p.Name < o.Name
if res := compare(r.Name, o.Name); res != 0 {
return res
}
return len(p.Attachments) < len(o.Attachments)
}
func (p *Profile) Equals(other any) bool {
o, _ := other.(*Profile)
return p.Name == o.Name && slices.Equal(p.Attachments, o.Attachments) &&
maps.Equal(p.Attributes, o.Attributes) &&
slices.Equal(p.Flags, o.Flags)
return compare(r.Attachments, o.Attachments)
}
func (p *Profile) String() string {

View file

@ -6,7 +6,6 @@ package aa
import (
"fmt"
"slices"
)
const PTRACE Kind = "ptrace"
@ -42,21 +41,15 @@ func (r *Ptrace) Validate() error {
return nil
}
func (r *Ptrace) Less(other any) bool {
func (r *Ptrace) Compare(other Rule) int {
o, _ := other.(*Ptrace)
if len(r.Access) != len(o.Access) {
return len(r.Access) < len(o.Access)
if res := compare(r.Access, o.Access); res != 0 {
return res
}
if r.Peer != o.Peer {
return r.Peer == o.Peer
if res := compare(r.Peer, o.Peer); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Ptrace) Equals(other any) bool {
o, _ := other.(*Ptrace)
return slices.Equal(r.Access, o.Access) && r.Peer == o.Peer &&
r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Ptrace) String() string {

View file

@ -43,20 +43,15 @@ func (r *Rlimit) Validate() error {
return nil
}
func (r *Rlimit) Less(other any) bool {
func (r *Rlimit) Compare(other Rule) int {
o, _ := other.(*Rlimit)
if r.Key != o.Key {
return r.Key < o.Key
if res := compare(r.Key, o.Key); res != 0 {
return res
}
if r.Op != o.Op {
return r.Op < o.Op
if res := compare(r.Op, o.Op); res != 0 {
return res
}
return r.Value < o.Value
}
func (r *Rlimit) Equals(other any) bool {
o, _ := other.(*Rlimit)
return r.Key == o.Key && r.Op == o.Op && r.Value == o.Value
return compare(r.Value, o.Value)
}
func (r *Rlimit) String() string {

View file

@ -35,8 +35,7 @@ func (k Kind) Tok() string {
// Rule generic interface for all AppArmor rules
type Rule interface {
Validate() error
Less(other any) bool
Equals(other any) bool
Compare(other Rule) int
String() string
Constraint() constraint
Kind() Kind
@ -66,7 +65,7 @@ func (r Rules) Index(item Rule) int {
if rule == nil {
continue
}
if rule.Kind() == item.Kind() && rule.Equals(item) {
if rule.Kind() == item.Kind() && rule.Compare(item) == 0 {
return idx
}
}
@ -153,7 +152,7 @@ func (r Rules) Merge() Rules {
}
// If rules are identical, merge them
if r[i].Equals(r[j]) {
if r[i].Compare(r[j]) == 0 {
r = r.Delete(j)
j--
continue
@ -166,7 +165,7 @@ func (r Rules) Merge() Rules {
fileJ := r[j].(*File)
if fileI.Path == fileJ.Path {
fileI.Access = append(fileI.Access, fileJ.Access...)
slices.SortFunc(fileI.Access, cmpFileAccess)
slices.SortFunc(fileI.Access, compareFileAccess)
fileI.Access = slices.Compact(fileI.Access)
r = r.Delete(j)
j--
@ -192,13 +191,7 @@ func (r Rules) Sort() Rules {
}
return ruleWeights[kindOfA] - ruleWeights[kindOfB]
}
if a.Equals(b) {
return 0
}
if a.Less(b) {
return -1
}
return 1
return a.Compare(b)
})
return r
}

View file

@ -6,7 +6,6 @@ package aa
import (
"fmt"
"slices"
)
const SIGNAL Kind = "signal"
@ -60,24 +59,18 @@ func (r *Signal) Validate() error {
return nil
}
func (r *Signal) Less(other any) bool {
func (r *Signal) Compare(other Rule) int {
o, _ := other.(*Signal)
if len(r.Access) != len(o.Access) {
return len(r.Access) < len(o.Access)
if res := compare(r.Access, o.Access); res != 0 {
return res
}
if len(r.Set) != len(o.Set) {
return len(r.Set) < len(o.Set)
if res := compare(r.Set, o.Set); res != 0 {
return res
}
if r.Peer != o.Peer {
return r.Peer < o.Peer
if res := compare(r.Peer, o.Peer); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Signal) Equals(other any) bool {
o, _ := other.(*Signal)
return slices.Equal(r.Access, o.Access) && slices.Equal(r.Set, o.Set) &&
r.Peer == o.Peer && r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Signal) String() string {

View file

@ -117,6 +117,12 @@ var (
}
fileWeights = generateWeights(fileAlphabet)
// The order AARE should be sorted
stringAlphabet = []byte(
"!\"#$%&'(){}[]*+,-./:;<=>?@\\^_`|~0123456789abcdefghijklmnopqrstuvwxyz",
)
stringWeights = generateWeights(stringAlphabet)
// The order the rule values (access, type, domains, etc) should be sorted
requirements = map[Kind]requirement{}
requirementsWeights map[Kind]map[string]map[string]int
@ -155,7 +161,7 @@ func renderTemplate(name Kind, data any) string {
return res.String()
}
func generateWeights[T Kind | string](alphabet []T) map[T]int {
func generateWeights[T comparable](alphabet []T) map[T]int {
res := make(map[T]int, len(alphabet))
for i, r := range alphabet {
res[r] = i

View file

@ -6,7 +6,6 @@ package aa
import (
"fmt"
"slices"
)
const UNIX Kind = "unix"
@ -58,45 +57,36 @@ func (r *Unix) Validate() error {
return nil
}
func (r *Unix) Less(other any) bool {
func (r *Unix) Compare(other Rule) int {
o, _ := other.(*Unix)
if len(r.Access) != len(o.Access) {
return len(r.Access) < len(o.Access)
if res := compare(r.Access, o.Access); res != 0 {
return res
}
if r.Type != o.Type {
return r.Type < o.Type
if res := compare(r.Type, o.Type); res != 0 {
return res
}
if r.Protocol != o.Protocol {
return r.Protocol < o.Protocol
if res := compare(r.Protocol, o.Protocol); res != 0 {
return res
}
if r.Address != o.Address {
return r.Address < o.Address
if res := compare(r.Address, o.Address); res != 0 {
return res
}
if r.Label != o.Label {
return r.Label < o.Label
if res := compare(r.Label, o.Label); res != 0 {
return res
}
if r.Attr != o.Attr {
return r.Attr < o.Attr
if res := compare(r.Attr, o.Attr); res != 0 {
return res
}
if r.Opt != o.Opt {
return r.Opt < o.Opt
if res := compare(r.Opt, o.Opt); res != 0 {
return res
}
if r.PeerLabel != o.PeerLabel {
return r.PeerLabel < o.PeerLabel
if res := compare(r.PeerLabel, o.PeerLabel); res != 0 {
return res
}
if r.PeerAddr != o.PeerAddr {
return r.PeerAddr < o.PeerAddr
if res := compare(r.PeerAddr, o.PeerAddr); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Unix) Equals(other any) bool {
o, _ := other.(*Unix)
return slices.Equal(r.Access, o.Access) && r.Type == o.Type &&
r.Protocol == o.Protocol && r.Address == o.Address &&
r.Label == o.Label && r.Attr == o.Attr && r.Opt == o.Opt &&
r.PeerLabel == o.PeerLabel && r.PeerAddr == o.PeerAddr &&
r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Unix) String() string {

View file

@ -4,6 +4,8 @@
package aa
import "fmt"
const USERNS Kind = "userns"
type Userns struct {
@ -24,17 +26,12 @@ func (r *Userns) Validate() error {
return nil
}
func (r *Userns) Less(other any) bool {
func (r *Userns) Compare(other Rule) int {
o, _ := other.(*Userns)
if r.Create != o.Create {
return r.Create
if res := compare(r.Create, o.Create); res != 0 {
return res
}
return r.Qualifier.Less(o.Qualifier)
}
func (r *Userns) Equals(other any) bool {
o, _ := other.(*Userns)
return r.Create == o.Create && r.Qualifier.Equals(o.Qualifier)
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Userns) String() string {