tests: updated sys fw tests

This commit is contained in:
Gustavo Iñiguez Goia 2023-07-06 14:27:54 +02:00
parent 344819eb30
commit 18e583d20e
Failed to generate hash of commit
8 changed files with 278 additions and 205 deletions

View file

@ -1,26 +1,27 @@
package nftables package nftables_test
import ( import (
"testing" "testing"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs" "github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/nftest"
"github.com/google/nftables" "github.com/google/nftables"
) )
func TestChains(t *testing.T) { func TestChains(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
if nft.addInterceptionTables() != nil { if nftest.Fw.AddInterceptionTables() != nil {
t.Error("Error adding interception tables") t.Error("Error adding interception tables")
} }
t.Run("AddChain", func(t *testing.T) { t.Run("AddChain", func(t *testing.T) {
filterPolicy := nftables.ChainPolicyAccept filterPolicy := nftables.ChainPolicyAccept
chn := nft.AddChain( chn := nftest.Fw.AddChain(
exprs.NFT_HOOK_INPUT, exprs.NFT_HOOK_INPUT,
exprs.NFT_CHAIN_FILTER, exprs.NFT_CHAIN_FILTER,
exprs.NFT_FAMILY_INET, exprs.NFT_FAMILY_INET,
@ -31,54 +32,54 @@ func TestChains(t *testing.T) {
if chn == nil { if chn == nil {
t.Error("chain input-filter-inet not created") t.Error("chain input-filter-inet not created")
} }
if !nft.Commit() { if !nftest.Fw.Commit() {
t.Error("error adding input-filter-inet chain") t.Error("error adding input-filter-inet chain")
} }
}) })
t.Run("getChain", func(t *testing.T) { t.Run("getChain", func(t *testing.T) {
tblfilter := nft.getTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET) tblfilter := nftest.Fw.GetTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET)
if tblfilter == nil { if tblfilter == nil {
t.Error("table filter-inet not created") t.Error("table filter-inet not created")
} }
chn := nft.getChain(exprs.NFT_HOOK_INPUT, tblfilter, exprs.NFT_FAMILY_INET) chn := nftest.Fw.GetChain(exprs.NFT_HOOK_INPUT, tblfilter, exprs.NFT_FAMILY_INET)
if chn == nil { if chn == nil {
t.Error("chain input-filter-inet not added") t.Error("chain input-filter-inet not added")
} }
}) })
t.Run("delChain", func(t *testing.T) { t.Run("delChain", func(t *testing.T) {
tblfilter := nft.getTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET) tblfilter := nftest.Fw.GetTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET)
if tblfilter == nil { if tblfilter == nil {
t.Error("table filter-inet not created") t.Error("table filter-inet not created")
} }
chn := nft.getChain(exprs.NFT_HOOK_INPUT, tblfilter, exprs.NFT_FAMILY_INET) chn := nftest.Fw.GetChain(exprs.NFT_HOOK_INPUT, tblfilter, exprs.NFT_FAMILY_INET)
if chn == nil { if chn == nil {
t.Error("chain input-filter-inet not added") t.Error("chain input-filter-inet not added")
} }
if err := nft.delChain(chn); err != nil { if err := nftest.Fw.DelChain(chn); err != nil {
t.Error("error deleting chain input-filter-inet") t.Error("error deleting chain input-filter-inet")
} }
}) })
nft.delSystemTables() nftest.Fw.DelSystemTables()
} }
// TestAddInterceptionChains checks if the needed tables and chains have been created. // TestAddInterceptionChains checks if the needed tables and chains have been created.
// We use 2: output-mangle-inet for intercepting outbound connections, and input-filter-inet for DNS responses interception // We use 2: output-mangle-inet for intercepting outbound connections, and input-filter-inet for DNS responses interception
func TestAddInterceptionChains(t *testing.T) { func TestAddInterceptionChains(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
if err := nft.addInterceptionTables(); err != nil { if err := nftest.Fw.AddInterceptionTables(); err != nil {
t.Errorf("Error adding interception tables: %s", err) t.Errorf("Error adding interception tables: %s", err)
} }
if err := nft.addInterceptionChains(); err != nil { if err := nftest.Fw.AddInterceptionChains(); err != nil {
t.Errorf("Error adding interception chains: %s", err) t.Errorf("Error adding interception chains: %s", err)
} }
nft.delSystemTables() nftest.Fw.DelSystemTables()
} }

View file

@ -1,21 +1,23 @@
package nftables package nftables_test
import ( import (
"testing" "testing"
"time" "time"
"github.com/evilsocket/opensnitch/daemon/firewall/common" "github.com/evilsocket/opensnitch/daemon/firewall/common"
nftb "github.com/evilsocket/opensnitch/daemon/firewall/nftables"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs" "github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/nftest"
"github.com/google/nftables" "github.com/google/nftables"
) )
// mimic EnableInterception() but without NewRulesChecker() // mimic EnableInterception() but without NewRulesChecker()
func addInterceptionRules(nft *Nft, t *testing.T) { func addInterceptionRules(nft *nftb.Nft, t *testing.T) {
if err := nft.addInterceptionTables(); err != nil { if err := nft.AddInterceptionTables(); err != nil {
t.Errorf("Error while adding interception tables: %s", err) t.Errorf("Error while adding interception tables: %s", err)
return return
} }
if err := nft.addInterceptionChains(); err != nil { if err := nft.AddInterceptionChains(); err != nil {
t.Errorf("Error while adding interception chains: %s", err) t.Errorf("Error while adding interception chains: %s", err)
return return
} }
@ -28,12 +30,12 @@ func addInterceptionRules(nft *Nft, t *testing.T) {
} }
} }
func _testMonitorReload(t *testing.T, conn *nftables.Conn, nft *Nft) { func _testMonitorReload(t *testing.T, conn *nftables.Conn, nft *nftb.Nft) {
tblfilter := nft.getTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET) tblfilter := nft.GetTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET)
if tblfilter == nil || tblfilter.Name != exprs.NFT_CHAIN_FILTER { if tblfilter == nil || tblfilter.Name != exprs.NFT_CHAIN_FILTER {
t.Error("table filter-inet not in the list") t.Error("table filter-inet not in the list")
} }
chnFilterInput := nft.getChain(exprs.NFT_HOOK_INPUT, tblfilter, exprs.NFT_FAMILY_INET) chnFilterInput := nftest.Fw.GetChain(exprs.NFT_HOOK_INPUT, tblfilter, exprs.NFT_FAMILY_INET)
if chnFilterInput == nil { if chnFilterInput == nil {
t.Error("chain input-filter-inet not in the list") t.Error("chain input-filter-inet not in the list")
} }
@ -42,12 +44,12 @@ func _testMonitorReload(t *testing.T, conn *nftables.Conn, nft *Nft) {
t.Error("DNS interception rule not added") t.Error("DNS interception rule not added")
} }
conn.FlushChain(chnFilterInput) conn.FlushChain(chnFilterInput)
nft.Commit() nftest.Fw.Commit()
// the rules checker checks the rules every 10s // the rules checker checks the rules every 10s
reloaded := false reloaded := false
for i := 0; i < 15; i++ { for i := 0; i < 15; i++ {
if r, _ := getRule(t, conn, tblfilter.Name, exprs.NFT_HOOK_INPUT, interceptionRuleKey, 0); r != nil { if r, _ := getRule(t, conn, tblfilter.Name, exprs.NFT_HOOK_INPUT, nftb.InterceptionRuleKey, 0); r != nil {
reloaded = true reloaded = true
break break
} }
@ -59,35 +61,35 @@ func _testMonitorReload(t *testing.T, conn *nftables.Conn, nft *Nft) {
} }
func TestAreRulesLoaded(t *testing.T) { func TestAreRulesLoaded(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
addInterceptionRules(nft, t) addInterceptionRules(nftest.Fw, t)
if !nft.AreRulesLoaded() { if !nftest.Fw.AreRulesLoaded() {
t.Error("interception rules not loaded, and they should") t.Error("interception rules not loaded, and they should")
} }
nft.delInterceptionRules() nftest.Fw.DelInterceptionRules()
if nft.AreRulesLoaded() { if nftest.Fw.AreRulesLoaded() {
t.Error("interception rules are loaded, and the shouldn't") t.Error("interception rules are loaded, and the shouldn't")
} }
} }
func TestMonitorReload(t *testing.T) { func TestMonitorReload(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
nft.EnableInterception() nftest.Fw.EnableInterception()
// test that rules are reloaded after being deleted, but also // test that rules are reloaded after being deleted, but also
// that the monitor is not stopped after the first reload. // that the monitor is not stopped after the first reload.
_testMonitorReload(t, conn, nft) _testMonitorReload(t, conn, nftest.Fw)
_testMonitorReload(t, conn, nft) _testMonitorReload(t, conn, nftest.Fw)
_testMonitorReload(t, conn, nft) _testMonitorReload(t, conn, nftest.Fw)
} }

View file

@ -0,0 +1,63 @@
package nftest
import (
"os"
"runtime"
"testing"
nftb "github.com/evilsocket/opensnitch/daemon/firewall/nftables"
"github.com/google/nftables"
"github.com/vishvananda/netns"
)
var (
conn *nftables.Conn
newNS netns.NsHandle
// Fw represents the nftables Fw object.
Fw, _ = nftb.Fw()
)
func init() {
nftb.InitMapsStore()
}
// SkipIfNotPrivileged will skip the test from where it's invoked,
// to skip the test if we don't have root privileges.
// This may occur when executing the tests on restricted environments,
// such as containers, chroots, etc.
func SkipIfNotPrivileged(t *testing.T) {
if os.Getenv("PRIVILEGED_TESTS") == "" {
t.Skip("Set PRIVILEGED_TESTS to 1 to launch these tests, and launch them as root, or as a user allowed to create new namespaces.")
}
}
// OpenSystemConn opens a new connection with the kernel in a new namespace.
// https://github.com/google/nftables/blob/8f2d395e1089dea4966c483fbeae7e336917c095/internal/nftest/system_conn.go#L15
func OpenSystemConn(t *testing.T) (*nftables.Conn, netns.NsHandle) {
t.Helper()
// We lock the goroutine into the current thread, as namespace operations
// such as those invoked by `netns.New()` are thread-local. This is undone
// in nftest.CleanupSystemConn().
runtime.LockOSThread()
ns, err := netns.New()
if err != nil {
t.Fatalf("netns.New() failed: %v", err)
}
t.Log("OpenSystemConn() with NS:", ns)
c, err := nftables.New(nftables.WithNetNSFd(int(ns)))
if err != nil {
t.Fatalf("nftables.New() failed: %v", err)
}
return c, ns
}
// CleanupSystemConn closes the given namespace.
func CleanupSystemConn(t *testing.T, newNS netns.NsHandle) {
defer runtime.UnlockOSThread()
if err := newNS.Close(); err != nil {
t.Fatalf("newNS.Close() failed: %v", err)
}
}

View file

@ -0,0 +1,44 @@
package nftest
import (
"testing"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs"
"github.com/google/nftables"
"github.com/google/nftables/expr"
)
func AddTestRule(t *testing.T, conn *nftables.Conn, exp *[]expr.Any) (*nftables.Rule, *nftables.Chain) {
_, err := Fw.AddTable("yyy", exprs.NFT_FAMILY_INET)
if err != nil {
t.Error("pre step add_table() yyy-inet failed")
return nil, nil
}
chn := Fw.AddChain(
exprs.NFT_HOOK_INPUT,
"yyy",
exprs.NFT_FAMILY_INET,
nftables.ChainPriorityFilter,
nftables.ChainTypeFilter,
nftables.ChainHookInput,
nftables.ChainPolicyAccept)
if chn == nil {
t.Error("pre step add_chain() input-yyy-inet failed")
return nil, nil
}
//nft.Commit()
r, err := Fw.AddRule(
exprs.NFT_HOOK_INPUT, "yyy", exprs.NFT_FAMILY_INET,
0,
"key-yyy",
exp)
if err != nil {
t.Errorf("Error adding rule: %s", err)
return nil, nil
}
t.Logf("Rule: %+v", r)
return r, chn
}

View file

@ -1,9 +1,11 @@
package nftables package nftables_test
import ( import (
"testing" "testing"
nftb "github.com/evilsocket/opensnitch/daemon/firewall/nftables"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs" "github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/nftest"
"github.com/google/nftables" "github.com/google/nftables"
) )
@ -14,7 +16,7 @@ func getRulesList(t *testing.T, conn *nftables.Conn, family, tblName, chnName st
} }
for rdx, c := range chains { for rdx, c := range chains {
if c.Table.Family == getFamilyCode(family) && c.Table.Name == tblName && c.Name == chnName { if c.Table.Family == nftb.GetFamilyCode(family) && c.Table.Name == tblName && c.Name == chnName {
rules, err := conn.GetRule(c.Table, c) rules, err := conn.GetRule(c.Table, c)
if err != nil { if err != nil {
return nil, -1 return nil, -1
@ -54,36 +56,41 @@ func getRule(t *testing.T, conn *nftables.Conn, tblName, chnName, key string, ru
} }
func TestAddRule(t *testing.T) { func TestAddRule(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
_, err := nft.AddTable("yyy", exprs.NFT_FAMILY_INET) r, chn := nftest.AddTestRule(t, conn, exprs.NewNoTrack())
if err != nil {
t.Error("pre step add_table() yyy-inet failed") /*
} _, err := nft.AddTable("yyy", exprs.NFT_FAMILY_INET)
chn := nft.AddChain( if err != nil {
exprs.NFT_HOOK_INPUT, t.Error("pre step add_table() yyy-inet failed")
"yyy", }
exprs.NFT_FAMILY_INET, chn := nft.AddChain(
nftables.ChainPriorityFilter, exprs.NFT_HOOK_INPUT,
nftables.ChainTypeFilter, "yyy",
nftables.ChainHookInput, exprs.NFT_FAMILY_INET,
nftables.ChainPolicyAccept) nftables.ChainPriorityFilter,
if chn == nil { nftables.ChainTypeFilter,
t.Error("pre step add_chain() input-yyy-inet failed") nftables.ChainHookInput,
} nftables.ChainPolicyAccept)
if chn == nil {
t.Error("pre step add_chain() input-yyy-inet failed")
}
r, err := nft.addRule(
exprs.NFT_HOOK_INPUT, "yyy", exprs.NFT_FAMILY_INET,
0,
"key-yyy",
exprs.NewNoTrack())
if err != nil {
t.Errorf("Error adding rule: %s", err)
}
*/
r, err := nft.addRule(
exprs.NFT_HOOK_INPUT, "yyy", exprs.NFT_FAMILY_INET,
0,
"key-yyy",
exprs.NewNoTrack())
if err != nil {
t.Errorf("Error adding rule: %s", err)
}
rules, err := conn.GetRules(chn.Table, chn) rules, err := conn.GetRules(chn.Table, chn)
if err != nil || len(rules) != 1 { if err != nil || len(rules) != 1 {
t.Errorf("Rule not added, total: %d", len(rules)) t.Errorf("Rule not added, total: %d", len(rules))
@ -92,17 +99,17 @@ func TestAddRule(t *testing.T) {
} }
func TestInsertRule(t *testing.T) { func TestInsertRule(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
_, err := nft.AddTable("yyy", exprs.NFT_FAMILY_INET) _, err := nftest.Fw.AddTable("yyy", exprs.NFT_FAMILY_INET)
if err != nil { if err != nil {
t.Error("pre step add_table() yyy-inet failed") t.Error("pre step add_table() yyy-inet failed")
} }
chn := nft.AddChain( chn := nftest.Fw.AddChain(
exprs.NFT_HOOK_INPUT, exprs.NFT_HOOK_INPUT,
"yyy", "yyy",
exprs.NFT_FAMILY_INET, exprs.NFT_FAMILY_INET,
@ -114,7 +121,7 @@ func TestInsertRule(t *testing.T) {
t.Error("pre step add_chain() input-yyy-inet failed") t.Error("pre step add_chain() input-yyy-inet failed")
} }
err = nft.insertRule( err = nftest.Fw.InsertRule(
exprs.NFT_HOOK_INPUT, "yyy", exprs.NFT_FAMILY_INET, exprs.NFT_HOOK_INPUT, "yyy", exprs.NFT_FAMILY_INET,
0, 0,
exprs.NewNoTrack()) exprs.NewNoTrack())
@ -128,17 +135,17 @@ func TestInsertRule(t *testing.T) {
} }
func TestQueueConnections(t *testing.T) { func TestQueueConnections(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
_, err := nft.AddTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET) _, err := nftest.Fw.AddTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
if err != nil { if err != nil {
t.Error("pre step add_table() mangle-inet failed") t.Error("pre step add_table() mangle-inet failed")
} }
chn := nft.AddChain( chn := nftest.Fw.AddChain(
exprs.NFT_HOOK_OUTPUT, exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET, exprs.NFT_HOOK_OUTPUT, exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET,
nftables.ChainPriorityFilter, nftables.ChainPriorityFilter,
nftables.ChainTypeFilter, nftables.ChainTypeFilter,
@ -148,31 +155,31 @@ func TestQueueConnections(t *testing.T) {
t.Error("pre step add_chain() output-mangle-inet failed") t.Error("pre step add_chain() output-mangle-inet failed")
} }
if err1, err2 := nft.QueueConnections(true, true); err1 != nil && err2 != nil { if err1, err2 := nftest.Fw.QueueConnections(true, true); err1 != nil && err2 != nil {
t.Errorf("rule to queue connections not added: %s, %s", err1, err2) t.Errorf("rule to queue connections not added: %s, %s", err1, err2)
} }
r, _ := getRule(t, conn, exprs.NFT_CHAIN_MANGLE, exprs.NFT_HOOK_OUTPUT, interceptionRuleKey, 0) r, _ := getRule(t, conn, exprs.NFT_CHAIN_MANGLE, exprs.NFT_HOOK_OUTPUT, nftb.InterceptionRuleKey, 0)
if r == nil { if r == nil {
t.Error("rule to queue connections not in the list") t.Error("rule to queue connections not in the list")
} }
if string(r.UserData) != interceptionRuleKey { if string(r.UserData) != nftb.InterceptionRuleKey {
t.Errorf("invalid UserData: %s", string(r.UserData)) t.Errorf("invalid UserData: %s", string(r.UserData))
} }
} }
func TestQueueDNSResponses(t *testing.T) { func TestQueueDNSResponses(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
_, err := nft.AddTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET) _, err := nftest.Fw.AddTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET)
if err != nil { if err != nil {
t.Error("pre step add_table() filter-inet failed") t.Error("pre step add_table() filter-inet failed")
} }
chn := nft.AddChain( chn := nftest.Fw.AddChain(
exprs.NFT_HOOK_INPUT, exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET, exprs.NFT_HOOK_INPUT, exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET,
nftables.ChainPriorityFilter, nftables.ChainPriorityFilter,
nftables.ChainTypeFilter, nftables.ChainTypeFilter,
@ -182,15 +189,15 @@ func TestQueueDNSResponses(t *testing.T) {
t.Error("pre step add_chain() input-filter-inet failed") t.Error("pre step add_chain() input-filter-inet failed")
} }
if err1, err2 := nft.QueueDNSResponses(true, true); err1 != nil && err2 != nil { if err1, err2 := nftest.Fw.QueueDNSResponses(true, true); err1 != nil && err2 != nil {
t.Errorf("rule to queue DNS responses not added: %s, %s", err1, err2) t.Errorf("rule to queue DNS responses not added: %s, %s", err1, err2)
} }
r, _ := getRule(t, conn, exprs.NFT_CHAIN_FILTER, exprs.NFT_HOOK_INPUT, interceptionRuleKey, 0) r, _ := getRule(t, conn, exprs.NFT_CHAIN_FILTER, exprs.NFT_HOOK_INPUT, nftb.InterceptionRuleKey, 0)
if r == nil { if r == nil {
t.Error("rule to queue DNS responses not in the list") t.Error("rule to queue DNS responses not in the list")
} }
if string(r.UserData) != interceptionRuleKey { if string(r.UserData) != nftb.InterceptionRuleKey {
t.Errorf("invalid UserData: %s", string(r.UserData)) t.Errorf("invalid UserData: %s", string(r.UserData))
} }
@ -201,7 +208,7 @@ func TestQueueDNSResponses(t *testing.T) {
/*if err1, err2 := nft.QueueDNSResponses(false, true); err1 != nil && err2 != nil { /*if err1, err2 := nft.QueueDNSResponses(false, true); err1 != nil && err2 != nil {
t.Errorf("rule to queue DNS responses not deleted: %s, %s", err1, err2) t.Errorf("rule to queue DNS responses not deleted: %s, %s", err1, err2)
} }
r, _ = getRule(t, conn, exprs.NFT_CHAIN_FILTER, exprs.NFT_HOOK_INPUT, interceptionRuleKey, 0) r, _ = getRule(t, conn, exprs.NFT_CHAIN_FILTER, exprs.NFT_HOOK_INPUT, nftb.InterceptionRuleKey, 0)
if r != nil { if r != nil {
t.Error("rule to queue DNS responses should have been deleted") t.Error("rule to queue DNS responses should have been deleted")
}*/ }*/

View file

@ -1,9 +1,10 @@
package nftables package nftables_test
import ( import (
"testing" "testing"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs" "github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/nftest"
) )
type sysChainsListT struct { type sysChainsListT struct {
@ -14,13 +15,13 @@ type sysChainsListT struct {
} }
func TestAddSystemRules(t *testing.T) { func TestAddSystemRules(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
cfg, err := nft.NewSystemFwConfig(nft.preloadConfCallback, nft.reloadConfCallback) cfg, err := nftest.Fw.NewSystemFwConfig(nftest.Fw.PreloadConfCallback, nftest.Fw.ReloadConfCallback)
if err != nil { if err != nil {
t.Logf("Error creating fw config: %s", err) t.Logf("Error creating fw config: %s", err)
} }
@ -30,7 +31,7 @@ func TestAddSystemRules(t *testing.T) {
t.Errorf("Error loading config from disk: %s", err) t.Errorf("Error loading config from disk: %s", err)
} }
nft.AddSystemRules(false, false) nftest.Fw.AddSystemRules(false, false)
rules, _ := getRulesList(t, conn, exprs.NFT_FAMILY_INET, exprs.NFT_CHAIN_FILTER, exprs.NFT_HOOK_INPUT) rules, _ := getRulesList(t, conn, exprs.NFT_FAMILY_INET, exprs.NFT_CHAIN_FILTER, exprs.NFT_HOOK_INPUT)
// 3 rules in total, 1 disabled. // 3 rules in total, 1 disabled.
@ -62,13 +63,13 @@ func TestAddSystemRules(t *testing.T) {
} }
func TestFwConfDisabled(t *testing.T) { func TestFwConfDisabled(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
cfg, err := nft.NewSystemFwConfig(nft.preloadConfCallback, nft.reloadConfCallback) cfg, err := nftest.Fw.NewSystemFwConfig(nftest.Fw.PreloadConfCallback, nftest.Fw.ReloadConfCallback)
if err != nil { if err != nil {
t.Logf("Error creating fw config: %s", err) t.Logf("Error creating fw config: %s", err)
} }
@ -78,7 +79,7 @@ func TestFwConfDisabled(t *testing.T) {
t.Errorf("Error loading config from disk: %s", err) t.Errorf("Error loading config from disk: %s", err)
} }
nft.AddSystemRules(false, false) nftest.Fw.AddSystemRules(false, false)
tests := []sysChainsListT{ tests := []sysChainsListT{
{ {
@ -101,13 +102,13 @@ func TestFwConfDisabled(t *testing.T) {
} }
func TestDeleteSystemRules(t *testing.T) { func TestDeleteSystemRules(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
cfg, err := nft.NewSystemFwConfig(nft.preloadConfCallback, nft.reloadConfCallback) cfg, err := nftest.Fw.NewSystemFwConfig(nftest.Fw.PreloadConfCallback, nftest.Fw.ReloadConfCallback)
if err != nil { if err != nil {
t.Logf("Error creating fw config: %s", err) t.Logf("Error creating fw config: %s", err)
} }
@ -117,7 +118,7 @@ func TestDeleteSystemRules(t *testing.T) {
t.Errorf("Error loading config from disk: %s", err) t.Errorf("Error loading config from disk: %s", err)
} }
nft.AddSystemRules(false, false) nftest.Fw.AddSystemRules(false, false)
tests := []sysChainsListT{ tests := []sysChainsListT{
{ {
@ -138,14 +139,14 @@ func TestDeleteSystemRules(t *testing.T) {
} }
t.Run("test-delete-system-rules", func(t *testing.T) { t.Run("test-delete-system-rules", func(t *testing.T) {
nft.DeleteSystemRules(false, false, true) nftest.Fw.DeleteSystemRules(false, false, true)
for _, tt := range tests { for _, tt := range tests {
rules, _ := getRulesList(t, conn, tt.family, tt.table, tt.chain) rules, _ := getRulesList(t, conn, tt.family, tt.table, tt.chain)
if len(rules) != 0 { if len(rules) != 0 {
t.Errorf("%d rules found, there should be 0", len(rules)) t.Errorf("%d rules found, there should be 0", len(rules))
} }
tbl := nft.getTable(tt.table, tt.family) tbl := nftest.Fw.GetTable(tt.table, tt.family)
if tbl == nil { if tbl == nil {
t.Errorf("table %s-%s should exist", tt.table, tt.family) t.Errorf("table %s-%s should exist", tt.table, tt.family)
} }

View file

@ -1,63 +1,17 @@
package nftables package nftables_test
import ( import (
"os"
"runtime"
"testing" "testing"
nftb "github.com/evilsocket/opensnitch/daemon/firewall/nftables"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs" "github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/nftest"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/vishvananda/netns"
) )
var ( func tableExists(t *testing.T, conn *nftables.Conn, origtbl *nftables.Table, family string) bool {
conn *nftables.Conn
newNS netns.NsHandle
nft, _ = Fw()
)
func init() {
initMapsStore()
}
func skipIfNotPrivileged(t *testing.T) {
if os.Getenv("PRIVILEGED_TESTS") == "" {
t.Skip("Set PRIVILEGED_TESTS to 1 to launch these tests, and launch them as root, or as a user allowed to create new namespaces.")
}
}
// https://github.com/google/nftables/blob/8f2d395e1089dea4966c483fbeae7e336917c095/internal/nftest/system_conn.go#L15
func OpenSystemConn(t *testing.T) (*nftables.Conn, netns.NsHandle) {
t.Helper()
// We lock the goroutine into the current thread, as namespace operations
// such as those invoked by `netns.New()` are thread-local. This is undone
// in nftest.CleanupSystemConn().
runtime.LockOSThread()
ns, err := netns.New()
if err != nil {
t.Fatalf("netns.New() failed: %v", err)
}
t.Log("OpenSystemConn() with NS:", ns)
c, err := nftables.New(nftables.WithNetNSFd(int(ns)))
if err != nil {
t.Fatalf("nftables.New() failed: %v", err)
}
return c, ns
}
func CleanupSystemConn(t *testing.T, newNS netns.NsHandle) {
defer runtime.UnlockOSThread()
if err := newNS.Close(); err != nil {
t.Fatalf("newNS.Close() failed: %v", err)
}
}
func tableExists(t *testing.T, origtbl *nftables.Table, family string) bool {
tables, err := conn.ListTablesOfFamily( tables, err := conn.ListTablesOfFamily(
getFamilyCode(family), nftb.GetFamilyCode(family),
) )
if err != nil { if err != nil {
return false return false
@ -73,54 +27,54 @@ func tableExists(t *testing.T, origtbl *nftables.Table, family string) bool {
} }
func TestAddTable(t *testing.T) { func TestAddTable(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
t.Run("inet family", func(t *testing.T) { t.Run("inet family", func(t *testing.T) {
tblxxx, err := nft.AddTable("xxx", exprs.NFT_FAMILY_INET) tblxxx, err := nftest.Fw.AddTable("xxx", exprs.NFT_FAMILY_INET)
if err != nil { if err != nil {
t.Error("table xxx-inet not added:", err) t.Error("table xxx-inet not added:", err)
} }
if tableExists(t, tblxxx, exprs.NFT_FAMILY_INET) == false { if tableExists(t, nftest.Fw.Conn, tblxxx, exprs.NFT_FAMILY_INET) == false {
t.Error("table xxx-inet not in the list") t.Error("table xxx-inet not in the list")
} }
nft.delSystemTables() nftest.Fw.DelSystemTables()
if tableExists(t, tblxxx, exprs.NFT_FAMILY_INET) { if tableExists(t, nftest.Fw.Conn, tblxxx, exprs.NFT_FAMILY_INET) {
t.Errorf("table xxx-inet still exists: %+v", sysTables) t.Error("table xxx-inet still exists")
} }
}) })
t.Run("ip family", func(t *testing.T) { t.Run("ip family", func(t *testing.T) {
tblxxx, err := nft.AddTable("xxx", exprs.NFT_FAMILY_IP) tblxxx, err := nftest.Fw.AddTable("xxx", exprs.NFT_FAMILY_IP)
if err != nil { if err != nil {
t.Error("table xxx-ip not added:", err) t.Error("table xxx-ip not added:", err)
} }
if tableExists(t, tblxxx, exprs.NFT_FAMILY_IP) == false { if tableExists(t, nftest.Fw.Conn, tblxxx, exprs.NFT_FAMILY_IP) == false {
t.Error("table xxx-ip not in the list") t.Error("table xxx-ip not in the list")
} }
nft.delSystemTables() nftest.Fw.DelSystemTables()
if tableExists(t, tblxxx, exprs.NFT_FAMILY_IP) { if tableExists(t, nftest.Fw.Conn, tblxxx, exprs.NFT_FAMILY_IP) {
t.Errorf("table xxx-ip still exists: %+v", sysTables) t.Errorf("table xxx-ip still exists:") // %+v", sysTables)
} }
}) })
t.Run("ip6 family", func(t *testing.T) { t.Run("ip6 family", func(t *testing.T) {
tblxxx, err := nft.AddTable("xxx", exprs.NFT_FAMILY_IP6) tblxxx, err := nftest.Fw.AddTable("xxx", exprs.NFT_FAMILY_IP6)
if err != nil { if err != nil {
t.Error("table xxx-ip6 not added:", err) t.Error("table xxx-ip6 not added:", err)
} }
if tableExists(t, tblxxx, exprs.NFT_FAMILY_IP6) == false { if tableExists(t, nftest.Fw.Conn, tblxxx, exprs.NFT_FAMILY_IP6) == false {
t.Error("table xxx-ip6 not in the list") t.Error("table xxx-ip6 not in the list")
} }
nft.delSystemTables() nftest.Fw.DelSystemTables()
if tableExists(t, tblxxx, exprs.NFT_FAMILY_IP6) { if tableExists(t, nftest.Fw.Conn, tblxxx, exprs.NFT_FAMILY_IP6) {
t.Errorf("table xxx-ip6 still exists: %+v", sysTables) t.Errorf("table xxx-ip6 still exists:") // %+v", sysTables)
} }
}) })
} }
@ -128,31 +82,31 @@ func TestAddTable(t *testing.T) {
// TestAddInterceptionTables checks if the needed tables have been created. // TestAddInterceptionTables checks if the needed tables have been created.
// We use 2: mangle-inet for intercepting outbound connections, and filter-inet for DNS responses interception // We use 2: mangle-inet for intercepting outbound connections, and filter-inet for DNS responses interception
func TestAddInterceptionTables(t *testing.T) { func TestAddInterceptionTables(t *testing.T) {
skipIfNotPrivileged(t) nftest.SkipIfNotPrivileged(t)
conn, newNS = OpenSystemConn(t) conn, newNS := nftest.OpenSystemConn(t)
defer CleanupSystemConn(t, newNS) defer nftest.CleanupSystemConn(t, newNS)
nft.conn = conn nftest.Fw.Conn = conn
if err := nft.addInterceptionTables(); err != nil { if err := nftest.Fw.AddInterceptionTables(); err != nil {
t.Errorf("addInterceptionTables() error: %s", err) t.Errorf("addInterceptionTables() error: %s", err)
} }
t.Run("mangle-inet", func(t *testing.T) { t.Run("mangle-inet", func(t *testing.T) {
tblmangle := nft.getTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET) tblmangle := nftest.Fw.GetTable(exprs.NFT_CHAIN_MANGLE, exprs.NFT_FAMILY_INET)
if tblmangle == nil { if tblmangle == nil {
t.Error("interception table mangle-inet not in the list") t.Error("interception table mangle-inet not in the list")
} }
if tableExists(t, tblmangle, exprs.NFT_FAMILY_INET) == false { if tableExists(t, nftest.Fw.Conn, tblmangle, exprs.NFT_FAMILY_INET) == false {
t.Error("table mangle-inet not in the list") t.Error("table mangle-inet not in the list")
} }
}) })
t.Run("filter-inet", func(t *testing.T) { t.Run("filter-inet", func(t *testing.T) {
tblfilter := nft.getTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET) tblfilter := nftest.Fw.GetTable(exprs.NFT_CHAIN_FILTER, exprs.NFT_FAMILY_INET)
if tblfilter == nil { if tblfilter == nil {
t.Error("interception table filter-inet not in the list") t.Error("interception table filter-inet not in the list")
} }
if tableExists(t, tblfilter, exprs.NFT_FAMILY_INET) == false { if tableExists(t, nftest.Fw.Conn, tblfilter, exprs.NFT_FAMILY_INET) == false {
t.Error("table filter-inet not in the list") t.Error("table filter-inet not in the list")
} }
}) })

View file

@ -1,8 +1,9 @@
package nftables package nftables_test
import ( import (
"testing" "testing"
nftb "github.com/evilsocket/opensnitch/daemon/firewall/nftables"
"github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs" "github.com/evilsocket/opensnitch/daemon/firewall/nftables/exprs"
"github.com/google/nftables" "github.com/google/nftables"
) )
@ -23,28 +24,28 @@ type chainPrioT struct {
func TestGetConntrackPriority(t *testing.T) { func TestGetConntrackPriority(t *testing.T) {
t.Run("hook-prerouting", func(t *testing.T) { t.Run("hook-prerouting", func(t *testing.T) {
cprio, ctype := getConntrackPriority(exprs.NFT_HOOK_PREROUTING) cprio, ctype := nftb.GetConntrackPriority(exprs.NFT_HOOK_PREROUTING)
if cprio != nftables.ChainPriorityConntrack && ctype != nftables.ChainTypeFilter { if cprio != nftables.ChainPriorityConntrack && ctype != nftables.ChainTypeFilter {
t.Errorf("invalid conntrack priority or type for hook PREROUTING: %+v, %+v", cprio, ctype) t.Errorf("invalid conntrack priority or type for hook PREROUTING: %+v, %+v", cprio, ctype)
} }
}) })
t.Run("hook-output", func(t *testing.T) { t.Run("hook-output", func(t *testing.T) {
cprio, ctype := getConntrackPriority(exprs.NFT_HOOK_OUTPUT) cprio, ctype := nftb.GetConntrackPriority(exprs.NFT_HOOK_OUTPUT)
if cprio != nftables.ChainPriorityNATSource && ctype != nftables.ChainTypeFilter { if cprio != nftables.ChainPriorityNATSource && ctype != nftables.ChainTypeFilter {
t.Errorf("invalid conntrack priority or type for hook OUTPUT: %+v, %+v", cprio, ctype) t.Errorf("invalid conntrack priority or type for hook OUTPUT: %+v, %+v", cprio, ctype)
} }
}) })
t.Run("hook-postrouting", func(t *testing.T) { t.Run("hook-postrouting", func(t *testing.T) {
cprio, ctype := getConntrackPriority(exprs.NFT_HOOK_POSTROUTING) cprio, ctype := nftb.GetConntrackPriority(exprs.NFT_HOOK_POSTROUTING)
if cprio != nftables.ChainPriorityConntrackHelper && ctype != nftables.ChainTypeNAT { if cprio != nftables.ChainPriorityConntrackHelper && ctype != nftables.ChainTypeNAT {
t.Errorf("invalid conntrack priority or type for hook POSTROUTING: %+v, %+v", cprio, ctype) t.Errorf("invalid conntrack priority or type for hook POSTROUTING: %+v, %+v", cprio, ctype)
} }
}) })
t.Run("hook-input", func(t *testing.T) { t.Run("hook-input", func(t *testing.T) {
cprio, ctype := getConntrackPriority(exprs.NFT_HOOK_INPUT) cprio, ctype := nftb.GetConntrackPriority(exprs.NFT_HOOK_INPUT)
if cprio != nftables.ChainPriorityConntrackConfirm && ctype != nftables.ChainTypeFilter { if cprio != nftables.ChainPriorityConntrackConfirm && ctype != nftables.ChainTypeFilter {
t.Errorf("invalid conntrack priority or type for hook INPUT: %+v, %+v", cprio, ctype) t.Errorf("invalid conntrack priority or type for hook INPUT: %+v, %+v", cprio, ctype)
} }
@ -146,7 +147,7 @@ func TestGetChainPriority(t *testing.T) {
for _, testChainPrio := range matrixTests { for _, testChainPrio := range matrixTests {
t.Run(testChainPrio.test, func(t *testing.T) { t.Run(testChainPrio.test, func(t *testing.T) {
chainPrio, chainType := getChainPriority(testChainPrio.family, testChainPrio.chain, testChainPrio.hook) chainPrio, chainType := nftb.GetChainPriority(testChainPrio.family, testChainPrio.chain, testChainPrio.hook)
if testChainPrio.checkEqual { if testChainPrio.checkEqual {
if chainPrio != testChainPrio.chainPrio && chainType != testChainPrio.chainType { if chainPrio != testChainPrio.chainPrio && chainType != testChainPrio.chainType {
@ -204,7 +205,7 @@ func TestInvalidChainPriority(t *testing.T) {
for _, testChainPrio := range matrixTests { for _, testChainPrio := range matrixTests {
t.Run(testChainPrio.test, func(t *testing.T) { t.Run(testChainPrio.test, func(t *testing.T) {
chainPrio, chainType := getChainPriority(testChainPrio.family, testChainPrio.chain, testChainPrio.hook) chainPrio, chainType := nftb.GetChainPriority(testChainPrio.family, testChainPrio.chain, testChainPrio.hook)
if testChainPrio.checkEqual { if testChainPrio.checkEqual {
if chainPrio != testChainPrio.chainPrio && chainType != testChainPrio.chainType { if chainPrio != testChainPrio.chainPrio && chainType != testChainPrio.chainType {