Fix packet blocking behaviour while UI is open (fixes #40)

There are a few non-obvious reasons why this commit is so big:

The PyQt mainloop must run in main thread.. This was not particularly
easy since the packet callbacks were running in the main thread.

Because of the PyQt running in the main thread thing NetFilterQueue had
to be wrapped up in a thread.

The packet callback is now dispatched to a thread if user has to be
prompted.
Packets are sent over a queue to the ui thread.

SQLite connection must be called from the same thread it was created
in. Thats why all the calls are wrapped up in a lock and create a new
connection. This is not ideal but I would say it's good enough for now.
This commit is contained in:
adisbladis 2017-05-12 23:58:47 +08:00
parent 635aed0732
commit f24bda2a25
Failed to generate hash of commit
5 changed files with 260 additions and 142 deletions

View file

@ -23,7 +23,8 @@ from socket import inet_ntoa, getservbyport
class Connection:
def __init__(self, procmon, desktop_parser, payload):
def __init__(self, packet_id, procmon, desktop_parser, payload):
self.id = packet_id
self.data = payload
self.pkt = ip.IP( self.data )
self.src_addr = inet_ntoa( self.pkt.src )

View file

@ -16,22 +16,17 @@
# program. If not, go to http://www.gnu.org/licenses/gpl.html
# or write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import logging
from scapy.all import DNSRR, DNS
from threading import Lock
from scapy.all import *
import logging
class DNSCollector:
def __init__(self):
self.lock = Lock()
self.hosts = { '127.0.0.1': 'localhost' }
self.hosts = {'127.0.0.1': 'localhost'}
def is_dns_response(self, packet):
if packet.haslayer(DNSRR):
return True
else:
return False
def add_response( self, packet ):
def add_response(self, packet):
if packet.haslayer(DNS) and packet.haslayer(DNSRR):
with self.lock:
try:
@ -39,7 +34,7 @@ class DNSCollector:
i = a_count + 4
while i > 4:
hostname = packet[0][i].rrname
address = packet[0][i].rdata
address = packet[0][i].rdata
i -= 1
if hostname == b'.':
@ -52,15 +47,18 @@ class DNSCollector:
if address.endswith('.'):
address = address[:-1]
logging.debug("Adding DNS response: %s => %s" % (address, hostname))
logging.debug("Adding DNS response: %s => %s", address, hostname) # noqa
self.hosts[address] = hostname.decode()
except Exception as e:
logging.debug("Error while parsing DNS response: %s" % e)
def get_hostname( self, address ):
with self.lock:
if address in self.hosts:
return self.hosts[address]
else:
logging.debug( "No hostname found for address %s" % address )
return address
return True
else:
return False
def get_hostname(self, address):
try:
return self.hosts[address]
except KeyError:
logging.debug("No hostname found for address %s" % address)
return address

View file

@ -96,29 +96,42 @@ class Rules:
if save_option == Rule.FOREVER:
self.db.save_rule(r)
class RulesDB:
def __init__(self, filename):
self._filename = filename
self._lock = Lock()
logging.info("Using rules database from %s" % filename)
self.conn = sqlite3.connect(filename)
self._create_table()
# Only call with lock!
def _get_conn(self):
return sqlite3.connect(self._filename)
def _create_table(self):
c = self.conn.cursor()
c.execute("CREATE TABLE IF NOT EXISTS rules (app_path TEXT, verdict INTEGER, address TEXT, port INTEGER, proto TEXT)")
with self._lock:
conn = self._get_conn()
c = conn.cursor()
c.execute("CREATE TABLE IF NOT EXISTS rules (app_path TEXT, verdict INTEGER, address TEXT, port INTEGER, proto TEXT)") # noqa
def load_rules(self):
c = self.conn.cursor()
c.execute("SELECT * FROM rules")
return [Rule(*item) for item in c.fetchall()]
with self._lock:
conn = self._get_conn()
c = conn.cursor()
c.execute("SELECT * FROM rules")
return [Rule(*item) for item in c.fetchall()]
def save_rule( self, rule ):
c = self.conn.cursor()
c.execute("INSERT INTO rules VALUES (?, ?, ?, ?, ?)", (rule.app_path, rule.verdict, rule.address, rule.port, rule.proto,))
self.conn.commit()
def save_rule(self, rule):
with self._lock:
conn = self._get_conn()
c = conn.cursor()
c.execute("INSERT INTO rules VALUES (?, ?, ?, ?, ?)", (rule.app_path, rule.verdict, rule.address, rule.port, rule.proto,)) # noqa
conn.commit()
def remove_all_app_rules ( self, app_path ):
c = self.conn.cursor()
c.execute("DELETE FROM rules WHERE app_path=?", (app_path,))
self.conn.commit()
def remove_all_app_rules(self, app_path):
with self._lock:
conn = self._get_conn()
c = conn.cursor()
c.execute("DELETE FROM rules WHERE app_path=?", (app_path,))
conn.commit()

View file

@ -16,12 +16,14 @@
# program. If not, go to http://www.gnu.org/licenses/gpl.html
# or write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import os
import logging
from concurrent.futures import Future, TimeoutError
from netfilterqueue import NetfilterQueue
from socket import AF_INET, AF_INET6, inet_ntoa
from threading import Lock
from scapy.all import *
from scapy.all import IP
import threading
import logging
import weakref
import os
from opensnitch.ui import QtApp
from opensnitch.connection import Connection
@ -31,86 +33,146 @@ from opensnitch.procmon import ProcMon
from opensnitch.app import LinuxDesktopParser
MARK_PACKET_DROP = 101285
PACKET_TIMEOUT = 30 # 30 seconds is a good value?
IPTABLES_RULES = (
# Get DNS responses
"INPUT --protocol udp --sport 53 -j NFQUEUE --queue-num 0 --queue-bypass",
# Get connection packets
"OUTPUT -t mangle -m conntrack --ctstate NEW -j NFQUEUE --queue-num 0 --queue-bypass", # noqa
# Reject packets marked by OpenSnitch
"OUTPUT --protocol tcp -m mark --mark 101285 -j REJECT")
def drop_packet(pkt, conn):
logging.info(
"Dropping %s from %s" % (conn, conn.get_app_name()))
pkt.set_mark(MARK_PACKET_DROP)
pkt.drop()
class NetfilterQueueWrapper(threading.Thread):
def __init__(self, snitch):
super().__init__()
self.snitch = snitch
self.start()
def run(self):
q = None
try:
for r in IPTABLES_RULES:
logging.debug("Applying iptables rule '%s'", r)
os.system("iptables -I %s" % r)
q = NetfilterQueue()
q.bind(0, self.snitch.pkt_callback, 1024 * 2)
q.run()
finally:
for r in IPTABLES_RULES:
logging.debug("Deleting iptables rule '%s'", r)
os.system("iptables -D %s" % r)
if q is not None:
q.unbind()
class PacketHandler(threading.Thread):
"""Handle a packet asynchronously in a thread"""
def __init__(self, connection, pkt, rules):
super().__init__()
self.future = Future()
self.future.set_running_or_notify_cancel()
self.conn = connection
self.pkt = pkt
self.rules = rules
self.start()
def run(self):
try:
(save_option,
verdict,
apply_for_all) = self.future.result(PACKET_TIMEOUT)
except TimeoutError:
# What to do on timeouts?
# Should we even have timeouts?
self.pkt.accept()
else:
if save_option != Rule.ONCE:
self.rules.add_rule(self.conn, verdict,
apply_for_all, save_option)
if verdict == Rule.DROP:
drop_packet(self.pkt, self.conn)
else:
self.pkt.accept()
class Snitch:
IPTABLES_RULES = ( # Get DNS responses
"INPUT --protocol udp --sport 53 -j NFQUEUE --queue-num 0 --queue-bypass",
# Get connection packets
"OUTPUT -t mangle -m conntrack --ctstate NEW -j NFQUEUE --queue-num 0 --queue-bypass",
# Reject packets marked by OpenSnitch
"OUTPUT --protocol tcp -m mark --mark 101285 -j REJECT" )
# TODO: Support IPv6!
def __init__(self, database):
self.desktop_parser = LinuxDesktopParser()
self.lock = Lock()
self.rules = Rules(database)
self.dns = DNSCollector()
self.q = NetfilterQueue()
self.q = NetfilterQueueWrapper(self)
self.procmon = ProcMon()
self.qt_app = QtApp()
self.desktop_parser = LinuxDesktopParser()
self.q.bind( 0, self.pkt_callback, 1024 * 2 )
def get_verdict(self,c):
verdict = self.rules.get_verdict(c)
if verdict is None:
with self.lock:
c.hostname = self.dns.get_hostname(c.dst_addr)
( save_option, verdict, apply_for_all ) = self.qt_app.prompt_user(c)
if save_option != Rule.ONCE:
self.rules.add_rule( c, verdict, apply_for_all, save_option )
return verdict
def pkt_callback(self,pkt):
verd = Rule.ACCEPT
self.connection_futures = weakref.WeakValueDictionary()
self.qt_app = QtApp(self.connection_futures)
self.latest_packet_id = 0
def pkt_callback(self, pkt):
try:
data = pkt.get_payload()
packet = IP(data)
if self.dns.is_dns_response(packet):
self.dns.add_response(packet)
if self.dns.add_response(IP(data)):
pkt.accept()
return
self.latest_packet_id += 1
conn = Connection(self.latest_packet_id, self.procmon,
self.desktop_parser, data)
if conn.proto is None:
logging.debug("Could not detect protocol for packet.")
return
elif conn.pid is None:
logging.debug("Could not detect process for connection.")
return
# Get verdict, if verdict cannot be found prompt user in thread
verd = self.rules.get_verdict(conn)
if verd == Rule.DROP:
drop_packet(pkt, conn)
elif verd == Rule.ACCEPT:
pkt.accept()
elif verd is None:
conn.hostname = self.dns.get_hostname(conn.dst_addr)
handler = PacketHandler(conn, pkt, self.rules)
self.connection_futures[conn.id] = handler.future
self.qt_app.prompt_user(conn)
else:
conn = Connection(self.procmon, self.desktop_parser, data)
if conn.proto is None:
logging.debug( "Could not detect protocol for packet." )
elif conn.pid is None:
logging.debug( "Could not detect process for connection." )
else:
verd = self.get_verdict( conn )
raise RuntimeError("Unhandled state")
except Exception as e:
logging.exception( "Exception on packet callback:" )
if verd == Rule.DROP:
logging.info( "Dropping %s from %s" % ( conn, conn.get_app_name() ) )
# mark this packet so iptables will drop it
pkt.set_mark(101285)
pkt.drop()
else:
pkt.accept()
logging.exception("Exception on packet callback:")
logging.exception(e)
def start(self):
for r in Snitch.IPTABLES_RULES:
logging.debug( "Applying iptables rule '%s'" % r )
os.system( "iptables -I %s" % r )
if ProcMon.is_ftrace_available():
self.procmon.enable()
self.procmon.start()
self.qt_app.run()
self.q.run()
def stop(self):
for r in Snitch.IPTABLES_RULES:
logging.debug( "Deleting iptables rule '%s'" % r )
os.system( "iptables -D %s" % r )
self.procmon.disable()
self.q.unbind()

View file

@ -18,71 +18,105 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from PyQt5 import QtCore, QtGui, uic, QtWidgets
from opensnitch.rule import Rule
import queue
import sys
import os
# TODO: Implement tray icon and menu.
# TODO: Implement rules editor.
RESOURCES_PATH = "%s/resources/" % os.path.dirname(sys.modules[__name__].__file__)
RESOURCES_PATH = "%s/resources/" % os.path.dirname(
sys.modules[__name__].__file__)
DIALOG_UI_PATH = "%s/dialog_hi.ui" % RESOURCES_PATH
class QtApp:
def __init__(self):
pass
def __init__(self, connection_futures):
self.app = QtWidgets.QApplication([])
self.connection_queue = queue.Queue()
self.dialog = Dialog(self.connection_queue, connection_futures)
def run(self):
self.app = QtWidgets.QApplication([])
self.app.exec()
def prompt_user( self, connection ):
dialog = Dialog( connection )
dialog.show()
self.app.exec_()
return dialog.result
def prompt_user(self, connection):
self.connection_queue.put(connection)
self.dialog.add_connection_signal.emit()
class Dialog( QtWidgets.QDialog, uic.loadUiType(DIALOG_UI_PATH)[0] ):
DEFAULT_RESULT = ( Rule.ONCE, Rule.ACCEPT, False )
def __init__( self, connection, parent=None ):
self.connection = connection
QtWidgets.QDialog.__init__( self, parent, QtCore.Qt.WindowStaysOnTopHint )
class Dialog(QtWidgets.QDialog, uic.loadUiType(DIALOG_UI_PATH)[0]):
DEFAULT_RESULT = (Rule.ONCE, Rule.ACCEPT, False)
MESSAGE_TEMPLATE = "<b>%s</b> (pid=%s) wants to connect to <b>%s</b> on <b>%s port %s%s</b>" # noqa
add_connection_signal = QtCore.pyqtSignal()
def __init__(self, connection_queue, connection_futures, parent=None):
self.connection_queue = connection_queue
self.connection = None
QtWidgets.QDialog.__init__(self, parent,
QtCore.Qt.WindowStaysOnTopHint)
self.setupUi(self)
self.init_widgets()
self.start_listeners()
self.connection_futures = connection_futures
self.add_connection_signal.connect(self.handle_connection)
@QtCore.pyqtSlot()
def handle_connection(self):
# This method will get called again after the user took action
# on the currently handled connection
if self.connection is not None:
return
try:
self.connection = self.connection_queue.get_nowait()
except queue.Empty:
return
self.setup_labels()
self.setup_icon()
self.setup_extra()
self.result = Dialog.DEFAULT_RESULT
self.show()
def setup_labels(self):
self.app_name_label.setText( self.connection.app.name )
self.app_name_label.setText(self.connection.app.name)
message = "<b>%s</b> (pid=%s) wants to connect to <b>%s</b> on <b>%s port %s%s</b>" % ( \
message = self.MESSAGE_TEMPLATE % (
self.connection.get_app_name_and_cmdline(),
self.connection.app.pid,
self.connection.hostname,
self.connection.proto.upper(),
self.connection.dst_port,
" (%s)" % self.connection.service if self.connection.service is not None else '' )
self.message_label.setText( message )
" (%s)" % self.connection.service or '')
self.message_label.setText(message)
def init_widgets(self):
self.app_name_label = self.findChild( QtWidgets.QLabel, "appNameLabel" )
self.message_label = self.findChild( QtWidgets.QLabel, "messageLabel" )
self.action_combo_box = self.findChild( QtWidgets.QComboBox, "actionComboBox" )
self.allow_button = self.findChild( QtWidgets.QPushButton, "allowButton" )
self.deny_button = self.findChild( QtWidgets.QPushButton, "denyButton" )
self.whitelist_button = self.findChild( QtWidgets.QPushButton, "whitelistButton" )
self.block_button = self.findChild( QtWidgets.QPushButton, "blockButton" )
self.icon_label = self.findChild( QtWidgets.QLabel, "iconLabel" )
self.app_name_label = self.findChild(QtWidgets.QLabel,
"appNameLabel")
self.message_label = self.findChild(QtWidgets.QLabel,
"messageLabel")
self.action_combo_box = self.findChild(QtWidgets.QComboBox,
"actionComboBox")
self.allow_button = self.findChild(QtWidgets.QPushButton,
"allowButton")
self.deny_button = self.findChild(QtWidgets.QPushButton,
"denyButton")
self.whitelist_button = self.findChild(QtWidgets.QPushButton,
"whitelistButton")
self.block_button = self.findChild(QtWidgets.QPushButton,
"blockButton")
self.icon_label = self.findChild(QtWidgets.QLabel, "iconLabel")
def start_listeners(self):
self.allow_button.clicked.connect( self._allow_action )
self.deny_button.clicked.connect( self._deny_action )
self.whitelist_button.clicked.connect( self._whitelist_action )
self.block_button.clicked.connect( self._block_action )
self.action_combo_box.currentIndexChanged[str].connect ( self._action_changed )
self.allow_button.clicked.connect(self._allow_action)
self.deny_button.clicked.connect(self._deny_action)
self.whitelist_button.clicked.connect(self._whitelist_action)
self.block_button.clicked.connect(self._block_action)
self.action_combo_box.currentIndexChanged[str].connect(
self._action_changed)
def setup_icon(self):
if self.connection.app.icon is not None:
@ -96,33 +130,43 @@ class Dialog( QtWidgets.QDialog, uic.loadUiType(DIALOG_UI_PATH)[0] ):
def _action_changed(self):
s_option = self.action_combo_box.currentText()
if s_option == "Until Quit" or s_option == "Forever":
self.whitelist_button.show()
self.block_button.show()
self.whitelist_button.show()
self.block_button.show()
elif s_option == "Once":
self.whitelist_button.hide()
self.block_button.hide()
self.whitelist_button.hide()
self.block_button.hide()
def _allow_action(self):
self._action( Rule.ACCEPT, False )
self._action(Rule.ACCEPT, False)
def _deny_action(self):
self._action( Rule.DROP, False )
self._action(Rule.DROP, False)
def _whitelist_action(self):
self._action( Rule.ACCEPT, True )
self._action(Rule.ACCEPT, True)
def _block_action(self):
self._action( Rule.DROP, True )
self._action(Rule.DROP, True)
def _action( self, verdict, apply_to_all=False ):
def _action(self, verdict, apply_to_all=False):
s_option = self.action_combo_box.currentText()
if s_option == "Once":
option = Rule.ONCE
option = Rule.ONCE
elif s_option == "Until Quit":
option = Rule.UNTIL_QUIT
option = Rule.UNTIL_QUIT
elif s_option == "Forever":
option = Rule.FOREVER
option = Rule.FOREVER
self.result = ( option, verdict, apply_to_all )
self.close()
# Set result future
try:
fut = self.connection_futures[self.connection.id]
except KeyError:
pass
else:
fut.set_result((option, verdict, apply_to_all))
# Check if we have any unhandled connections on the queue
self.connection = None # Indicate that next connection can be handled
self.hide()
self.handle_connection()