Below is the file 'tftp.py' from this revision. You can also download the file.
#!/usr/bin/python import struct import socket import os opcodes = { "RRQ" : 1, "WRQ" : 2, "DATA" : 3, "ACK" : 4, "ERROR" : 5 } ropcodes = {} for opcode in opcodes: ropcodes[opcodes[opcode]] = opcode tftp_max_packet_size = 516 tftp_block_size = 512 def set_nonblocking(fd): fd.setblocking(False) class TFTPError(Exception): def __init__(self, mesg): self._mesg = mesg def __repr__(self): return "TFTP Error: %s" % (self._mesg) MAX_RETRANSMIT=5 class TFTPPacket: def __init__(self, last=False, retransmit=False): self.opcode = None self.last = last self.retransmit = retransmit self.retransmit_count = 0 self.attrs = {} def is_last(self): return self.last def should_retransmit(self): return self.retransmit def serialise(self): data = struct.pack("!h", self.opcode) if self.opcode == opcodes['RRQ'] or self.opcode == opcodes['WRQ']: data += self.attrs['filename'] + "\0" + self.attrs['mode'] + "\0" elif self.opcode == opcodes['DATA']: data += struct.pack("!h", self.attrs['block']) data += self.attrs['data'] elif self.opcode == opcodes['ACK']: data += struct.pack("!h", self.attrs['block']) elif self.opcode == opcodes['ERROR']: data += struct.pack("!h", self.attrs['errorcode']) data += self.attrs['errmsg'] + "\0" else: raise TFTPError("Unknown opcode: %d" % (self.opcode)) return data def deserialise(self, data): if len(data) > tftp_max_packet_size: raise TFTPError("Packet too long") self.attrs = {} self.opcode = struct.unpack("!h", data[:2])[0] if self.opcode == opcodes['RRQ'] or self.opcode == opcodes['WRQ']: np = data[2:].split('\0') if len(np) <> 3: raise TFTPError("Malformed RRQ/WRQ packet - wrong number of parts") self.attrs['filename'] = np[0] self.attrs['mode'] = np[1] elif self.opcode == opcodes['DATA']: self.attrs['block'] = struct.unpack('!h', data[2:4])[0] self.attrs['data'] = data[4:] elif self.opcode == opcodes['ACK']: self.attrs['block'] = struct.unpack('!h', data[2:4])[0] elif self.opcode == opcodes['ERROR']: self.attrs['errorcode'] = struct.unpack('!h', data[2:4])[0] if data[-1] != '\0': raise TFTPError("Malformed ERROR packet") self.attrs['errmsg'] = data[4:-1] else: raise TFTPError("Unknown opcode: %d" % (self.opcode)) class TFTPTransfer: def __init__(self, io, opcode, attrs): self.opcode = opcode self.attrs = attrs self.buffer = "" self.last_block = None self.complete = False def append(self, block, data): # packet out of order, we wait for the right packet # to appear if self.last_block and block != self.last_block + 1: return False self.last_block = block self.buffer += data return True def set_buffer(self, data): self.buffer = data self.block = 1 self.last_block = None def consume(self, max): block = self.block self.block += 1 self.last_block = block rb = len(self.buffer) if rb >= max: rb = max else: self.complete = True data = self.buffer[:rb] self.buffer = self.buffer[rb:] return block, data def data(self): return self.buffer class TFTPStore: def __init__(self): self.export = {} self.received = {} def clear_received(self, filename): if self.received.has_key(filename): self.received.pop(filename) def set_received(self, filename, data): self.received[filename] = data def has_received(filename): return self.received.has_key(filename) def get_received(self, filename): return self.received.get(filename, None) def has_export(self, filename): return self.export.has_key(filename) def clear_export(self, filename): if self.export.has_key(filename): self.export.pop(filename) def set_export(self, filename, data): self.export[filename] = data def get_export(self, filename): return self.export.get(filename, None) class TFTPIOFunctions: def __init__(self, set_read, set_write, clear_read, clear_write, set_timeout, remove_timeout, notify, authenticate): self.set_read = set_read self.set_write = set_write self.clear_read = clear_read self.clear_write = clear_write self.set_timeout = set_timeout self.remove_timeout = remove_timeout self.notify = notify self.authenticate = authenticate class TFTPConnection: def __init__(self, store, io, addr, pkt, verbose=False): self.store = store self.io = io self.client_addr = addr self.pkt = pkt self.verbose = verbose self.retransmit_id = None self.packet_to_send = None self.packet_to_retransmit = None self.current_transfer = None self.fd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) set_nonblocking(self.fd) self.fd.bind(('', 0)) self.process_packet(pkt) def set_read_trigger(self): if self.verbose: print "** read trigger set" self.io.set_read(self.fd, self.recv_data) def set_write_trigger(self): if self.verbose: print "** write trigger set" self.io.set_write(self.fd, self.xmit_data) def hangup(self): self.io.clear_read(self.fd) self.io.clear_write(self.fd) self.fd.close() del self def recv_data(self, source, condition, user=None): if self.verbose: print "** read trigger consumed" received = [] while 1: try: received.append(self.fd.recvfrom(tftp_max_packet_size)) except socket.error: break for data, addr in received: # some dodgy clients (eg. Linux "tftp" client) # will send their final ACK from a different TID! # workaround for them is to ignore the port number # when checking.. probably dodgy if addr[0] != self.client_addr[0]: continue pkt = TFTPPacket() pkt.deserialise(data) if self.verbose: print "Packet from client (%s): %s" % (ropcodes[pkt.opcode], pkt.attrs) self.process_packet(pkt) if pkt.is_last(): self.hangup() def xmit_data(self, source, condition, user=None): if self.verbose: print "** write trigger consumed" if not self.packet_to_send: return data = self.packet_to_send.serialise() try: self.fd.sendto(data, 0, self.client_addr) except socket.error: self.set_write_trigger() return if self.verbose: print "Packet to client (%s): %s" % (ropcodes[self.packet_to_send.opcode], self.packet_to_send.attrs) if self.packet_to_send.is_last(): self.hangup() if self.packet_to_send.should_retransmit() and not self.retransmit_id: self.packet_to_retransmit = self.packet_to_send self.retransmit_id = self.io.set_timeout(3000, self.retransmit) else: self.packet_to_retransmit = None self.packet_to_send = None def retransmit(self, user_data=None): if self.packet_to_retransmit and self.packet_to_retransmit.should_retransmit(): if self.packet_to_retransmit.retransmit_count > MAX_RETRANSMIT: # cancel retransmission if we're up to MAX_RETRANSMIT retransmits return False data = self.packet_to_retransmit.serialise() if self.verbose: print "Retransmit (%s): %s" % (ropcodes[self.packet_to_retransmit.opcode], self.packet_to_retransmit.attrs) self.packet_to_retransmit.retransmit_count += 1 try: self.fd.sendto(data, 0, self.client_addr) except: pass return True else: return False def send_packet(self, pkt): if self.packet_to_send: raise TFTPError("There is already a packet in the send buffer") self.packet_to_send = pkt self.set_write_trigger() def send_error(self, code, message): pkt = TFTPPacket() self.io.notify("error_sent", (self.client_addr, code, message)) pkt.opcode = opcodes['ERROR'] pkt.attrs['errorcode'] = code pkt.attrs['errmsg'] = message self.send_packet(pkt) self.set_read_trigger() def send_ack(self, block, last=False, retransmit=False): pkt = TFTPPacket(last=last, retransmit=retransmit) pkt.opcode = opcodes['ACK'] pkt.attrs['block'] = block self.send_packet(pkt) def send_data(self): if not self.current_transfer.complete: block, data = self.current_transfer.consume(tftp_block_size) pkt = TFTPPacket(retransmit=True) pkt.opcode = opcodes['DATA'] pkt.attrs['block'] = block pkt.attrs['data'] = data self.send_packet(pkt) self.set_read_trigger() else: # nothing more to do self.io.notify("file_sent", (self.client_addr, self.current_transfer.attrs['filename'])) self.hangup() return def cancel_retransmission(self): if self.retransmit_id: self.io.remove_timeout(self.retransmit_id) self.packet_to_retransmit = None self.retransmit_id = None def process_packet(self, pkt): if pkt.opcode == opcodes['RRQ']: # whatever happens, we don't have anything to retransmit (same for WRQ) self.cancel_retransmission() if self.current_transfer: self.send_error(4, "We already have an operation to perform") return if not self.io.authenticate("RRQ", self.client_addr, pkt.attrs['filename']): self.send_error(2, "Access violation.") return data = self.store.get_export(pkt.attrs['filename']) if not data: self.send_error(1, "File not found") return self.current_transfer = TFTPTransfer(self.io, pkt.opcode, pkt.attrs) self.current_transfer.set_buffer(data) self.send_data() elif pkt.opcode == opcodes['WRQ']: self.cancel_retransmission() if self.current_transfer: self.send_error(4, "We already have an operation to perform") return if not self.io.authenticate("WRQ", self.client_addr, pkt.attrs['filename']): self.send_error(2, "Access violation.") return self.current_transfer = TFTPTransfer(self.io, pkt.opcode, pkt.attrs) self.send_ack(0, retransmit=True) self.set_read_trigger() elif pkt.opcode == opcodes['DATA']: if not self.current_transfer or self.current_transfer.opcode != opcodes['WRQ']: self.send_error(4, "DATA when not in write operation") return if self.current_transfer.append(pkt.attrs['block'], pkt.attrs['data']): # a new data block; we can stop retransmitting the last ACK now. self.cancel_retransmission() self.io.notify("data_received", (self.client_addr, self.current_transfer.attrs['filename'], len(pkt.attrs['data']))) if len(pkt.attrs['data']) < tftp_block_size: self.store.set_received(self.current_transfer.attrs['filename'], self.current_transfer.data()) self.send_ack(pkt.attrs['block'], last=True) self.io.notify("file_received", (self.client_addr, self.current_transfer.attrs['filename'])) else: self.send_ack(pkt.attrs['block'], retransmit=True) self.set_read_trigger() elif pkt.opcode == opcodes['ACK']: if not self.current_transfer or self.current_transfer.opcode != opcodes['RRQ']: # okay, that's just.. whacky. they ACKed when we're not sending something to them. self.hangup() return # ignore duplicate ACKs, go to transmit only when we get an ACK for the last block we sent as DATA if pkt.attrs['block'] > self.current_transfer.last_block: self.send_error(4, "ACK for wrong block - %d (should be %d)" % (pkt.attrs['block'], self.current_transfer.last_block)) self.set_read_trigger() elif pkt.attrs['block'] == self.current_transfer.last_block: # well, that data was successfully sent self.io.notify("data_sent", (self.client_addr, self.current_transfer.attrs['filename'], len(self.packet_to_retransmit.attrs['data']))) # stop retransmitting the last DATA block then self.cancel_retransmission() # now, send the next block self.send_data() else: # duplicate ACK; read until we get the right one self.set_read_trigger() elif pkt.opcode == opcodes['ERROR']: self.io.notify("error_received", (self.client_addr, pkt.attrs['errorcode'], pkt.attrs['errmsg'])) # closing the socket here seems to cause problems for some clients. # however, we don't want to keep sending them junk either. so cancel retransmission self.cancel_retransmission() #self.hangup() else: self.send_error(4, "Illegal operation") class TFTPServer: def __init__(self, io, port=69, verbose=False): self.port = port self.io = io self.verbose = verbose self.serve_files = {} self.received_files = {} self.store = TFTPStore() # this is UDP socket upon which new requests will arrive self.fd = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) set_nonblocking(self.fd) self.fd.bind(('', self.port)) self.set_read_trigger() def set_read_trigger(self): self.io.set_read(self.fd, self.recv_data) def recv_data(self, source, condition, user=None): received = [] while 1: try: received.append(self.fd.recvfrom(tftp_max_packet_size)) except socket.error: break for data, addr in received: pkt = TFTPPacket() pkt.deserialise(data) if self.verbose: print "Packet from client:", pkt.attrs TFTPConnection(self.store, self.io, addr, pkt, verbose=self.verbose) self.set_read_trigger() def set_file(self, filename, bytes): self.serve_files[filename] = bytes def get_files(self): return self.received_files