scyther/gui/mpa.py

97 lines
2.7 KiB
Python
Executable File

#!/usr/bin/python
"""
Test script to execute multi-protocol attacks on some test set.
"""
import Scyther
def MyScyther(protocollist,filter=None):
"""
Evaluate the composition of the protocols in protocollist.
If there is a filter, i.e. "ns3,I1" then only this specific claim
will be evaluated.
"""
s = Scyther.Scyther()
s.options = "--match=2"
if filter:
s.options += " --filter=%s" % (filter)
for protocol in protocollist:
s.addFile(protocol)
s.verify()
return s
def getCorrectIsolatedClaims(protocolset):
"""
Given a set of protocols, determine the correct claims when run in
isolation.
Returns a tuple, consisting of
- a list of compiling protocols
- a list of tuples (protocol,claimid) wich denote correct claims
"""
correctclaims = []
goodprotocols = []
for protocol in protocolset:
# verify protocol in isolation
s = MyScyther([protocol])
# investigate the results
if not s.errors:
goodprotocols.append(protocol)
for claim in s.claims:
if claim.okay:
correctclaims.append((protocol,claim.id))
return (goodprotocols,correctclaims)
def findMPA(protocolset,protocol,claimid,maxcount=3):
"""
The protocol claim is assumed to be correct. When does it break?
"""
count = 2
if len(protocolset) < maxcount:
maxcount = len(protocolset)
def verifyMPAlist(mpalist):
# This should be a more restricted verification
s = MyScyther(mpalist,claimid)
cl = s.getClaim(claimid)
if cl:
if not cl.okay:
# This is an MPA attack!
print "I've found a multi-protocol attack on claim %s in the context %s." % (claimid,str(mpalist))
return mpalist
def constructMPAlist(mpalist,start,callback):
if len(mpalist) < count:
for pn in range(start,len(protocolset)):
p = protocolset[pn]
if p not in mpalist:
constructMPAlist(mpalist + [p],pn+1,callback)
else:
callback(mpalist)
while count <= maxcount:
constructMPAlist([protocol],0,verifyMPAlist)
count += 1
return None
def findAllMPA(protocolset,maxcount=3):
"""
Given a set of protocols, find multi-protocol attacks
"""
(protocolset,correct) = getCorrectIsolatedClaims(protocolset)
print correct
for (protocol,claimid) in correct:
mpalist = findMPA(protocolset,protocol,claimid,maxcount=3)
if mpalist:
print "Darn, MPA on %s (%s) using %s" % (claimid,protocol,mpalist)
if __name__ == '__main__':
list = ['me.spdl','ns3.spdl','nsl3.spdl']
findAllMPA(list)