feat(aa): add initial profile validation structure.

This commit is contained in:
Alexandre Pujol 2024-05-25 22:36:39 +01:00
parent 2dd6046697
commit 92641e7e28
Failed to generate hash of commit
20 changed files with 222 additions and 2 deletions

View file

@ -12,6 +12,10 @@ type All struct {
RuleBase
}
func (r *All) Validate() error {
return nil
}
func (r *All) Less(other any) bool {
return false
}

View file

@ -49,6 +49,19 @@ func (f *AppArmorProfileFile) String() string {
return renderTemplate("apparmor", f)
}
// Validate the profile file
func (f *AppArmorProfileFile) Validate() error {
if err := f.Preamble.Validate(); err != nil {
return err
}
for _, p := range f.Profiles {
if err := p.Validate(); err != nil {
return err
}
}
return nil
}
// GetDefaultProfile ensure a profile is always present in the profile file and
// return it, as a default profile.
func (f *AppArmorProfileFile) GetDefaultProfile() *Profile {

View file

@ -16,6 +16,10 @@ type Hat struct {
Rules Rules
}
func (r *Hat) Validate() error {
return nil
}
func (p *Hat) Less(other any) bool {
o, _ := other.(*Hat)
return p.Name < o.Name

View file

@ -5,6 +5,7 @@
package aa
import (
"fmt"
"slices"
)
@ -39,6 +40,13 @@ func newCapabilityFromLog(log map[string]string) Rule {
}
}
func (r *Capability) Validate() error {
if err := validateValues(r.Kind(), "name", r.Names); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Capability) Less(other any) bool {
o, _ := other.(*Capability)
for i := 0; i < len(r.Names) && i < len(o.Names); i++ {

View file

@ -30,6 +30,13 @@ func newChangeProfileFromLog(log map[string]string) Rule {
}
}
func (r *ChangeProfile) Validate() error {
if err := validateValues(r.Kind(), "mode", []string{r.ExecMode}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *ChangeProfile) Less(other any) bool {
o, _ := other.(*ChangeProfile)
if r.ExecMode != o.ExecMode {

View file

@ -5,6 +5,7 @@
package aa
import (
"fmt"
"slices"
)
@ -55,6 +56,13 @@ func newDbusFromLog(log map[string]string) Rule {
}
}
func (r *Dbus) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return validateValues(r.Kind(), "bus", []string{r.Bus})
}
func (r *Dbus) Less(other any) bool {
o, _ := other.(*Dbus)
for i := 0; i < len(r.Access) && i < len(o.Access); i++ {

View file

@ -81,6 +81,10 @@ func newFileFromLog(log map[string]string) Rule {
}
}
func (r *File) Validate() error {
return nil
}
func (r *File) Less(other any) bool {
o, _ := other.(*File)
letterR := getLetterIn(fileAlphabet, r.Path)
@ -140,6 +144,10 @@ func newLinkFromLog(log map[string]string) Rule {
}
}
func (r *Link) Validate() error {
return nil
}
func (r *Link) Less(other any) bool {
o, _ := other.(*Link)
if r.Path != o.Path {

View file

@ -4,7 +4,10 @@
package aa
import "slices"
import (
"fmt"
"slices"
)
const tokIOURING = "io_uring"
@ -30,6 +33,13 @@ func newIOUringFromLog(log map[string]string) Rule {
}
}
func (r *IOUring) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *IOUring) Less(other any) bool {
o, _ := other.(*IOUring)
if len(r.Access) != len(o.Access) {

View file

@ -42,6 +42,10 @@ func newMountConditionsFromLog(log map[string]string) MountConditions {
return MountConditions{FsType: log["fstype"]}
}
func (m MountConditions) Validate() error {
return validateValues(tokMOUNT, "flags", m.Options)
}
func (m MountConditions) Less(other MountConditions) bool {
if m.FsType != other.FsType {
return m.FsType < other.FsType
@ -71,6 +75,13 @@ func newMountFromLog(log map[string]string) Rule {
}
}
func (r *Mount) Validate() error {
if err := r.MountConditions.Validate(); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Mount) Less(other any) bool {
o, _ := other.(*Mount)
if r.Source != o.Source {
@ -120,6 +131,13 @@ func newUmountFromLog(log map[string]string) Rule {
}
}
func (r *Umount) Validate() error {
if err := r.MountConditions.Validate(); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Umount) Less(other any) bool {
o, _ := other.(*Umount)
if r.MountPoint != o.MountPoint {
@ -166,6 +184,13 @@ func newRemountFromLog(log map[string]string) Rule {
}
}
func (r *Remount) Validate() error {
if err := r.MountConditions.Validate(); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Remount) Less(other any) bool {
o, _ := other.(*Remount)
if r.MountPoint != o.MountPoint {

View file

@ -5,6 +5,7 @@
package aa
import (
"fmt"
"slices"
"strings"
)
@ -47,6 +48,16 @@ func newMqueueFromLog(log map[string]string) Rule {
}
}
func (r *Mqueue) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
if err := validateValues(r.Kind(), "type", []string{r.Type}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Mqueue) Less(other any) bool {
o, _ := other.(*Mqueue)
if len(r.Access) != len(o.Access) {

View file

@ -4,6 +4,10 @@
package aa
import (
"fmt"
)
const tokNETWORK = "network"
func init() {
@ -77,6 +81,19 @@ func newNetworkFromLog(log map[string]string) Rule {
}
}
func (r *Network) Validate() error {
if err := validateValues(r.Kind(), "domains", []string{r.Domain}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
if err := validateValues(r.Kind(), "type", []string{r.Type}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
if err := validateValues(r.Kind(), "protocol", []string{r.Protocol}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Network) Less(other any) bool {
o, _ := other.(*Network)
if r.Domain != o.Domain {

View file

@ -24,6 +24,10 @@ func newPivotRootFromLog(log map[string]string) Rule {
}
}
func (r *PivotRoot) Validate() error {
return nil
}
func (r *PivotRoot) Less(other any) bool {
o, _ := other.(*PivotRoot)
if r.OldRoot != o.OldRoot {

View file

@ -21,6 +21,10 @@ type Comment struct {
RuleBase
}
func (r *Comment) Validate() error {
return nil
}
func (r *Comment) Less(other any) bool {
return false
}
@ -51,6 +55,10 @@ type Abi struct {
IsMagic bool
}
func (r *Abi) Validate() error {
return nil
}
func (r *Abi) Less(other any) bool {
o, _ := other.(*Abi)
if r.Path != o.Path {
@ -82,6 +90,10 @@ type Alias struct {
RewrittenPath string
}
func (r *Alias) Validate() error {
return nil
}
func (r Alias) Less(other any) bool {
o, _ := other.(*Alias)
if r.Path != o.Path {
@ -114,6 +126,10 @@ type Include struct {
IsMagic bool
}
func (r *Include) Validate() error {
return nil
}
func (r *Include) Less(other any) bool {
o, _ := other.(*Include)
if r.Path == o.Path {
@ -149,6 +165,10 @@ type Variable struct {
Define bool
}
func (r *Variable) Validate() error {
return nil
}
func (r *Variable) Less(other any) bool {
o, _ := other.(*Variable)
if r.Name != o.Name {

View file

@ -5,6 +5,7 @@
package aa
import (
"fmt"
"maps"
"reflect"
"slices"
@ -18,6 +19,17 @@ const (
tokPROFILE = "profile"
)
func init() {
requirements[tokPROFILE] = requirement{
tokFLAGS: {
"enforce", "complain", "kill", "default_allow", "unconfined",
"prompt", "audit", "mediate_deleted", "attach_disconnected",
"attach_disconneced.path=", "chroot_relative", "debug",
"interruptible", "kill", "kill.signal=",
},
}
}
// Profile represents a single AppArmor profile.
type Profile struct {
RuleBase
@ -33,6 +45,13 @@ type Header struct {
Flags []string
}
func (r *Profile) Validate() error {
if err := validateValues(r.Kind(), tokFLAGS, r.Flags); err != nil {
return fmt.Errorf("profile %s: %w", r.Name, err)
}
return r.Rules.Validate()
}
func (p *Profile) Less(other any) bool {
o, _ := other.(*Profile)
if p.Name != o.Name {

View file

@ -5,6 +5,7 @@
package aa
import (
"fmt"
"slices"
)
@ -34,6 +35,13 @@ func newPtraceFromLog(log map[string]string) Rule {
}
}
func (r *Ptrace) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Ptrace) Less(other any) bool {
o, _ := other.(*Ptrace)
if len(r.Access) != len(o.Access) {

View file

@ -35,6 +35,13 @@ func newRlimitFromLog(log map[string]string) Rule {
}
}
func (r *Rlimit) Validate() error {
if err := validateValues(r.Kind(), "keys", []string{r.Key}); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Rlimit) Less(other any) bool {
o, _ := other.(*Rlimit)
if r.Key != o.Key {

View file

@ -28,6 +28,7 @@ const (
// Rule generic interface for all AppArmor rules
type Rule interface {
Validate() error
Less(other any) bool
Equals(other any) bool
String() string
@ -37,6 +38,15 @@ type Rule interface {
type Rules []Rule
func (r Rules) Validate() error {
for _, rule := range r {
if err := rule.Validate(); err != nil {
return err
}
}
return nil
}
func (r Rules) String() string {
return renderTemplate("rules", r)
}
@ -82,6 +92,18 @@ func Must[T any](v T, err error) T {
return v
}
func validateValues(rule string, key string, values []string) error {
for _, v := range values {
if v == "" {
continue
}
if !slices.Contains(requirements[rule][key], v) {
return fmt.Errorf("invalid mode '%s'", v)
}
}
return nil
}
// Helper function to convert a string to a slice of rule values according to
// the rule requirements as defined in the requirements map.
func toValues(rule string, key string, input string) ([]string, error) {

View file

@ -5,6 +5,7 @@
package aa
import (
"fmt"
"slices"
)
@ -49,6 +50,16 @@ func newSignalFromLog(log map[string]string) Rule {
}
}
func (r *Signal) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
if err := validateValues(r.Kind(), "set", r.Set); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Signal) Less(other any) bool {
o, _ := other.(*Signal)
if len(r.Access) != len(o.Access) {

View file

@ -4,7 +4,10 @@
package aa
import "slices"
import (
"fmt"
"slices"
)
const tokUNIX = "unix"
@ -48,6 +51,13 @@ func newUnixFromLog(log map[string]string) Rule {
}
}
func (r *Unix) Validate() error {
if err := validateValues(r.Kind(), "access", r.Access); err != nil {
return fmt.Errorf("%s: %w", r, err)
}
return nil
}
func (r *Unix) Less(other any) bool {
o, _ := other.(*Unix)
if len(r.Access) != len(o.Access) {

View file

@ -20,6 +20,10 @@ func newUsernsFromLog(log map[string]string) Rule {
}
}
func (r *Userns) Validate() error {
return nil
}
func (r *Userns) Less(other any) bool {
o, _ := other.(*Userns)
if r.Create != o.Create {