sys,fw: fixed race condition creating system rules

Hard to reproduce, but not impossible
This commit is contained in:
Gustavo Iñiguez Goia 2023-05-24 14:26:58 +02:00
parent 631f27ee24
commit 9f71c15e57
Failed to generate hash of commit
6 changed files with 50 additions and 18 deletions

View file

@ -34,7 +34,7 @@ func (n *Nft) AddChain(name, table, family string, priority *nftables.ChainPrior
if family == "" {
family = exprs.NFT_FAMILY_INET
}
tbl := getTable(table, family)
tbl := n.getTable(table, family)
if tbl == nil {
log.Error("%s addChain, Error getting table: %s, %s", logTag, table, family)
return nil
@ -87,7 +87,7 @@ func (n *Nft) getChain(name string, table *nftables.Table, family string) *nftab
// regular chains are user-defined chains, to better organize fw rules.
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_regular_chains
func (n *Nft) addRegularChain(name, table, family string) error {
tbl := getTable(table, family)
tbl := n.getTable(table, family)
if tbl == nil {
return fmt.Errorf("%s addRegularChain, Error getting table: %s, %s", logTag, table, family)
}
@ -111,7 +111,7 @@ func (n *Nft) addInterceptionChains() error {
filterPolicy = nftables.ChainPolicyAccept
manglePolicy = nftables.ChainPolicyAccept
tbl := getTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET)
tbl := n.getTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET)
if tbl != nil {
key := getChainKey(exprs.NFT_HOOK_INPUT, tbl)
ch, found := sysChains.Load(key)
@ -119,7 +119,7 @@ func (n *Nft) addInterceptionChains() error {
filterPolicy = *ch.(*nftables.Chain).Policy
}
}
tbl = getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
tbl = n.getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
if tbl != nil {
key := getChainKey(exprs.NFT_HOOK_OUTPUT, tbl)
ch, found := sysChains.Load(key)

View file

@ -130,8 +130,8 @@ func (n *Nft) EnableInterception() {
// DisableInterception removes firewall rules to intercept outbound connections.
func (n *Nft) DisableInterception(logErrors bool) {
n.delInterceptionRules()
n.StopCheckingRules()
n.delInterceptionRules()
}
// CleanRules deletes the rules we added.

View file

@ -13,7 +13,7 @@ import (
// rules examples: https://github.com/google/nftables/blob/master/nftables_test.go
func (n *Nft) buildICMPRule(table, family string, icmpProtoVersion string, icmpOptions []*config.ExprValues) *[]expr.Any {
tbl := getTable(table, family)
tbl := n.getTable(table, family)
if tbl == nil {
return nil
}
@ -140,7 +140,7 @@ Exit:
}
func (n *Nft) buildProtocolRule(table, family, ports string, cmpOp *expr.CmpOp) *[]expr.Any {
tbl := getTable(table, family)
tbl := n.getTable(table, family)
if tbl == nil {
return nil
}

View file

@ -22,7 +22,7 @@ func (n *Nft) QueueDNSResponses(enable bool, logError bool) (error, error) {
}
families := []string{exprs.NFT_FAMILY_INET}
for _, fam := range families {
table := getTable(exprs.NFT_CHAIN_FILTER, fam)
table := n.getTable(exprs.NFT_CHAIN_FILTER, fam)
chain := getChain(exprs.NFT_HOOK_INPUT, table)
if table == nil {
log.Error("QueueDNSResponses() Error getting table: %s-filter", fam)
@ -82,7 +82,7 @@ func (n *Nft) QueueConnections(enable bool, logError bool) (error, error) {
if n.conn == nil {
return nil, fmt.Errorf("nftables QueueConnections: netlink connection not active")
}
table := getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
table := n.getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
if table == nil {
return nil, fmt.Errorf("QueueConnections() Error getting table mangle-inet")
}
@ -130,7 +130,7 @@ func (n *Nft) QueueConnections(enable bool, logError bool) (error, error) {
}
func (n *Nft) insertRule(chain, table, family string, position uint64, exprs *[]expr.Any) error {
tbl := getTable(table, family)
tbl := n.getTable(table, family)
if tbl == nil {
return fmt.Errorf("%s addRule, Error getting table: %s, %s", logTag, table, family)
}
@ -157,7 +157,7 @@ func (n *Nft) insertRule(chain, table, family string, position uint64, exprs *[]
}
func (n *Nft) addRule(chain, table, family string, position uint64, exprs *[]expr.Any) error {
tbl := getTable(table, family)
tbl := n.getTable(table, family)
if tbl == nil {
return fmt.Errorf("%s addRule, Error getting table: %s, %s", logTag, table, family)
}

View file

@ -13,16 +13,48 @@ import (
"github.com/google/uuid"
)
// store of tables added to the system
type sysTablesT struct {
tables map[string]*nftables.Table
sync.RWMutex
}
func (t *sysTablesT) Add(name string, tbl *nftables.Table) {
t.Lock()
defer t.Unlock()
t.tables[name] = tbl
}
func (t *sysTablesT) Get(name string) *nftables.Table {
t.RLock()
defer t.RUnlock()
return t.tables[name]
}
func (t *sysTablesT) List() map[string]*nftables.Table {
t.RLock()
defer t.RUnlock()
return t.tables
}
func (t *sysTablesT) Del(name string) {
t.Lock()
defer t.Unlock()
delete(t.tables, name)
}
var (
logTag = "nftables:"
sysTables map[string]*nftables.Table
sysTables *sysTablesT
sysChains *sync.Map
origSysChains map[string]*nftables.Chain
sysSets []*nftables.Set
)
func initMapsStore() {
sysTables = make(map[string]*nftables.Table)
sysTables = &sysTablesT{
tables: make(map[string]*nftables.Table),
}
sysChains = &sync.Map{}
origSysChains = make(map[string]*nftables.Chain)
}

View file

@ -21,12 +21,12 @@ func (n *Nft) AddTable(name, family string) (*nftables.Table, error) {
return nil, fmt.Errorf("%s error adding system firewall table: %s, family: %s (%d)", logTag, name, family, famCode)
}
key := getTableKey(name, family)
sysTables[key] = tbl
sysTables.Add(key, tbl)
return tbl, nil
}
func getTable(name, family string) *nftables.Table {
return sysTables[getTableKey(name, family)]
func (n *Nft) getTable(name, family string) *nftables.Table {
return sysTables.Get(getTableKey(name, family))
}
func getTableKey(name string, family interface{}) string {
@ -73,7 +73,7 @@ func (n *Nft) nonSystemRules(tbl *nftables.Table) int {
}
func (n *Nft) delSystemTables() {
for k, tbl := range sysTables {
for k, tbl := range sysTables.List() {
if n.nonSystemRules(tbl) != 0 {
continue
}
@ -82,6 +82,6 @@ func (n *Nft) delSystemTables() {
log.Warning("error deleting system table: %s", k)
continue
}
delete(sysTables, k)
sysTables.Del(k)
}
}