Below is the file 'fritz.cc' from this revision. You can also download the file.


/* fritz.it
 * A framework for intercepting application input, altering it,
 * and checking for abormal changes in application behaviour.
 *
 * Copyright (C) 2007 Grahame Bowland.
 * All rights reserved.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <iostream>
#include <exception>
#include <dlfcn.h>

/* some general design notes:
 * data is exchanged via a shared memory page
 * a pair of semaphores are used to hand control of the page between
 * the client and the server; a normal sequence is
 *
 *    client            server
 ******************************************
 *                      [ create semaphores, shared memory ]
 *                      sem_wait (client)
 *    sem_post (client)
 *    sem_wait (server)
 *                      sem_post (server)
 *    [ copy data into shm ]
 */


#include "fritz.hh"

/*** Fritzer ***/

Fritzer::Fritzer (void)
{
    client_sem = server_sem = NULL;
    initialise_from_environ ();
}

void
Fritzer::initialise_from_environ (void)
{
    my_pid = getpid ();
    shm_name = getenv ("FRITZ_SHM");
    client_sem_name = getenv ("FRITZ_CLIENT_SEM");
    server_sem_name = getenv ("FRITZ_SERVER_SEM");
    if (!shm_name || !server_sem_name || !client_sem_name ) {
        throw fritz_exception ("FRITZ_SHM, FRITZ_CLIENT)SEM and FRITZ_SERVER_SEM must be set.");
    }
}

void
Fritzer::map_the_shm (void)
{
    shm_mem = mmap ((void *)NULL, shm_size, PROT_READ|PROT_WRITE, MAP_SHARED, shm_fd, (off_t)0);
    if (shm_mem <= 0) {
        throw fritz_exception("Unable to map shared memory.");
    }
    fritz_params = (struct fritz_params *)shm_mem;
    fritz_contents = (unsigned char *)shm_mem + sizeof(struct fritz_params);
}

void
Fritzer::update_params (void)
{
    fritz_params->buffer_size = shm_size;
    fritz_params->pid = my_pid;
}

sem_t *
Fritzer::open_semaphore (const char *name)
{
    sem_t *rv;

    rv = sem_open (name, 0);
    if ((rv == (sem_t *)SEM_FAILED) || (!rv)) {
        throw fritz_exception ("Unable top open semaphore.");
    }
    return rv;
}

sem_t *
Fritzer::create_semaphore (const char *name)
{
    sem_t *rv;

    /* a bit brutal */
    sem_unlink (name);
    rv = sem_open (name, O_CREAT|O_EXCL, 0600, 0);
    if ((rv == (sem_t *)SEM_FAILED) || (!rv)) {
        throw fritz_exception ("Unable to create semaphore.");
    }
    return rv;
}

void
Fritzer::close_semaphores (void)
{
    sem_close (client_sem);
    sem_close (server_sem);
    client_sem = server_sem = NULL;
}

size_t
Fritzer::get_message_id (void)
{
    return fritz_params->message_id;
}

/*** FritzServer ***/

void
FritzServer::create_shared_memory (void)
{
    shm_fd = 0;
    /* a bit brutal */
    shm_unlink (shm_name);
    shm_fd = shm_open (shm_name, O_RDWR|O_CREAT|O_EXCL, 0600);
    if (shm_fd < 0) {
        throw fritz_exception ("Unable to open shared memory segment.");
    }
    if (ftruncate (shm_fd, shm_size) == -1) {
        throw fritz_exception ("Unable to truncate shared memory segment.");
    }
}

void
FritzServer::wait_for_client (void)
{
    if (sem_wait (client_sem) == -1) {
        throw fritz_exception ("Error waiting on semaphore for control.");
    }
}

void
FritzServer::create_semaphores (void)
{
    server_sem = create_semaphore (server_sem_name);
    client_sem = create_semaphore (client_sem_name);
}

FritzServer::FritzServer (void)
{
    mode = "server";
    create_semaphores ();
    create_shared_memory ();
    map_the_shm ();
    fritz_params->message_id = 0;
}

FritzServer::~FritzServer (void)
{
    shm_unlink (shm_name);
    close_semaphores ();
    sem_unlink (client_sem_name);
    sem_unlink (server_sem_name);
}

void
FritzServer::signal_client (void)
{
   if (sem_post (server_sem) == -1) {
       throw fritz_exception ("Unable to sem_post");
   }
}

void
FritzServer::rewrite_rot13 (void)
{
    for (size_t i=0;i<fritz_params->contents_size;i++) {
	char base='\0';
	if (fritz_contents[i] >= 'A' && fritz_contents[i] <= 'Z') {
	    base='A';
	} else if (fritz_contents[i] >= 'a' && fritz_contents[i] <= 'z') {
	    base='a';
	}
	if (base != '\0') {
	    fritz_contents[i] = base + (13 + fritz_contents[i] - base) % 26;
	}
    }
}

void
FritzServer::rewrite (void)
{
    for (size_t i=0;i<fritz_params->contents_size;i++) {
        fritz_contents[i]++;
    }
//    std::cerr << "rewrote " << fritz_params->contents_size << " bytes." << std::endl;
}

void
FritzServer::run (void)
{
    for (;;) {
	wait_for_client ();
        update_params ();
        fritz_params->message_id++;
        rewrite_rot13 ();
        signal_client ();
    }
}

/*** FritzClient ***/

void
FritzClient::open_semaphores ()
{
    server_sem = open_semaphore (server_sem_name);
    client_sem = open_semaphore (client_sem_name);
}

void
FritzClient::open_shared_memory (void)
{
    shm_fd = shm_open (shm_name, O_RDWR, 0600);
    if (shm_fd < 0) {
        throw fritz_exception ("Unable to open shared memory segment.");
    }
}

void
FritzClient::wait_for_server (void)
{
    if (sem_wait (server_sem) == -1) {
        throw fritz_exception ("Error waiting on semaphore for control.");
    }
}

void
FritzClient::signal_server (void)
{
   if (sem_post (client_sem) == -1) {
       throw fritz_exception ("Unable to sem_post");
   }
}

FritzClient::FritzClient (void)
{
    mode = "client";
    open_semaphores ();
    open_shared_memory ();
    map_the_shm ();
}

FritzClient::~FritzClient (void)
{
    close_semaphores ();
}

void
FritzClient::pass_to_server (void)
{
    signal_server ();
    wait_for_server ();
}

void *
FritzClient::grab_symbol_from (const char *lib_env, const char *sym)
{
    void *dl, *rv;
    char *dl_file, *error;

    dl_file = getenv (lib_env);
    if (!dl_file) {
        throw fritz_exception ("Can't fritz library, as environment variable is not present.");
    }
    dl = dlopen (dl_file, RTLD_LAZY);
    dlerror ();
    rv = dlsym (dl, sym);
    if ((error = dlerror()) != NULL) {
        std::cerr << "error loading symbol: " << error << std::endl;
        throw fritz_exception ("Error grabbing symbol.");
    }
    std::cerr << "got the symbol \"" << sym << "\" from \"" << dl_file << "\" at " << rv << std::endl;
    dlclose (dl);
    return rv;
}

void
FritzClient::copy_data (size_t nbytes, void *from)
{
    update_params ();
    int mbytes = fritz_params->buffer_size - \
    			((unsigned char *)fritz_params - fritz_contents);
    if (nbytes > mbytes) {
        throw fritz_exception ("buffer overrun; can't copy this many bytes!");
    }
    fritz_params->contents_size = nbytes;
    memcpy (fritz_contents, from, nbytes);
}

size_t
FritzClient::copy_back (size_t maxbytes, void *buf)
{
    if (fritz_params->contents_size > maxbytes)  {
        throw fritz_exception ("buffer overrun; can't copy this many bytes back!");
    }
    memcpy (buf, fritz_contents, fritz_params->contents_size);
    return fritz_params->contents_size;
}