netlink: fixed connections querying

also code simplified.
This commit is contained in:
Gustavo Iñiguez Goia 2020-02-18 02:05:15 +01:00
parent 25c27511e1
commit a13f42d98b
2 changed files with 37 additions and 85 deletions

View file

@ -3,8 +3,6 @@ package netlink
import (
"syscall"
"net"
"github.com/gustavo-iniguez-goya/opensnitch/daemon/log"
)
func GetSocketInfo(proto string, srcIP net.IP, srcPort uint, dstIP net.IP, dstPort uint) (uid, inode int) {
@ -15,31 +13,19 @@ func GetSocketInfo(proto string, srcIP net.IP, srcPort uint, dstIP net.IP, dstPo
family = syscall.AF_INET6
}
var s *Socket
var err error
if proto[:3] == "udp" {
ipproto = syscall.IPPROTO_UDP
if protoLen >=7 && proto[:7] == "udplite" {
ipproto = syscall.IPPROTO_UDPLITE
}
srcAddr := &net.UDPAddr{ IP: srcIP, Port: int(srcPort), }
dstAddr := &net.UDPAddr{ IP: dstIP, Port: int(dstPort), }
s, err = SocketGet(family, ipproto, srcAddr, dstAddr)
} else if proto[:3] == "tcp" {
srcAddr := &net.TCPAddr{ IP: srcIP, Port: int(srcPort), }
dstAddr := &net.TCPAddr{ IP: dstIP, Port: int(dstPort), }
s, err = SocketGet(family, ipproto, srcAddr, dstAddr)
} else {
log.Debug("Unknown protocol, not implemented", proto)
return -1, -1
}
if err == nil && s.INode > 0 && s.INode != 0xffffffff {
if s.UID == 0xffffffff {
return -1, int(s.INode)
sock, err := SocketGet(family, ipproto, uint16(srcPort), uint16(dstPort), srcIP, dstIP)
if err == nil && sock.INode > 0 && sock.INode != 0xffffffff {
if sock.UID == 0xffffffff {
return -1, int(sock.INode)
}
return int(s.UID), int(s.INode)
} else if err != nil {
log.Debug("Netlink socket error", err)
return int(sock.UID), int(sock.INode)
}
return -1, -1

View file

@ -8,7 +8,6 @@ import (
"syscall"
"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
)
// This is a copy of https://github.com/vishvananda/netlink socket_linux.go
@ -140,71 +139,38 @@ func (s *Socket) deserialize(b []byte) error {
}
// SocketGet returns the Socket identified by its local and remote addresses.
func SocketGet(family uint8, proto uint8, local, remote net.Addr) (*Socket, error) {
var sPort, dPort uint16
func SocketGet(family uint8, proto uint8, srcPort, dstPort uint16, local, remote net.IP) (*Socket, error) {
var localIP, remoteIP net.IP
_Id := SocketID{}
if proto == unix.IPPROTO_UDP || proto == unix.IPPROTO_UDPLITE {
localUDP, ok := local.(*net.UDPAddr)
if !ok {
return nil, errors.New ("UDP IP error: invalid source IP")
}
remoteUDP, ok := remote.(*net.UDPAddr)
if !ok {
return nil, errors.New ("UDP IP error: invalid remote IP")
}
if family == unix.AF_INET6 {
localIP = localUDP.IP.To16()
remoteIP = remoteUDP.IP.To16()
if family == syscall.AF_INET6 {
localIP = local.To16()
remoteIP = remote.To16()
} else {
localIP = localUDP.IP.To4()
remoteIP = remoteUDP.IP.To4()
localIP = local.To4()
remoteIP = remote.To4()
}
sPort = uint16(localUDP.Port)
dPort = uint16(remoteUDP.Port)
} else {
localTCP, ok := local.(*net.TCPAddr)
if !ok {
return nil, errors.New ("TCP IP error: invalid source IP")
}
remoteTCP, ok := remote.(*net.TCPAddr)
if !ok {
return nil, errors.New ("TCP IP error: invalid remote IP")
}
if family == unix.AF_INET6 {
localIP = localTCP.IP.To16()
remoteIP = remoteTCP.IP.To16()
} else {
localIP = localTCP.IP.To4()
remoteIP = remoteTCP.IP.To4()
}
sPort = uint16(localTCP.Port)
dPort = uint16(remoteTCP.Port)
}
_Id = SocketID{
SourcePort: sPort,
DestinationPort: dPort,
_Id := SocketID{
SourcePort: srcPort,
DestinationPort: dstPort,
Source: localIP,
Destination: remoteIP,
Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
}
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0)
req.AddData(&SocketRequest{
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, syscall.NLM_F_DUMP)
sockReq := &SocketRequest{
Family: family,
Protocol: proto,
States: TCP_ALL,
ID: _Id,
})
}
req.AddData(sockReq)
msgs, err := req.Execute(syscall.NETLINK_INET_DIAG, 0)
if err != nil {
return nil, err
}
if len(msgs) == 0 {
return nil, errors.New("no message nor error from netlink")
return nil, errors.New("Warning, no message nor error from netlink")
}
if len(msgs) > 2 {
return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs))