sys,fw: export some internal utils.

This commit is contained in:
Gustavo Iñiguez Goia 2023-07-06 13:36:57 +02:00
parent b366f5f8b3
commit d474e7f57a
Failed to generate hash of commit
9 changed files with 70 additions and 67 deletions

View file

@ -19,8 +19,8 @@ func getChainKey(name string, table *nftables.Table) string {
return fmt.Sprintf("%s-%s-%d", name, table.Name, table.Family)
}
// get an existing chain
func getChain(name string, table *nftables.Table) *nftables.Chain {
// GetChain gets an existing chain
func GetChain(name string, table *nftables.Table) *nftables.Chain {
key := getChainKey(name, table)
if ch, ok := sysChains.Load(key); ok {
return ch.(*nftables.Chain)
@ -34,7 +34,7 @@ func (n *Nft) AddChain(name, table, family string, priority *nftables.ChainPrior
if family == "" {
family = exprs.NFT_FAMILY_INET
}
tbl := n.getTable(table, family)
tbl := n.GetTable(table, family)
if tbl == nil {
log.Error("%s addChain, Error getting table: %s, %s", logTag, table, family)
return nil
@ -51,10 +51,10 @@ func (n *Nft) AddChain(name, table, family string, priority *nftables.ChainPrior
sysChains.Delete(key)
}
chain.Policy = &policy
n.conn.AddChain(chain)
n.Conn.AddChain(chain)
} else {
// nft list chains
chain = n.conn.AddChain(&nftables.Chain{
chain = n.Conn.AddChain(&nftables.Chain{
Name: strings.ToLower(name),
Table: tbl,
Type: ctype,
@ -74,9 +74,9 @@ func (n *Nft) AddChain(name, table, family string, priority *nftables.ChainPrior
// getChain checks if a chain in the given table exists.
func (n *Nft) getChain(name string, table *nftables.Table, family string) *nftables.Chain {
if chains, err := n.conn.ListChains(); err == nil {
if chains, err := n.Conn.ListChains(); err == nil {
for _, c := range chains {
if name == c.Name && table.Name == c.Table.Name && getFamilyCode(family) == c.Table.Family {
if name == c.Name && table.Name == c.Table.Name && GetFamilyCode(family) == c.Table.Family {
return c
}
}
@ -87,12 +87,12 @@ func (n *Nft) getChain(name string, table *nftables.Table, family string) *nftab
// regular chains are user-defined chains, to better organize fw rules.
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains#Adding_regular_chains
func (n *Nft) addRegularChain(name, table, family string) error {
tbl := n.getTable(table, family)
tbl := n.GetTable(table, family)
if tbl == nil {
return fmt.Errorf("%s addRegularChain, Error getting table: %s, %s", logTag, table, family)
}
chain := n.conn.AddChain(&nftables.Chain{
chain := n.Conn.AddChain(&nftables.Chain{
Name: name,
Table: tbl,
})
@ -105,13 +105,13 @@ func (n *Nft) addRegularChain(name, table, family string) error {
return nil
}
func (n *Nft) addInterceptionChains() error {
func (n *Nft) AddInterceptionChains() error {
var filterPolicy nftables.ChainPolicy
var manglePolicy nftables.ChainPolicy
filterPolicy = nftables.ChainPolicyAccept
manglePolicy = nftables.ChainPolicyAccept
tbl := n.getTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET)
tbl := n.GetTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET)
if tbl != nil {
key := getChainKey(exprs.NFT_HOOK_INPUT, tbl)
ch, found := sysChains.Load(key)
@ -119,7 +119,7 @@ func (n *Nft) addInterceptionChains() error {
filterPolicy = *ch.(*nftables.Chain).Policy
}
}
tbl = n.getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
tbl = n.GetTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
if tbl != nil {
key := getChainKey(exprs.NFT_HOOK_OUTPUT, tbl)
ch, found := sysChains.Load(key)
@ -140,8 +140,8 @@ func (n *Nft) addInterceptionChains() error {
log.Error("(1) Error adding interception chain mangle-output-inet, trying with type Filter instead of Route")
// Workaround for kernels 4.x and maybe others.
// @see firewall/nftables/utils.go:getChainPriority()
chainPrio, chainType := getChainPriority(exprs.NFT_FAMILY_INET, exprs.NFT_CHAIN_MANGLE, exprs.NFT_HOOK_OUTPUT)
// @see firewall/nftables/utils.go:GetChainPriority()
chainPrio, chainType := GetChainPriority(exprs.NFT_FAMILY_INET, exprs.NFT_CHAIN_MANGLE, exprs.NFT_HOOK_OUTPUT)
n.AddChain(exprs.NFT_HOOK_OUTPUT, exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET,
chainPrio, chainType, nftables.ChainHookOutput, manglePolicy)
if !n.Commit() {
@ -153,7 +153,7 @@ func (n *Nft) addInterceptionChains() error {
}
func (n *Nft) delChain(chain *nftables.Chain) error {
n.conn.DelChain(chain)
n.Conn.DelChain(chain)
sysChains.Delete(getChainKey(chain.Name, chain.Table))
if !n.Commit() {
return fmt.Errorf("[nftables] error deleting chain %s, %s", chain.Name, chain.Table.Name)
@ -166,7 +166,7 @@ func (n *Nft) delChain(chain *nftables.Chain) error {
// If the user configures the chain policy to Drop, we need to set it back to Accept,
// in order not to block incoming connections.
func (n *Nft) backupExistingChains() {
if chains, err := n.conn.ListChains(); err == nil {
if chains, err := n.Conn.ListChains(); err == nil {
for _, c := range chains {
if c.Policy != nil && *c.Policy == nftables.ChainPolicyAccept {
log.Debug("%s backing up existing chain with policy ACCEPT: %s, %s", logTag, c.Name, c.Table.Name)
@ -180,7 +180,7 @@ func (n *Nft) restoreBackupChains() {
for _, c := range origSysChains {
log.Debug("%s Restoring chain policy to accept: %s, %s", logTag, c.Name, c.Table.Name)
*c.Policy = nftables.ChainPolicyAccept
n.conn.AddChain(c)
n.Conn.AddChain(c)
}
n.Commit()
}

View file

@ -14,14 +14,14 @@ func (n *Nft) AreRulesLoaded() bool {
defer n.Unlock()
nRules := 0
chains, err := n.conn.ListChains()
chains, err := n.Conn.ListChains()
if err != nil {
log.Warning("[nftables] error listing nftables chains: %s", err)
return false
}
for _, c := range chains {
rules, err := n.conn.GetRule(c.Table, c)
rules, err := n.Conn.GetRule(c.Table, c)
if err != nil {
log.Warning("[nftables] Error listing rules: %s", err)
continue

View file

@ -45,7 +45,7 @@ type Nft struct {
config.Config
common.Common
conn *nftables.Conn
Conn *nftables.Conn
chains iptables.SystemChains
}
@ -75,9 +75,9 @@ func (n *Nft) Init(qNum *int) {
if n.IsRunning() {
return
}
initMapsStore()
InitMapsStore()
n.SetQueueNum(qNum)
n.conn = NewNft()
n.Conn = NewNft()
// In order to clean up any existing firewall rule before start,
// we need to load the fw configuration first to know what rules
@ -109,11 +109,11 @@ func (n *Nft) Stop() {
// EnableInterception adds firewall rules to intercept connections
func (n *Nft) EnableInterception() {
if err := n.addInterceptionTables(); err != nil {
if err := n.AddInterceptionTables(); err != nil {
log.Error("Error while adding interception tables: %s", err)
return
}
if err := n.addInterceptionChains(); err != nil {
if err := n.AddInterceptionChains(); err != nil {
log.Error("Error while adding interception chains: %s", err)
return
}
@ -144,7 +144,7 @@ func (n *Nft) CleanRules(logErrors bool) {
// You add rules, chains or tables, and after calling to Flush() they're added to the system.
// NOTE: it's very important not to call Flush() without queued tasks.
func (n *Nft) Commit() bool {
if err := n.conn.Flush(); err != nil {
if err := n.Conn.Flush(); err != nil {
log.Warning("%s error applying changes: %s", logTag, err)
return false
}

View file

@ -187,7 +187,7 @@ func (n *Nft) parseExpression(table, chain, family string, expression *config.Ex
counterObj.Packets = 1
}
}
n.conn.AddObj(counterObj)
n.Conn.AddObj(counterObj)
exprList = append(exprList, *exprs.NewExprCounter(defaultCounterName)...)
}

View file

@ -14,7 +14,7 @@ import (
// rules examples: https://github.com/google/nftables/blob/master/nftables_test.go
func (n *Nft) buildICMPRule(table, family string, icmpProtoVersion string, icmpOptions []*config.ExprValues) *[]expr.Any {
tbl := n.getTable(table, family)
tbl := n.GetTable(table, family)
if tbl == nil {
return nil
}
@ -86,7 +86,7 @@ func (n *Nft) buildICMPRule(table, family string, icmpProtoVersion string, icmpO
Table: tbl,
KeyType: setType,
}
if err := n.conn.AddSet(set, setElements); err != nil {
if err := n.Conn.AddSet(set, setElements); err != nil {
log.Warning("%s AddSet() error: %s", logTag, err)
return nil
}
@ -155,7 +155,7 @@ Exit:
// [ payload load 2b @ transport header + 2 => reg 1 ]
// [ cmp eq reg 1 0x00003500 ]
func (n *Nft) buildL4ProtoRule(table, family, l4prots string, cmpOp *expr.CmpOp) (*[]expr.Any, error) {
tbl := n.getTable(table, family)
tbl := n.GetTable(table, family)
if tbl == nil {
return nil, fmt.Errorf("Invalid table (%s, %s)", table, family)
}
@ -168,7 +168,7 @@ func (n *Nft) buildL4ProtoRule(table, family, l4prots string, cmpOp *expr.CmpOp)
KeyType: nftables.TypeInetProto,
}
protoSet := exprs.NewExprProtoSet(l4prots)
if err := n.conn.AddSet(set, *protoSet); err != nil {
if err := n.Conn.AddSet(set, *protoSet); err != nil {
log.Warning("%s protoSet, AddSet() error: %s", logTag, err)
return nil, err
}
@ -186,7 +186,7 @@ func (n *Nft) buildL4ProtoRule(table, family, l4prots string, cmpOp *expr.CmpOp)
}
func (n *Nft) buildPortsRule(table, family, ports string, cmpOp *expr.CmpOp) (*[]expr.Any, error) {
tbl := n.getTable(table, family)
tbl := n.GetTable(table, family)
if tbl == nil {
return nil, fmt.Errorf("Invalid table (%s, %s)", table, family)
}
@ -199,7 +199,7 @@ func (n *Nft) buildPortsRule(table, family, ports string, cmpOp *expr.CmpOp) (*[
KeyType: nftables.TypeInetService,
}
setElements := exprs.NewExprPortSet(ports)
if err := n.conn.AddSet(set, *setElements); err != nil {
if err := n.Conn.AddSet(set, *setElements); err != nil {
log.Warning("%s portSet, AddSet() error: %s", logTag, err)
return nil, err
}

View file

@ -17,13 +17,13 @@ import (
// This rule must be added in top of the system rules, otherwise it may get bypassed.
// nft insert rule ip filter input udp sport 53 queue num 0 bypass
func (n *Nft) QueueDNSResponses(enable bool, logError bool) (error, error) {
if n.conn == nil {
if n.Conn == nil {
return nil, nil
}
families := []string{exprs.NFT_FAMILY_INET}
for _, fam := range families {
table := n.getTable(exprs.NFT_CHAIN_FILTER, fam)
chain := getChain(exprs.NFT_HOOK_INPUT, table)
table := n.GetTable(exprs.NFT_CHAIN_FILTER, fam)
chain := GetChain(exprs.NFT_HOOK_INPUT, table)
if table == nil {
log.Error("QueueDNSResponses() Error getting table: %s-filter", fam)
continue
@ -34,7 +34,7 @@ func (n *Nft) QueueDNSResponses(enable bool, logError bool) (error, error) {
}
// nft list ruleset -a
n.conn.InsertRule(&nftables.Rule{
n.Conn.InsertRule(&nftables.Rule{
Position: 0,
Table: table,
Chain: chain,
@ -79,19 +79,19 @@ func (n *Nft) QueueDNSResponses(enable bool, logError bool) (error, error) {
// rules above this one to exclude a service/app from being intercepted.
// nft insert rule ip mangle OUTPUT ct state new queue num 0 bypass
func (n *Nft) QueueConnections(enable bool, logError bool) (error, error) {
if n.conn == nil {
if n.Conn == nil {
return nil, fmt.Errorf("nftables QueueConnections: netlink connection not active")
}
table := n.getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
table := n.GetTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
if table == nil {
return nil, fmt.Errorf("QueueConnections() Error getting table mangle-inet")
}
chain := getChain(exprs.NFT_HOOK_OUTPUT, table)
chain := GetChain(exprs.NFT_HOOK_OUTPUT, table)
if chain == nil {
return nil, fmt.Errorf("QueueConnections() Error getting outputChain: output-%s", table.Name)
}
n.conn.AddRule(&nftables.Rule{
n.Conn.AddRule(&nftables.Rule{
Position: 0,
Table: table,
Chain: chain,
@ -130,7 +130,7 @@ func (n *Nft) QueueConnections(enable bool, logError bool) (error, error) {
}
func (n *Nft) insertRule(chain, table, family string, position uint64, exprs *[]expr.Any) error {
tbl := n.getTable(table, family)
tbl := n.GetTable(table, family)
if tbl == nil {
return fmt.Errorf("%s addRule, Error getting table: %s, %s", logTag, table, family)
}
@ -148,7 +148,7 @@ func (n *Nft) insertRule(chain, table, family string, position uint64, exprs *[]
Exprs: *exprs,
UserData: []byte(systemRuleKey),
}
n.conn.InsertRule(rule)
n.Conn.InsertRule(rule)
if !n.Commit() {
return fmt.Errorf("%s Error adding rule", logTag)
}
@ -156,8 +156,9 @@ func (n *Nft) insertRule(chain, table, family string, position uint64, exprs *[]
return nil
}
func (n *Nft) addRule(chain, table, family string, position uint64, key string, exprs *[]expr.Any) (*nftables.Rule, error) {
tbl := n.getTable(table, family)
// AddRule adds a rule to the system.
func (n *Nft) AddRule(chain, table, family string, position uint64, key string, exprs *[]expr.Any) (*nftables.Rule, error) {
tbl := n.GetTable(table, family)
if tbl == nil {
return nil, fmt.Errorf("%s addRule, Error getting table: %s, %s", logTag, table, family)
}
@ -175,7 +176,7 @@ func (n *Nft) addRule(chain, table, family string, position uint64, key string,
Exprs: *exprs,
UserData: []byte(key),
}
n.conn.AddRule(rule)
n.Conn.AddRule(rule)
if !n.Commit() {
return nil, fmt.Errorf("%s Error adding rule", logTag)
}
@ -184,12 +185,12 @@ func (n *Nft) addRule(chain, table, family string, position uint64, key string,
}
func (n *Nft) delRulesByKey(key string) error {
chains, err := n.conn.ListChains()
chains, err := n.Conn.ListChains()
if err != nil {
return fmt.Errorf("error listing nftables chains (%s): %s", key, err)
}
for _, c := range chains {
rules, err := n.conn.GetRule(c.Table, c)
rules, err := n.Conn.GetRule(c.Table, c)
if err != nil {
log.Warning("Error listing rules (%s): %s", key, err)
continue
@ -200,7 +201,7 @@ func (n *Nft) delRulesByKey(key string) error {
continue
}
// just passing the r object doesn't work.
if err := n.conn.DelRule(&nftables.Rule{
if err := n.Conn.DelRule(&nftables.Rule{
Table: c.Table,
Chain: c,
Handle: r.Handle,

View file

@ -51,7 +51,8 @@ var (
sysSets []*nftables.Set
)
func initMapsStore() {
// InitMapsStore initializes internal stores of chains and maps.
func InitMapsStore() {
sysTables = &sysTablesT{
tables: make(map[string]*nftables.Table),
}
@ -81,8 +82,8 @@ func (n *Nft) CreateSystemRule(chain *config.FwChain, logErrors bool) bool {
chainPolicy = nftables.ChainPolicyDrop
}
chainHook := getHook(chain.Hook)
chainPrio, chainType := getChainPriority(chain.Family, chain.Type, chain.Hook)
chainHook := GetHook(chain.Hook)
chainPrio, chainType := GetChainPriority(chain.Family, chain.Type, chain.Hook)
if chainPrio == nil {
log.Warning("%s Invalid system firewall combination: %s, %s", logTag, chain.Type, chain.Hook)
return false
@ -144,7 +145,7 @@ func (n *Nft) DeleteSystemRules(force, restoreExistingChains, logErrors bool) {
n.restoreBackupChains()
}
if force {
n.delSystemTables()
n.DelSystemTables()
}
}
@ -160,7 +161,7 @@ func (n *Nft) AddSystemRule(rule *config.FwRule, chain *config.FwChain) (err4, e
}
}
if len(exprList) > 0 {
exprVerdict := exprs.NewExprVerdict(rule.Target, rule.TargetParameters)
exprVerdict := exprs.NewExprVerdict(chain.Family, rule.Target, rule.TargetParameters)
exprList = append(exprList, *exprVerdict...)
if err := n.insertRule(chain.Name, chain.Table, chain.Family, rule.Position, &exprList); err != nil {
log.Warning("error adding rule: %v", rule)

View file

@ -10,12 +10,12 @@ import (
// AddTable adds a new table to nftables.
func (n *Nft) AddTable(name, family string) (*nftables.Table, error) {
famCode := getFamilyCode(family)
famCode := GetFamilyCode(family)
tbl := &nftables.Table{
Family: famCode,
Name: name,
}
n.conn.AddTable(tbl)
n.Conn.AddTable(tbl)
if !n.Commit() {
return nil, fmt.Errorf("%s error adding system firewall table: %s, family: %s (%d)", logTag, name, family, famCode)
@ -25,7 +25,7 @@ func (n *Nft) AddTable(name, family string) (*nftables.Table, error) {
return tbl, nil
}
func (n *Nft) getTable(name, family string) *nftables.Table {
func (n *Nft) GetTable(name, family string) *nftables.Table {
return sysTables.Get(getTableKey(name, family))
}
@ -33,7 +33,7 @@ func getTableKey(name string, family interface{}) string {
return fmt.Sprint(name, "-", family)
}
func (n *Nft) addInterceptionTables() error {
func (n *Nft) AddInterceptionTables() error {
if _, err := n.AddTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET); err != nil {
return err
}
@ -53,7 +53,7 @@ func (n *Nft) addSystemTables() {
// return the number of rules that we didn't add.
func (n *Nft) nonSystemRules(tbl *nftables.Table) int {
chains, err := n.conn.ListChains()
chains, err := n.Conn.ListChains()
if err != nil {
return -1
}
@ -62,7 +62,7 @@ func (n *Nft) nonSystemRules(tbl *nftables.Table) int {
if tbl.Name != c.Table.Name && tbl.Family != c.Table.Family {
continue
}
rules, err := n.conn.GetRule(c.Table, c)
rules, err := n.Conn.GetRule(c.Table, c)
if err != nil {
return -1
}
@ -72,12 +72,13 @@ func (n *Nft) nonSystemRules(tbl *nftables.Table) int {
return t
}
func (n *Nft) delSystemTables() {
// DelSystemTables deletes tables created from fw configuration.
func (n *Nft) DelSystemTables() {
for k, tbl := range sysTables.List() {
if n.nonSystemRules(tbl) != 0 {
continue
}
n.conn.DelTable(tbl)
n.Conn.DelTable(tbl)
if !n.Commit() {
log.Warning("error deleting system table: %s", k)
continue

View file

@ -8,7 +8,7 @@ import (
"github.com/google/nftables"
)
func getFamilyCode(family string) nftables.TableFamily {
func GetFamilyCode(family string) nftables.TableFamily {
famCode := nftables.TableFamilyINet
switch family {
// [filter]: prerouting forward input output postrouting
@ -32,7 +32,7 @@ func getFamilyCode(family string) nftables.TableFamily {
return famCode
}
func getHook(chain string) *nftables.ChainHook {
func GetHook(chain string) *nftables.ChainHook {
hook := nftables.ChainHookOutput
// https://github.com/google/nftables/blob/master/chain.go#L33
@ -52,12 +52,12 @@ func getHook(chain string) *nftables.ChainHook {
return hook
}
// getChainPriority gets the corresponding priority for the given chain, based
// GetChainPriority gets the corresponding priority for the given chain, based
// on the following configuration matrix:
// https://wiki.nftables.org/wiki-nftables/index.php/Netfilter_hooks#Priority_within_hook
// https://github.com/google/nftables/blob/master/chain.go#L48
// man nft (table 6.)
func getChainPriority(family, cType, hook string) (*nftables.ChainPriority, nftables.ChainType) {
func GetChainPriority(family, cType, hook string) (*nftables.ChainPriority, nftables.ChainType) {
// types: route, nat, filter
chainType := nftables.ChainTypeFilter
// priorities: raw, conntrack, mangle, natdest, filter, security
@ -119,7 +119,7 @@ func getChainPriority(family, cType, hook string) (*nftables.ChainPriority, nfta
chainPrio = nftables.ChainPriorityRaw
case exprs.NFT_CHAIN_CONNTRACK:
chainPrio, chainType = getConntrackPriority(hook)
chainPrio, chainType = GetConntrackPriority(hook)
case exprs.NFT_CHAIN_NATDEST:
// hook: prerouting
@ -152,7 +152,7 @@ func getChainPriority(family, cType, hook string) (*nftables.ChainPriority, nfta
}
// https://wiki.nftables.org/wiki-nftables/index.php/Netfilter_hooks#Priority_within_hook
func getConntrackPriority(hook string) (*nftables.ChainPriority, nftables.ChainType) {
func GetConntrackPriority(hook string) (*nftables.ChainPriority, nftables.ChainType) {
chainType := nftables.ChainTypeFilter
chainPrio := nftables.ChainPriorityConntrack
switch hook {