From 6791dcde280ffbaada807aecbf4f89874498aaf6 Mon Sep 17 00:00:00 2001 From: Alexandre Pujol Date: Sat, 22 Jun 2024 20:59:43 +0100 Subject: [PATCH] feat(aa): add merge methods to the rule interface. --- pkg/aa/all.go | 5 +++++ pkg/aa/base.go | 9 +++++++++ pkg/aa/convert.go | 17 +++++++++++++++++ pkg/aa/dbus.go | 15 +++++++++++++++ pkg/aa/file.go | 12 +++++++----- pkg/aa/io_uring.go | 13 +++++++++++++ pkg/aa/mount.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ pkg/aa/mqueue.go | 13 +++++++++++++ pkg/aa/network.go | 5 ----- pkg/aa/preamble.go | 19 ++++++++++++++++++- pkg/aa/ptrace.go | 13 +++++++++++++ pkg/aa/signal.go | 17 +++++++++++++++++ pkg/aa/unix.go | 15 +++++++++++++++ pkg/aa/userns.go | 5 +++++ 14 files changed, 192 insertions(+), 11 deletions(-) diff --git a/pkg/aa/all.go b/pkg/aa/all.go index 2ef5441b..5351afb7 100644 --- a/pkg/aa/all.go +++ b/pkg/aa/all.go @@ -24,6 +24,11 @@ func (r *All) Compare(other Rule) int { return 0 } +func (r *All) Merge(other Rule) bool { + o, _ := other.(*All) + return r.RuleBase.merge(o.RuleBase) +} + func (r *All) String() string { return renderTemplate(r.Kind(), r) } diff --git a/pkg/aa/base.go b/pkg/aa/base.go index a9c86487..92aa76f8 100644 --- a/pkg/aa/base.go +++ b/pkg/aa/base.go @@ -83,6 +83,11 @@ func (r RuleBase) Merge(other Rule) bool { return false } +func (r RuleBase) merge(other RuleBase) bool { + r.Comment += " " + other.Comment + return true +} + type Qualifier struct { Audit bool AccessType string @@ -102,3 +107,7 @@ func (r Qualifier) Compare(o Qualifier) int { } return compare(r.AccessType, o.AccessType) } + +func (r Qualifier) Equal(o Qualifier) bool { + return r.Audit == o.Audit && r.AccessType == o.AccessType +} diff --git a/pkg/aa/convert.go b/pkg/aa/convert.go index eef77db4..e0889360 100644 --- a/pkg/aa/convert.go +++ b/pkg/aa/convert.go @@ -26,6 +26,23 @@ func boolToInt(b bool) int { return 0 } +func merge(kind Kind, key string, a, b []string) []string { + a = append(a, b...) + switch kind { + case FILE: + slices.SortFunc(a, compareFileAccess) + case VARIABLE: + slices.SortFunc(a, func(s1, s2 string) int { + return compare(s1, s2) + }) + default: + slices.SortFunc(a, func(i, j string) int { + return requirementsWeights[kind][key][i] - requirementsWeights[kind][key][j] + }) + } + return slices.Compact(a) +} + func compare(a, b any) int { switch a := a.(type) { case int: diff --git a/pkg/aa/dbus.go b/pkg/aa/dbus.go index a3f679df..8602e93d 100644 --- a/pkg/aa/dbus.go +++ b/pkg/aa/dbus.go @@ -110,6 +110,21 @@ func (r *Dbus) Compare(other Rule) int { return r.Qualifier.Compare(o.Qualifier) } +func (r *Dbus) Merge(other Rule) bool { + o, _ := other.(*Dbus) + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + if r.Bus == o.Bus && r.Name == o.Name && r.Path == o.Path && + r.Interface == o.Interface && r.Member == o.Member && + r.PeerName == o.PeerName && r.PeerLabel == o.PeerLabel { + r.Access = merge(r.Kind(), "access", r.Access, o.Access) + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *Dbus) String() string { return renderTemplate(r.Kind(), r) } diff --git a/pkg/aa/file.go b/pkg/aa/file.go index 7cc6d4dc..928e897c 100644 --- a/pkg/aa/file.go +++ b/pkg/aa/file.go @@ -132,11 +132,13 @@ func (r *File) Compare(other Rule) int { func (r *File) Merge(other Rule) bool { o, _ := other.(*File) - if r.Path == o.Path { - r.Access = append(r.Access, o.Access...) - slices.SortFunc(r.Access, compareFileAccess) - r.Access = slices.Compact(r.Access) - return true + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + if r.Owner == o.Owner && r.Path == o.Path && r.Target == o.Target { + r.Access = merge(r.Kind(), "access", r.Access, o.Access) + return r.RuleBase.merge(o.RuleBase) } return false } diff --git a/pkg/aa/io_uring.go b/pkg/aa/io_uring.go index f7adff8f..06c33ad3 100644 --- a/pkg/aa/io_uring.go +++ b/pkg/aa/io_uring.go @@ -63,6 +63,19 @@ func (r *IOUring) Compare(other Rule) int { return r.Qualifier.Compare(o.Qualifier) } +func (r *IOUring) Merge(other Rule) bool { + o, _ := other.(*IOUring) + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + if r.Label == o.Label { + r.Access = merge(r.Kind(), "access", r.Access, o.Access) + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *IOUring) String() string { return renderTemplate(r.Kind(), r) } diff --git a/pkg/aa/mount.go b/pkg/aa/mount.go index 4d7928b0..480afa2f 100644 --- a/pkg/aa/mount.go +++ b/pkg/aa/mount.go @@ -65,6 +65,14 @@ func (m MountConditions) Compare(other MountConditions) int { return compare(m.Options, other.Options) } +func (m MountConditions) Merge(other MountConditions) bool { + if m.FsType == other.FsType { + m.Options = merge(MOUNT, "flags", m.Options, other.Options) + return true + } + return false +} + type Mount struct { RuleBase Qualifier @@ -133,6 +141,19 @@ func (r *Mount) Compare(other Rule) int { return r.Qualifier.Compare(o.Qualifier) } +func (r *Mount) Merge(other Rule) bool { + o, _ := other.(*Mount) + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + if r.Source == o.Source && r.MountPoint == o.MountPoint && + r.MountConditions.Merge(o.MountConditions) { + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *Mount) String() string { return renderTemplate(r.Kind(), r) } @@ -197,6 +218,18 @@ func (r *Umount) Compare(other Rule) int { return r.Qualifier.Compare(o.Qualifier) } +func (r *Umount) Merge(other Rule) bool { + o, _ := other.(*Umount) + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + if r.MountPoint == o.MountPoint && r.MountConditions.Merge(o.MountConditions) { + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *Umount) String() string { return renderTemplate(r.Kind(), r) } @@ -262,6 +295,18 @@ func (r *Remount) Compare(other Rule) int { return r.Qualifier.Compare(o.Qualifier) } +func (r *Remount) Merge(other Rule) bool { + o, _ := other.(*Remount) + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + if r.MountPoint == o.MountPoint && r.MountConditions.Merge(o.MountConditions) { + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *Remount) String() string { return renderTemplate(r.Kind(), r) } diff --git a/pkg/aa/mqueue.go b/pkg/aa/mqueue.go index e8d4dffd..7edd9358 100644 --- a/pkg/aa/mqueue.go +++ b/pkg/aa/mqueue.go @@ -97,6 +97,19 @@ func (r *Mqueue) Compare(other Rule) int { return r.Qualifier.Compare(o.Qualifier) } +func (r *Mqueue) Merge(other Rule) bool { + o, _ := other.(*Mqueue) + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + if r.Type == o.Type && r.Label == o.Label && r.Name == o.Name { + r.Access = merge(r.Kind(), "access", r.Access, o.Access) + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *Mqueue) String() string { return renderTemplate(r.Kind(), r) } diff --git a/pkg/aa/network.go b/pkg/aa/network.go index 0478b310..38818de1 100644 --- a/pkg/aa/network.go +++ b/pkg/aa/network.go @@ -57,11 +57,6 @@ func (r AddressExpr) Compare(other AddressExpr) int { return compare(r.Port, other.Port) } -func (r AddressExpr) Equals(other AddressExpr) bool { - return r.Source == other.Source && r.Destination == other.Destination && - r.Port == other.Port -} - type Network struct { RuleBase Qualifier diff --git a/pkg/aa/preamble.go b/pkg/aa/preamble.go index d9e8b1f2..e628417b 100644 --- a/pkg/aa/preamble.go +++ b/pkg/aa/preamble.go @@ -256,8 +256,25 @@ func (r *Variable) Validate() error { return nil } +func (r *Variable) Merge(other Rule) bool { + o, _ := other.(*Variable) + + if r.Name == o.Name && r.Define == o.Define { + r.Values = merge(r.Kind(), "access", r.Values, o.Values) + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *Variable) Compare(other Rule) int { - return 0 + o, _ := other.(*Variable) + if res := compare(r.Name, o.Name); res != 0 { + return res + } + if res := compare(r.Define, o.Define); res != 0 { + return res + } + return compare(r.Values, o.Values) } func (r *Variable) String() string { diff --git a/pkg/aa/ptrace.go b/pkg/aa/ptrace.go index 2a5109a6..1ecec49a 100644 --- a/pkg/aa/ptrace.go +++ b/pkg/aa/ptrace.go @@ -54,6 +54,19 @@ func (r *Ptrace) Validate() error { return nil } +func (r *Ptrace) Merge(other Rule) bool { + o, _ := other.(*Ptrace) + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + if r.Peer == o.Peer { + r.Access = merge(r.Kind(), "access", r.Access, o.Access) + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *Ptrace) Compare(other Rule) int { o, _ := other.(*Ptrace) if res := compare(r.Access, o.Access); res != 0 { diff --git a/pkg/aa/signal.go b/pkg/aa/signal.go index 8e674df2..6d590b10 100644 --- a/pkg/aa/signal.go +++ b/pkg/aa/signal.go @@ -77,6 +77,23 @@ func (r *Signal) Validate() error { return nil } +func (r *Signal) Merge(other Rule) bool { + o, _ := other.(*Signal) + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + switch { + case r.Peer == o.Peer && compare(r.Set, o.Set) == 0: + r.Access = merge(r.Kind(), "access", r.Access, o.Access) + return r.RuleBase.merge(o.RuleBase) + case r.Peer == o.Peer && compare(r.Access, o.Access) == 0: + r.Set = merge(r.Kind(), "set", r.Set, o.Set) + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *Signal) Compare(other Rule) int { o, _ := other.(*Signal) if res := compare(r.Access, o.Access); res != 0 { diff --git a/pkg/aa/unix.go b/pkg/aa/unix.go index f65c953d..5ccae971 100644 --- a/pkg/aa/unix.go +++ b/pkg/aa/unix.go @@ -109,6 +109,21 @@ func (r *Unix) Compare(other Rule) int { return r.Qualifier.Compare(o.Qualifier) } +func (r *Unix) Merge(other Rule) bool { + o, _ := other.(*Unix) + + if !r.Qualifier.Equal(o.Qualifier) { + return false + } + if r.Type == o.Type && r.Protocol == o.Protocol && r.Address == o.Address && + r.Label == o.Label && r.Attr == o.Attr && r.Opt == o.Opt && + r.PeerLabel == o.PeerLabel && r.PeerAddr == o.PeerAddr { + r.Access = merge(r.Kind(), "access", r.Access, o.Access) + return r.RuleBase.merge(o.RuleBase) + } + return false +} + func (r *Unix) String() string { return renderTemplate(r.Kind(), r) } diff --git a/pkg/aa/userns.go b/pkg/aa/userns.go index a28870c9..6770106b 100644 --- a/pkg/aa/userns.go +++ b/pkg/aa/userns.go @@ -54,6 +54,11 @@ func (r *Userns) Compare(other Rule) int { return r.Qualifier.Compare(o.Qualifier) } +func (r *Userns) Merge(other Rule) bool { + o, _ := other.(*Userns) + return r.RuleBase.merge(o.RuleBase) +} + func (r *Userns) String() string { return renderTemplate(r.Kind(), r) }