From 9f71c15e573ca96026c8e158ed72de7a2050b563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20I=C3=B1iguez=20Goia?= Date: Wed, 24 May 2023 14:26:58 +0200 Subject: [PATCH] sys,fw: fixed race condition creating system rules Hard to reproduce, but not impossible --- daemon/firewall/nftables/chains.go | 8 +++--- daemon/firewall/nftables/nftables.go | 2 +- daemon/firewall/nftables/rule_helpers.go | 4 +-- daemon/firewall/nftables/rules.go | 8 +++--- daemon/firewall/nftables/system.go | 36 ++++++++++++++++++++++-- daemon/firewall/nftables/tables.go | 10 +++---- 6 files changed, 50 insertions(+), 18 deletions(-) diff --git a/daemon/firewall/nftables/chains.go b/daemon/firewall/nftables/chains.go index cc8ede27..6d803dea 100644 --- a/daemon/firewall/nftables/chains.go +++ b/daemon/firewall/nftables/chains.go @@ -34,7 +34,7 @@ func (n *Nft) AddChain(name, table, family string, priority *nftables.ChainPrior if family == "" { family = exprs.NFT_FAMILY_INET } - tbl := getTable(table, family) + tbl := n.getTable(table, family) if tbl == nil { log.Error("%s addChain, Error getting table: %s, %s", logTag, table, family) return nil @@ -87,7 +87,7 @@ func (n *Nft) getChain(name string, table *nftables.Table, family string) *nftab // regular chains are user-defined chains, to better organize fw rules. // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_regular_chains func (n *Nft) addRegularChain(name, table, family string) error { - tbl := getTable(table, family) + tbl := n.getTable(table, family) if tbl == nil { return fmt.Errorf("%s addRegularChain, Error getting table: %s, %s", logTag, table, family) } @@ -111,7 +111,7 @@ func (n *Nft) addInterceptionChains() error { filterPolicy = nftables.ChainPolicyAccept manglePolicy = nftables.ChainPolicyAccept - tbl := getTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET) + tbl := n.getTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET) if tbl != nil { key := getChainKey(exprs.NFT_HOOK_INPUT, tbl) ch, found := sysChains.Load(key) @@ -119,7 +119,7 @@ func (n *Nft) addInterceptionChains() error { filterPolicy = *ch.(*nftables.Chain).Policy } } - tbl = getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET) + tbl = n.getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET) if tbl != nil { key := getChainKey(exprs.NFT_HOOK_OUTPUT, tbl) ch, found := sysChains.Load(key) diff --git a/daemon/firewall/nftables/nftables.go b/daemon/firewall/nftables/nftables.go index 5d05ea87..3d95c68a 100644 --- a/daemon/firewall/nftables/nftables.go +++ b/daemon/firewall/nftables/nftables.go @@ -130,8 +130,8 @@ func (n *Nft) EnableInterception() { // DisableInterception removes firewall rules to intercept outbound connections. func (n *Nft) DisableInterception(logErrors bool) { - n.delInterceptionRules() n.StopCheckingRules() + n.delInterceptionRules() } // CleanRules deletes the rules we added. diff --git a/daemon/firewall/nftables/rule_helpers.go b/daemon/firewall/nftables/rule_helpers.go index 73a20092..d2043c71 100644 --- a/daemon/firewall/nftables/rule_helpers.go +++ b/daemon/firewall/nftables/rule_helpers.go @@ -13,7 +13,7 @@ import ( // rules examples: https://github.com/google/nftables/blob/master/nftables_test.go func (n *Nft) buildICMPRule(table, family string, icmpProtoVersion string, icmpOptions []*config.ExprValues) *[]expr.Any { - tbl := getTable(table, family) + tbl := n.getTable(table, family) if tbl == nil { return nil } @@ -140,7 +140,7 @@ Exit: } func (n *Nft) buildProtocolRule(table, family, ports string, cmpOp *expr.CmpOp) *[]expr.Any { - tbl := getTable(table, family) + tbl := n.getTable(table, family) if tbl == nil { return nil } diff --git a/daemon/firewall/nftables/rules.go b/daemon/firewall/nftables/rules.go index d70424c1..607e4d6a 100644 --- a/daemon/firewall/nftables/rules.go +++ b/daemon/firewall/nftables/rules.go @@ -22,7 +22,7 @@ func (n *Nft) QueueDNSResponses(enable bool, logError bool) (error, error) { } families := []string{exprs.NFT_FAMILY_INET} for _, fam := range families { - table := getTable(exprs.NFT_CHAIN_FILTER, fam) + table := n.getTable(exprs.NFT_CHAIN_FILTER, fam) chain := getChain(exprs.NFT_HOOK_INPUT, table) if table == nil { log.Error("QueueDNSResponses() Error getting table: %s-filter", fam) @@ -82,7 +82,7 @@ func (n *Nft) QueueConnections(enable bool, logError bool) (error, error) { if n.conn == nil { return nil, fmt.Errorf("nftables QueueConnections: netlink connection not active") } - table := getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET) + table := n.getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET) if table == nil { return nil, fmt.Errorf("QueueConnections() Error getting table mangle-inet") } @@ -130,7 +130,7 @@ func (n *Nft) QueueConnections(enable bool, logError bool) (error, error) { } func (n *Nft) insertRule(chain, table, family string, position uint64, exprs *[]expr.Any) error { - tbl := getTable(table, family) + tbl := n.getTable(table, family) if tbl == nil { return fmt.Errorf("%s addRule, Error getting table: %s, %s", logTag, table, family) } @@ -157,7 +157,7 @@ func (n *Nft) insertRule(chain, table, family string, position uint64, exprs *[] } func (n *Nft) addRule(chain, table, family string, position uint64, exprs *[]expr.Any) error { - tbl := getTable(table, family) + tbl := n.getTable(table, family) if tbl == nil { return fmt.Errorf("%s addRule, Error getting table: %s, %s", logTag, table, family) } diff --git a/daemon/firewall/nftables/system.go b/daemon/firewall/nftables/system.go index 0681ffb3..33f3377f 100644 --- a/daemon/firewall/nftables/system.go +++ b/daemon/firewall/nftables/system.go @@ -13,16 +13,48 @@ import ( "github.com/google/uuid" ) +// store of tables added to the system +type sysTablesT struct { + tables map[string]*nftables.Table + sync.RWMutex +} + +func (t *sysTablesT) Add(name string, tbl *nftables.Table) { + t.Lock() + defer t.Unlock() + t.tables[name] = tbl +} + +func (t *sysTablesT) Get(name string) *nftables.Table { + t.RLock() + defer t.RUnlock() + return t.tables[name] +} + +func (t *sysTablesT) List() map[string]*nftables.Table { + t.RLock() + defer t.RUnlock() + return t.tables +} + +func (t *sysTablesT) Del(name string) { + t.Lock() + defer t.Unlock() + delete(t.tables, name) +} + var ( logTag = "nftables:" - sysTables map[string]*nftables.Table + sysTables *sysTablesT sysChains *sync.Map origSysChains map[string]*nftables.Chain sysSets []*nftables.Set ) func initMapsStore() { - sysTables = make(map[string]*nftables.Table) + sysTables = &sysTablesT{ + tables: make(map[string]*nftables.Table), + } sysChains = &sync.Map{} origSysChains = make(map[string]*nftables.Chain) } diff --git a/daemon/firewall/nftables/tables.go b/daemon/firewall/nftables/tables.go index 191fe68c..5efa85c9 100644 --- a/daemon/firewall/nftables/tables.go +++ b/daemon/firewall/nftables/tables.go @@ -21,12 +21,12 @@ func (n *Nft) AddTable(name, family string) (*nftables.Table, error) { return nil, fmt.Errorf("%s error adding system firewall table: %s, family: %s (%d)", logTag, name, family, famCode) } key := getTableKey(name, family) - sysTables[key] = tbl + sysTables.Add(key, tbl) return tbl, nil } -func getTable(name, family string) *nftables.Table { - return sysTables[getTableKey(name, family)] +func (n *Nft) getTable(name, family string) *nftables.Table { + return sysTables.Get(getTableKey(name, family)) } func getTableKey(name string, family interface{}) string { @@ -73,7 +73,7 @@ func (n *Nft) nonSystemRules(tbl *nftables.Table) int { } func (n *Nft) delSystemTables() { - for k, tbl := range sysTables { + for k, tbl := range sysTables.List() { if n.nonSystemRules(tbl) != 0 { continue } @@ -82,6 +82,6 @@ func (n *Nft) delSystemTables() { log.Warning("error deleting system table: %s", k) continue } - delete(sysTables, k) + sysTables.Del(k) } }