feat(aa): add merge methods to the rule interface.

This commit is contained in:
Alexandre Pujol 2024-06-22 20:59:43 +01:00
parent a91e2ddf56
commit 6791dcde28
No known key found for this signature in database
GPG Key ID: C5469996F0DF68EC
14 changed files with 192 additions and 11 deletions

View File

@ -24,6 +24,11 @@ func (r *All) Compare(other Rule) int {
return 0
}
func (r *All) Merge(other Rule) bool {
o, _ := other.(*All)
return r.RuleBase.merge(o.RuleBase)
}
func (r *All) String() string {
return renderTemplate(r.Kind(), r)
}

View File

@ -83,6 +83,11 @@ func (r RuleBase) Merge(other Rule) bool {
return false
}
func (r RuleBase) merge(other RuleBase) bool {
r.Comment += " " + other.Comment
return true
}
type Qualifier struct {
Audit bool
AccessType string
@ -102,3 +107,7 @@ func (r Qualifier) Compare(o Qualifier) int {
}
return compare(r.AccessType, o.AccessType)
}
func (r Qualifier) Equal(o Qualifier) bool {
return r.Audit == o.Audit && r.AccessType == o.AccessType
}

View File

@ -26,6 +26,23 @@ func boolToInt(b bool) int {
return 0
}
func merge(kind Kind, key string, a, b []string) []string {
a = append(a, b...)
switch kind {
case FILE:
slices.SortFunc(a, compareFileAccess)
case VARIABLE:
slices.SortFunc(a, func(s1, s2 string) int {
return compare(s1, s2)
})
default:
slices.SortFunc(a, func(i, j string) int {
return requirementsWeights[kind][key][i] - requirementsWeights[kind][key][j]
})
}
return slices.Compact(a)
}
func compare(a, b any) int {
switch a := a.(type) {
case int:

View File

@ -110,6 +110,21 @@ func (r *Dbus) Compare(other Rule) int {
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Dbus) Merge(other Rule) bool {
o, _ := other.(*Dbus)
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
if r.Bus == o.Bus && r.Name == o.Name && r.Path == o.Path &&
r.Interface == o.Interface && r.Member == o.Member &&
r.PeerName == o.PeerName && r.PeerLabel == o.PeerLabel {
r.Access = merge(r.Kind(), "access", r.Access, o.Access)
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *Dbus) String() string {
return renderTemplate(r.Kind(), r)
}

View File

@ -132,11 +132,13 @@ func (r *File) Compare(other Rule) int {
func (r *File) Merge(other Rule) bool {
o, _ := other.(*File)
if r.Path == o.Path {
r.Access = append(r.Access, o.Access...)
slices.SortFunc(r.Access, compareFileAccess)
r.Access = slices.Compact(r.Access)
return true
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
if r.Owner == o.Owner && r.Path == o.Path && r.Target == o.Target {
r.Access = merge(r.Kind(), "access", r.Access, o.Access)
return r.RuleBase.merge(o.RuleBase)
}
return false
}

View File

@ -63,6 +63,19 @@ func (r *IOUring) Compare(other Rule) int {
return r.Qualifier.Compare(o.Qualifier)
}
func (r *IOUring) Merge(other Rule) bool {
o, _ := other.(*IOUring)
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
if r.Label == o.Label {
r.Access = merge(r.Kind(), "access", r.Access, o.Access)
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *IOUring) String() string {
return renderTemplate(r.Kind(), r)
}

View File

@ -65,6 +65,14 @@ func (m MountConditions) Compare(other MountConditions) int {
return compare(m.Options, other.Options)
}
func (m MountConditions) Merge(other MountConditions) bool {
if m.FsType == other.FsType {
m.Options = merge(MOUNT, "flags", m.Options, other.Options)
return true
}
return false
}
type Mount struct {
RuleBase
Qualifier
@ -133,6 +141,19 @@ func (r *Mount) Compare(other Rule) int {
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Mount) Merge(other Rule) bool {
o, _ := other.(*Mount)
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
if r.Source == o.Source && r.MountPoint == o.MountPoint &&
r.MountConditions.Merge(o.MountConditions) {
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *Mount) String() string {
return renderTemplate(r.Kind(), r)
}
@ -197,6 +218,18 @@ func (r *Umount) Compare(other Rule) int {
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Umount) Merge(other Rule) bool {
o, _ := other.(*Umount)
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
if r.MountPoint == o.MountPoint && r.MountConditions.Merge(o.MountConditions) {
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *Umount) String() string {
return renderTemplate(r.Kind(), r)
}
@ -262,6 +295,18 @@ func (r *Remount) Compare(other Rule) int {
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Remount) Merge(other Rule) bool {
o, _ := other.(*Remount)
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
if r.MountPoint == o.MountPoint && r.MountConditions.Merge(o.MountConditions) {
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *Remount) String() string {
return renderTemplate(r.Kind(), r)
}

View File

@ -97,6 +97,19 @@ func (r *Mqueue) Compare(other Rule) int {
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Mqueue) Merge(other Rule) bool {
o, _ := other.(*Mqueue)
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
if r.Type == o.Type && r.Label == o.Label && r.Name == o.Name {
r.Access = merge(r.Kind(), "access", r.Access, o.Access)
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *Mqueue) String() string {
return renderTemplate(r.Kind(), r)
}

View File

@ -57,11 +57,6 @@ func (r AddressExpr) Compare(other AddressExpr) int {
return compare(r.Port, other.Port)
}
func (r AddressExpr) Equals(other AddressExpr) bool {
return r.Source == other.Source && r.Destination == other.Destination &&
r.Port == other.Port
}
type Network struct {
RuleBase
Qualifier

View File

@ -256,8 +256,25 @@ func (r *Variable) Validate() error {
return nil
}
func (r *Variable) Merge(other Rule) bool {
o, _ := other.(*Variable)
if r.Name == o.Name && r.Define == o.Define {
r.Values = merge(r.Kind(), "access", r.Values, o.Values)
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *Variable) Compare(other Rule) int {
return 0
o, _ := other.(*Variable)
if res := compare(r.Name, o.Name); res != 0 {
return res
}
if res := compare(r.Define, o.Define); res != 0 {
return res
}
return compare(r.Values, o.Values)
}
func (r *Variable) String() string {

View File

@ -54,6 +54,19 @@ func (r *Ptrace) Validate() error {
return nil
}
func (r *Ptrace) Merge(other Rule) bool {
o, _ := other.(*Ptrace)
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
if r.Peer == o.Peer {
r.Access = merge(r.Kind(), "access", r.Access, o.Access)
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *Ptrace) Compare(other Rule) int {
o, _ := other.(*Ptrace)
if res := compare(r.Access, o.Access); res != 0 {

View File

@ -77,6 +77,23 @@ func (r *Signal) Validate() error {
return nil
}
func (r *Signal) Merge(other Rule) bool {
o, _ := other.(*Signal)
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
switch {
case r.Peer == o.Peer && compare(r.Set, o.Set) == 0:
r.Access = merge(r.Kind(), "access", r.Access, o.Access)
return r.RuleBase.merge(o.RuleBase)
case r.Peer == o.Peer && compare(r.Access, o.Access) == 0:
r.Set = merge(r.Kind(), "set", r.Set, o.Set)
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *Signal) Compare(other Rule) int {
o, _ := other.(*Signal)
if res := compare(r.Access, o.Access); res != 0 {

View File

@ -109,6 +109,21 @@ func (r *Unix) Compare(other Rule) int {
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Unix) Merge(other Rule) bool {
o, _ := other.(*Unix)
if !r.Qualifier.Equal(o.Qualifier) {
return false
}
if r.Type == o.Type && r.Protocol == o.Protocol && r.Address == o.Address &&
r.Label == o.Label && r.Attr == o.Attr && r.Opt == o.Opt &&
r.PeerLabel == o.PeerLabel && r.PeerAddr == o.PeerAddr {
r.Access = merge(r.Kind(), "access", r.Access, o.Access)
return r.RuleBase.merge(o.RuleBase)
}
return false
}
func (r *Unix) String() string {
return renderTemplate(r.Kind(), r)
}

View File

@ -54,6 +54,11 @@ func (r *Userns) Compare(other Rule) int {
return r.Qualifier.Compare(o.Qualifier)
}
func (r *Userns) Merge(other Rule) bool {
o, _ := other.(*Userns)
return r.RuleBase.merge(o.RuleBase)
}
func (r *Userns) String() string {
return renderTemplate(r.Kind(), r)
}