mirror of
https://github.com/evilsocket/opensnitch.git
synced 2025-03-04 08:34:40 +01:00
netlink: fixed connections querying
also code simplified.
This commit is contained in:
parent
25c27511e1
commit
a13f42d98b
2 changed files with 37 additions and 85 deletions
|
@ -3,8 +3,6 @@ package netlink
|
|||
import (
|
||||
"syscall"
|
||||
"net"
|
||||
|
||||
"github.com/gustavo-iniguez-goya/opensnitch/daemon/log"
|
||||
)
|
||||
|
||||
func GetSocketInfo(proto string, srcIP net.IP, srcPort uint, dstIP net.IP, dstPort uint) (uid, inode int) {
|
||||
|
@ -15,31 +13,19 @@ func GetSocketInfo(proto string, srcIP net.IP, srcPort uint, dstIP net.IP, dstPo
|
|||
family = syscall.AF_INET6
|
||||
}
|
||||
|
||||
var s *Socket
|
||||
var err error
|
||||
if proto[:3] == "udp" {
|
||||
ipproto = syscall.IPPROTO_UDP
|
||||
if protoLen >=7 && proto[:7] == "udplite" {
|
||||
ipproto = syscall.IPPROTO_UDPLITE
|
||||
}
|
||||
srcAddr := &net.UDPAddr{ IP: srcIP, Port: int(srcPort), }
|
||||
dstAddr := &net.UDPAddr{ IP: dstIP, Port: int(dstPort), }
|
||||
s, err = SocketGet(family, ipproto, srcAddr, dstAddr)
|
||||
} else if proto[:3] == "tcp" {
|
||||
srcAddr := &net.TCPAddr{ IP: srcIP, Port: int(srcPort), }
|
||||
dstAddr := &net.TCPAddr{ IP: dstIP, Port: int(dstPort), }
|
||||
s, err = SocketGet(family, ipproto, srcAddr, dstAddr)
|
||||
} else {
|
||||
log.Debug("Unknown protocol, not implemented", proto)
|
||||
return -1, -1
|
||||
}
|
||||
if err == nil && s.INode > 0 && s.INode != 0xffffffff {
|
||||
if s.UID == 0xffffffff {
|
||||
return -1, int(s.INode)
|
||||
sock, err := SocketGet(family, ipproto, uint16(srcPort), uint16(dstPort), srcIP, dstIP)
|
||||
if err == nil && sock.INode > 0 && sock.INode != 0xffffffff {
|
||||
if sock.UID == 0xffffffff {
|
||||
return -1, int(sock.INode)
|
||||
}
|
||||
return int(s.UID), int(s.INode)
|
||||
} else if err != nil {
|
||||
log.Debug("Netlink socket error", err)
|
||||
return int(sock.UID), int(sock.INode)
|
||||
}
|
||||
|
||||
return -1, -1
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"syscall"
|
||||
|
||||
"github.com/vishvananda/netlink/nl"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// This is a copy of https://github.com/vishvananda/netlink socket_linux.go
|
||||
|
@ -140,71 +139,38 @@ func (s *Socket) deserialize(b []byte) error {
|
|||
}
|
||||
|
||||
// SocketGet returns the Socket identified by its local and remote addresses.
|
||||
func SocketGet(family uint8, proto uint8, local, remote net.Addr) (*Socket, error) {
|
||||
var sPort, dPort uint16
|
||||
func SocketGet(family uint8, proto uint8, srcPort, dstPort uint16, local, remote net.IP) (*Socket, error) {
|
||||
var localIP, remoteIP net.IP
|
||||
_Id := SocketID{}
|
||||
|
||||
if proto == unix.IPPROTO_UDP || proto == unix.IPPROTO_UDPLITE {
|
||||
localUDP, ok := local.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return nil, errors.New ("UDP IP error: invalid source IP")
|
||||
}
|
||||
remoteUDP, ok := remote.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return nil, errors.New ("UDP IP error: invalid remote IP")
|
||||
}
|
||||
if family == unix.AF_INET6 {
|
||||
localIP = localUDP.IP.To16()
|
||||
remoteIP = remoteUDP.IP.To16()
|
||||
if family == syscall.AF_INET6 {
|
||||
localIP = local.To16()
|
||||
remoteIP = remote.To16()
|
||||
} else {
|
||||
localIP = localUDP.IP.To4()
|
||||
remoteIP = remoteUDP.IP.To4()
|
||||
localIP = local.To4()
|
||||
remoteIP = remote.To4()
|
||||
}
|
||||
|
||||
sPort = uint16(localUDP.Port)
|
||||
dPort = uint16(remoteUDP.Port)
|
||||
} else {
|
||||
localTCP, ok := local.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return nil, errors.New ("TCP IP error: invalid source IP")
|
||||
}
|
||||
remoteTCP, ok := remote.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return nil, errors.New ("TCP IP error: invalid remote IP")
|
||||
}
|
||||
if family == unix.AF_INET6 {
|
||||
localIP = localTCP.IP.To16()
|
||||
remoteIP = remoteTCP.IP.To16()
|
||||
} else {
|
||||
localIP = localTCP.IP.To4()
|
||||
remoteIP = remoteTCP.IP.To4()
|
||||
}
|
||||
|
||||
sPort = uint16(localTCP.Port)
|
||||
dPort = uint16(remoteTCP.Port)
|
||||
}
|
||||
|
||||
_Id = SocketID{
|
||||
SourcePort: sPort,
|
||||
DestinationPort: dPort,
|
||||
_Id := SocketID{
|
||||
SourcePort: srcPort,
|
||||
DestinationPort: dstPort,
|
||||
Source: localIP,
|
||||
Destination: remoteIP,
|
||||
Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
|
||||
}
|
||||
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0)
|
||||
req.AddData(&SocketRequest{
|
||||
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, syscall.NLM_F_DUMP)
|
||||
sockReq := &SocketRequest{
|
||||
Family: family,
|
||||
Protocol: proto,
|
||||
States: TCP_ALL,
|
||||
ID: _Id,
|
||||
})
|
||||
}
|
||||
req.AddData(sockReq)
|
||||
msgs, err := req.Execute(syscall.NETLINK_INET_DIAG, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(msgs) == 0 {
|
||||
return nil, errors.New("no message nor error from netlink")
|
||||
return nil, errors.New("Warning, no message nor error from netlink")
|
||||
}
|
||||
if len(msgs) > 2 {
|
||||
return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs))
|
||||
|
|
Loading…
Add table
Reference in a new issue