Below is the file 'tftp.py' from this revision. You can also download the file.

#!/usr/bin/python

import struct
import socket
import os

opcodes = { "RRQ" : 1, "WRQ" : 2, "DATA" : 3, "ACK" : 4, "ERROR" : 5 }
ropcodes = {}
for opcode in opcodes: ropcodes[opcodes[opcode]] = opcode

tftp_max_packet_size = 516
tftp_block_size = 512

def set_nonblocking(fd):
	fd.setblocking(False)

class TFTPError(Exception):
	def __init__(self, mesg):
		self._mesg = mesg
	def __repr__(self):
		return "TFTP Error: %s" % (self._mesg)

MAX_RETRANSMIT=5

class TFTPPacket:
	def __init__(self, last=False, retransmit=False):
		self.opcode = None
		self.last = last
		self.retransmit = retransmit
		self.retransmit_count = 0
		self.attrs = {}
	def is_last(self):
		return self.last
	def should_retransmit(self):
		return self.retransmit
	def serialise(self):
		data = struct.pack("!h", self.opcode)
		if self.opcode == opcodes['RRQ'] or self.opcode == opcodes['WRQ']:
			data += self.attrs['filename'] + "\0" + self.attrs['mode'] + "\0"
		elif self.opcode == opcodes['DATA']:
			data += struct.pack("!h", self.attrs['block'])
			data += self.attrs['data']
		elif self.opcode == opcodes['ACK']:
			data += struct.pack("!h", self.attrs['block'])
		elif self.opcode == opcodes['ERROR']:
			data += struct.pack("!h", self.attrs['errorcode'])
			data += self.attrs['errmsg'] + "\0"
		else: raise TFTPError("Unknown opcode: %d" % (self.opcode))
		return data
	def deserialise(self, data):
		if len(data) > tftp_max_packet_size: raise TFTPError("Packet too long")
		self.attrs = {}
		self.opcode = struct.unpack("!h", data[:2])[0]
		if self.opcode == opcodes['RRQ'] or self.opcode == opcodes['WRQ']:
			np = data[2:].split('\0')
			if len(np) <> 3: raise TFTPError("Malformed RRQ/WRQ packet - wrong number of parts")
			self.attrs['filename'] = np[0]
			self.attrs['mode'] = np[1]
		elif self.opcode == opcodes['DATA']:
			self.attrs['block'] = struct.unpack('!h', data[2:4])[0]
			self.attrs['data'] = data[4:]
		elif self.opcode == opcodes['ACK']:
			self.attrs['block'] = struct.unpack('!h', data[2:4])[0]
		elif self.opcode == opcodes['ERROR']:
			self.attrs['errorcode'] = struct.unpack('!h', data[2:4])[0]
			if data[-1] != '\0': raise TFTPError("Malformed ERROR packet")
			self.attrs['errmsg'] = data[4:-1]
		else: raise TFTPError("Unknown opcode: %d" % (self.opcode))

class TFTPTransfer:
	def __init__(self, io, opcode, attrs):
		self.opcode = opcode
		self.attrs = attrs
		self.buffer = ""
		self.last_block = None
		self.complete = False
	def append(self, block, data):
		# packet out of order, we wait for the right packet
		# to appear
		if self.last_block and block != self.last_block + 1: return False
		self.last_block = block
		self.buffer += data
		return True
	def set_buffer(self, data):
		self.buffer = data
		self.block = 1
		self.last_block = None
	def consume(self, max):
		block = self.block
		self.block += 1
		self.last_block = block
		rb = len(self.buffer)
		if rb >= max: rb = max
		else: self.complete = True
		data = self.buffer[:rb]
		self.buffer = self.buffer[rb:]
		return block, data
	def data(self):
		return self.buffer

class TFTPStore:
	def __init__(self):
		self.export = {}
		self.received = {}
	def clear_received(self, filename):
		if self.received.has_key(filename): self.received.pop(filename)
	def set_received(self, filename, data):
		self.received[filename] = data
	def has_received(filename):
		return self.received.has_key(filename)
	def get_received(self, filename):
		return self.received.get(filename, None)
	def has_export(self, filename):
		return self.export.has_key(filename)
	def clear_export(self, filename):
		if self.export.has_key(filename): self.export.pop(filename)
	def set_export(self, filename, data):
		self.export[filename] = data
	def get_export(self, filename):
		return self.export.get(filename, None)

class TFTPIOFunctions:
	def __init__(self, set_read, set_write, clear_read, clear_write, set_timeout, remove_timeout, notify, authenticate):
		self.set_read = set_read
		self.set_write = set_write
		self.clear_read = clear_read
		self.clear_write = clear_write
		self.set_timeout = set_timeout
		self.remove_timeout = remove_timeout
		self.notify = notify
		self.authenticate = authenticate

class TFTPConnection:
	def __init__(self, store, io, addr, pkt, verbose=False):
		self.store = store
		self.io = io
		self.client_addr = addr
		self.pkt = pkt
		self.verbose = verbose
		self.retransmit_id = None
		self.packet_to_send = None
		self.packet_to_retransmit = None
		self.current_transfer = None

		self.fd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
		set_nonblocking(self.fd)
		self.fd.bind(('', 0))
		self.process_packet(pkt)
	def set_read_trigger(self):
		if self.verbose: print "** read trigger set"
		self.io.set_read(self.fd, self.recv_data)
	def set_write_trigger(self):
		if self.verbose: print "** write trigger set"
		self.io.set_write(self.fd, self.xmit_data)
	def hangup(self):
		self.io.clear_read(self.fd)
		self.io.clear_write(self.fd)
		self.fd.close()
		del self
	def recv_data(self, source, condition, user=None):
		if self.verbose: print "** read trigger consumed"
		received = []
		while 1:
			try: received.append(self.fd.recvfrom(tftp_max_packet_size))
			except socket.error: break
		for data, addr in received:
			# some dodgy clients (eg. Linux "tftp" client)
			# will send their final ACK from a different TID!
			# workaround for them is to ignore the port number
			# when checking.. probably dodgy
			if addr[0] != self.client_addr[0]: continue
			pkt = TFTPPacket()
			pkt.deserialise(data)
			if self.verbose: print "Packet from client (%s): %s" % (ropcodes[pkt.opcode], pkt.attrs)
			self.process_packet(pkt)
			if pkt.is_last():
				self.hangup()
	def xmit_data(self, source, condition, user=None):
		if self.verbose: print "** write trigger consumed"
		if not self.packet_to_send: return
		data = self.packet_to_send.serialise()
		try: self.fd.sendto(data, 0, self.client_addr)
		except socket.error:
			self.set_write_trigger()
			return
		if self.verbose: print "Packet to client (%s): %s" % (ropcodes[self.packet_to_send.opcode], self.packet_to_send.attrs)
		if self.packet_to_send.is_last():
			self.hangup()
		if self.packet_to_send.should_retransmit() and not self.retransmit_id:
			self.packet_to_retransmit = self.packet_to_send
			self.retransmit_id = self.io.set_timeout(3000, self.retransmit)
		else: self.packet_to_retransmit = None
		self.packet_to_send = None
	def retransmit(self, user_data=None):
		if self.packet_to_retransmit and self.packet_to_retransmit.should_retransmit():
			if self.packet_to_retransmit.retransmit_count > MAX_RETRANSMIT:
				# cancel retransmission if we're up to MAX_RETRANSMIT retransmits
				return False
			data = self.packet_to_retransmit.serialise()
			if self.verbose: print "Retransmit (%s): %s" % (ropcodes[self.packet_to_retransmit.opcode], self.packet_to_retransmit.attrs)
			self.packet_to_retransmit.retransmit_count += 1
			try: self.fd.sendto(data, 0, self.client_addr)
			except: pass
			return True
		else: return False
	def send_packet(self, pkt):
		if self.packet_to_send: raise TFTPError("There is already a packet in the send buffer")
		self.packet_to_send = pkt
		self.set_write_trigger()
	def send_error(self, code, message):
		pkt = TFTPPacket()
		self.io.notify("error_sent", (self.client_addr, code, message))
		pkt.opcode = opcodes['ERROR']
		pkt.attrs['errorcode'] = code
		pkt.attrs['errmsg'] = message
		self.send_packet(pkt)
		self.set_read_trigger()
	def send_ack(self, block, last=False, retransmit=False):
		pkt = TFTPPacket(last=last, retransmit=retransmit)
		pkt.opcode = opcodes['ACK']
		pkt.attrs['block'] = block
		self.send_packet(pkt)
	def send_data(self):
		if not self.current_transfer.complete:
			block, data = self.current_transfer.consume(tftp_block_size)
			pkt = TFTPPacket(retransmit=True)
			pkt.opcode = opcodes['DATA']
			pkt.attrs['block'] = block
			pkt.attrs['data'] = data
			self.send_packet(pkt)
			self.set_read_trigger()
		else: # nothing more to do
			self.io.notify("file_sent", (self.client_addr, self.current_transfer.attrs['filename']))
			self.hangup()
			return
	def cancel_retransmission(self):
		if self.retransmit_id:
			self.io.remove_timeout(self.retransmit_id)
			self.packet_to_retransmit = None
			self.retransmit_id = None
	def process_packet(self, pkt):
		if pkt.opcode == opcodes['RRQ']:
			# whatever happens, we don't have anything to retransmit (same for WRQ)
			self.cancel_retransmission()
			if self.current_transfer:
				self.send_error(4, "We already have an operation to perform")
				return
			if not self.io.authenticate("RRQ", self.client_addr, pkt.attrs['filename']):
				self.send_error(2, "Access violation.")
				return
			data = self.store.get_export(pkt.attrs['filename'])
			if not data:
				self.send_error(1, "File not found")
				return
			self.current_transfer = TFTPTransfer(self.io, pkt.opcode, pkt.attrs)
			self.current_transfer.set_buffer(data)
			self.send_data()
		elif pkt.opcode == opcodes['WRQ']:
			self.cancel_retransmission()
			if self.current_transfer:
				self.send_error(4, "We already have an operation to perform")
				return
			if not self.io.authenticate("WRQ", self.client_addr, pkt.attrs['filename']):
				self.send_error(2, "Access violation.")
				return
			self.current_transfer = TFTPTransfer(self.io, pkt.opcode, pkt.attrs)
			self.send_ack(0, retransmit=True)
			self.set_read_trigger()
		elif pkt.opcode == opcodes['DATA']:
			if not self.current_transfer or self.current_transfer.opcode != opcodes['WRQ']:
				self.send_error(4, "DATA when not in write operation")
				return
			if self.current_transfer.append(pkt.attrs['block'], pkt.attrs['data']):
				# a new data block; we can stop retransmitting the last ACK now.
				self.cancel_retransmission()
				self.io.notify("data_received", (self.client_addr, self.current_transfer.attrs['filename'], len(pkt.attrs['data'])))
			if len(pkt.attrs['data']) < tftp_block_size:
				self.store.set_received(self.current_transfer.attrs['filename'], self.current_transfer.data())
				self.send_ack(pkt.attrs['block'], last=True)
				self.io.notify("file_received", (self.client_addr, self.current_transfer.attrs['filename']))
			else:
				self.send_ack(pkt.attrs['block'], retransmit=True)
				self.set_read_trigger()
		elif pkt.opcode == opcodes['ACK']:
			if not self.current_transfer or self.current_transfer.opcode != opcodes['RRQ']:
				# okay, that's just.. whacky. they ACKed when we're not sending something to them.
				self.hangup()
				return
			# ignore duplicate ACKs, go to transmit only when we get an ACK for the last block we sent as DATA
			if pkt.attrs['block'] > self.current_transfer.last_block:
				self.send_error(4, "ACK for wrong block - %d (should be %d)" % (pkt.attrs['block'], self.current_transfer.last_block))
				self.set_read_trigger()
			elif pkt.attrs['block'] == self.current_transfer.last_block:
				# well, that data was successfully sent
				self.io.notify("data_sent", (self.client_addr, self.current_transfer.attrs['filename'], len(self.packet_to_retransmit.attrs['data'])))
				# stop retransmitting the last DATA block then
				self.cancel_retransmission()
				# now, send the next block
				self.send_data()
			else:
				# duplicate ACK; read until we get the right one
				self.set_read_trigger()
		elif pkt.opcode == opcodes['ERROR']:
			self.io.notify("error_received", (self.client_addr, pkt.attrs['errorcode'], pkt.attrs['errmsg']))
			# closing the socket here seems to cause problems for some clients.
			# however, we don't want to keep sending them junk either. so cancel retransmission
			self.cancel_retransmission()
			#self.hangup()
		else:
			self.send_error(4, "Illegal operation")

class TFTPServer:
	def __init__(self, io, port=69, verbose=False):
		self.port = port
		self.io = io
		self.verbose = verbose
		self.serve_files = {}
		self.received_files = {}
		self.store = TFTPStore()
		# this is UDP socket upon which new requests will arrive
		self.fd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
		set_nonblocking(self.fd)
		self.fd.bind(('', self.port))
		self.set_read_trigger()
	def set_read_trigger(self):
		self.io.set_read(self.fd, self.recv_data)
	def recv_data(self, source, condition, user=None):
		received = []
		while 1:
			try: received.append(self.fd.recvfrom(tftp_max_packet_size))
			except socket.error: break
		for data, addr in received:
			pkt = TFTPPacket()
			pkt.deserialise(data)
			if self.verbose: print "Packet from client:", pkt.attrs
			TFTPConnection(self.store, self.io, addr, pkt, verbose=self.verbose)
		self.set_read_trigger()
	def set_file(self, filename, bytes):
		self.serve_files[filename] = bytes
	def get_files(self):
		return self.received_files