/* * Copyright (c) 2010, Manuel Mausz. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * - Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * - Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * - The names of the authors may not be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS * IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ import static java.lang.System.err; import static java.lang.System.out; import javax.crypto.*; import java.security.*; import org.bouncycastle.openssl.PEMReader; import org.bouncycastle.util.encoders.Base64; import org.bouncycastle.openssl.PasswordFinder; import org.bouncycastle.util.encoders.Hex; import java.nio.ByteBuffer; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.nio.channels.DatagramChannel; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.MissingResourceException; import java.util.Collections; import java.util.Enumeration; import java.util.ArrayList; import java.util.Calendar; import java.util.Iterator; import java.util.HashMap; import java.util.Arrays; import java.util.Map; import java.net.*; import java.io.*; /* * Proxy implementation for Lab#1 of DSLab WS10 * See angabe.pdf for details * * This code is not documented at all. This is volitional * * @author Manuel Mausz (0728348) */ public class Proxy { public class FSRecord implements Comparable { public final String host; public final int port; public int usage; public boolean online; public long lastUpdate; FSRecord(String host, int port) { this.host = host; this.port = port; usage = 0; online = true; lastUpdate = Calendar.getInstance().getTimeInMillis(); } public void ping() { online = true; lastUpdate = Calendar.getInstance().getTimeInMillis(); } public boolean equals(FSRecord o) { return online == o.online && usage == o.usage; } public int compareTo(FSRecord o) { return usage - o.usage; } } public class FSRecords extends HashMap {} /*==========================================================================*/ public class UserRecord { public final String name; public String pass; public int credits; public ArrayList loggedin; UserRecord(String name, int credits) { this.name = name; this.pass = null; this.credits = credits; loggedin = new ArrayList(); } } public class UserRecords extends HashMap {} /*==========================================================================*/ public class ProxyConnection implements Runnable { private final String host; private final int port; private FSRecords fileservers; ProxyConnection(String host, int port, FSRecords fileservers) { this.host = host; this.port = port; this.fileservers = fileservers; } /*------------------------------------------------------------------------*/ public void run() { String key = host + ":" + port; synchronized(fileservers) { FSRecord record = fileservers.get(key); if (record == null) { fileservers.put(key, new FSRecord(host, port)); out.println("New fileserver registered: " + key); } else { if (!record.online) out.println("Fileserver is online again: " + key); record.ping(); } } } } /*==========================================================================*/ public class UDPSocketReader implements Runnable { private final DatagramChannel dchannel; private FSRecords fileservers; private final Object mainLock; private final ExecutorService pool; UDPSocketReader(DatagramChannel dchannel, FSRecords fileservers, Object mainLock) { this.dchannel = dchannel; this.fileservers = fileservers; this.mainLock = mainLock; this.pool = Executors.newCachedThreadPool(); } /*------------------------------------------------------------------------*/ public void run() { try { String tmp = "!alive 12345"; ByteBuffer buffer = ByteBuffer.allocate(tmp.getBytes().length); while(true) { buffer.clear(); InetSocketAddress proxyaddr = (InetSocketAddress) dchannel.receive(buffer); String msg = new String(buffer.array()); if (msg.length() != tmp.length()) continue; assert msg.matches("!alive 1[0-9]{4}"); try { pool.execute(new ProxyConnection(proxyaddr.getHostName(), Integer.parseInt(msg.substring("!alive ".length())), fileservers)); } catch(NumberFormatException e) { /* simple ignore that packet */ } } } catch(IOException e) { /* ignore that exception * thread will shutdown and unlock the main thread * which will shutdown the application */ } pool.shutdown(); try { if (!pool.awaitTermination(100, TimeUnit.MILLISECONDS)) out.println("Trying to shutdown the UDP Proxy connections. This may take up to 15 seconds..."); if (!pool.awaitTermination(5, TimeUnit.SECONDS)) { pool.shutdownNow(); if (!pool.awaitTermination(5, TimeUnit.SECONDS)) err.println("Error: UDP Proxy connections did not terminate. You may have to kill that appplication."); } } catch(InterruptedException e) { pool.shutdownNow(); } synchronized(mainLock) { mainLock.notify(); } } } /*==========================================================================*/ public class ClientConnection extends CommandNetwork implements Runnable { private final SocketChannel sock; private final FSRecords fileservers; private final UserRecords users; private final Utils.EncObjectInputStream clin; private final Utils.EncObjectOutputStream clout; private UserRecord user = null; private final String clientaddr; HashMap filelist = null; private CommandNetwork fscmd; private ArrayList fsschannel; private ArrayList fsin; private ArrayList fsout; private ArrayList fileserver; private int curconn; private String mychallenge; ClientConnection(SocketChannel sock, FSRecords fileservers, UserRecords users) throws NoSuchMethodException, IOException { this.sock = sock; this.clout = new Utils.EncObjectOutputStream(sock.socket().getOutputStream()); this.clin = new Utils.EncObjectInputStream(new BufferedInputStream(sock.socket().getInputStream())); this.fileservers = fileservers; this.users = users; this.clientaddr = "tcp:/" + sock.socket().getInetAddress() + ":" + sock.socket().getPort(); fsschannel = new ArrayList(); fsin = new ArrayList(); fsout = new ArrayList(); fileserver = new ArrayList(); clin.setCipher(rsadecrypt); cmdHandler.register("!login", this, "cmdLogin"); cmdHandler.register("!buy", this, "cmdBuy"); cmdHandler.register("!credits", this, "cmdCredits"); cmdHandler.register("!list", this, "cmdList"); cmdHandler.register("!download", this, "cmdDownload"); cmdHandler.register("!upload", this, "cmdUpload"); cmdHandler.register("!upload2", this, "cmdUpload2"); cmdHandler.register("unknown", this, "cmdUnknown"); fscmd = new CommandNetwork(); fscmd.setOneCommandMode(true); fscmd.cmdHandler.register("!error", this, "cmdFSRelayOutput"); fscmd.cmdHandler.register("!output", this, "cmdFSRelayOutput"); fscmd.cmdHandler.register("!download", this, "cmdFSDownload"); fscmd.cmdHandler.register("!list", this, "cmdFSList"); fscmd.cmdHandler.register("!hasherr", this, "cmdFSHashError"); fscmd.cmdHandler.register("unknown", this, "cmdFSRelayOutput"); } /*------------------------------------------------------------------------*/ public boolean checkLogin() throws IOException { if (user == null) { Utils.sendError(clout, "Not logged in"); return false; } return true; } /*------------------------------------------------------------------------*/ public void cmdLogin(String cmd, String[] args) throws IOException { if (user != null) { Utils.sendError(clout, "Already logged in"); return; } if (args.length != 2) { Utils.sendError(clout, "Invalid Syntax: !login "); return; } String firstMessage = cmd + " " + Utils.join(Arrays.asList(args), " "); assert firstMessage.matches("!login \\w+ [" + B64 + "]{43}=") : "1st message"; synchronized(users) { UserRecord record = users.get(args[0]); if (record == null) { Utils.sendError(clout, "Invalid username"); return; } user = record; } /* read and init users public key */ File pemfile = new File(keysDir, args[0] + ".pub.pem"); if (!pemfile.isFile()) { Utils.sendError(clout, "No public keyfile"); return; } if (!pemfile.canRead()) { Utils.sendError(clout, "Your public keyfile is not readable"); return; } try { PEMReader in = new PEMReader(new FileReader(pemfile)); PublicKey publicKey = (PublicKey) in.readObject(); rsaencrypt.init(Cipher.ENCRYPT_MODE, publicKey); clout.setCipher(rsaencrypt); } catch(FileNotFoundException e) { Utils.sendError(clout, "Your public keyfile is not readable"); return; } catch(IOException e) { Utils.sendError(clout, "While reading users public key"); return; } catch(InvalidKeyException e) { Utils.sendError(clout, "invalid public key file: " + e.getMessage()); return; } String clichallenge = args[1]; /* generates a 32 byte secure random number */ SecureRandom secureRandom = new SecureRandom(); byte[] tmp = new byte[32]; secureRandom.nextBytes(tmp); mychallenge = new String(Base64.encode(tmp)); SecretKey seckey = null; javax.crypto.spec.IvParameterSpec iv = null; try { /* generate aes key */ KeyGenerator generator = KeyGenerator.getInstance("AES"); generator.init(256); seckey = generator.generateKey(); /* generate iv */ byte[] tmpiv = new byte[16]; secureRandom.nextBytes(tmpiv); iv = new javax.crypto.spec.IvParameterSpec(tmpiv); /* init aes */ aesencrypt.init(Cipher.ENCRYPT_MODE, seckey, iv); aesdecrypt.init(Cipher.DECRYPT_MODE, seckey, iv); } catch(NoSuchAlgorithmException e) { err.println("Error: Unable to generate AES key: " + e.getMessage()); return; } catch(InvalidKeyException e) { err.println("Error: invalid AES key: " + e.getMessage()); return; } catch(InvalidAlgorithmParameterException e) { err.println("Error: invalid AES parameters: " + e.getMessage()); return; } ArrayList msg = new ArrayList(); msg.add("!ok"); msg.add(clichallenge); msg.add(mychallenge); msg.add(new String(Base64.encode(seckey.getEncoded()))); msg.add(new String(Base64.encode(iv.getIV()))); String secondMessage = Utils.join(msg, " "); assert secondMessage.matches("!ok [" + B64 + "]{43}= [" + B64 + "]{43}= [" + B64 + "]{43}= [" + B64 + "]{22}==") : "2nd message"; clout.writeLine(secondMessage); clout.flush(); clout.setCipher(aesencrypt); clin.setCipher(aesdecrypt); } /*------------------------------------------------------------------------*/ public void cmdBuy(String cmd, String[] args) throws IOException { if (!checkLogin()) return; if (args.length != 1) { Utils.sendError(clout, "Invalid Syntax: !buy "); return; } int add = 0; try { add = Integer.parseInt(args[0]); if (add <= 0) throw new NumberFormatException(""); } catch(NumberFormatException e) { Utils.sendError(clout, "Credits must be numberic and positive"); return; } synchronized(users) { if (user.credits > Integer.MAX_VALUE - add) Utils.sendError(clout, "You can't buy that much/more credits"); else { user.credits += add; Utils.sendOutput(clout, "You now have " + user.credits + " credits"); } } } /*------------------------------------------------------------------------*/ public void cmdCredits(String cmd, String[] args) throws IOException { if (!checkLogin()) return; if (args.length != 0) { Utils.sendError(clout, "Invalid Syntax: !credits"); return; } synchronized(users) { Utils.sendOutput(clout, "You have " + user.credits + " credits left"); } } /*------------------------------------------------------------------------*/ public int connectFileserver(FSRecord fileserver) throws IOException { if (fileserver == null || !fileserver.online) { Utils.sendError(clout, "Unable to execute command. Fileserver not online"); return -1; } synchronized(fsschannel) { int conn = fsschannel.size(); try { fsschannel.add(conn, SocketChannel.open(new InetSocketAddress(fileserver.host, fileserver.port))); fsin.add(conn, new Utils.EncObjectInputStream(new BufferedInputStream( fsschannel.get(conn).socket().getInputStream()))); fsout.add(conn, new Utils.EncObjectOutputStream(fsschannel.get(conn).socket().getOutputStream())); this.fileserver.add(conn, fileserver); fsin.get(conn).setMAC(hmac); fsout.get(conn).setMAC(hmac); } catch(IOException e) { err.println("Error: Unable to connect to fileserver: " + e.getMessage()); Utils.sendError(clout, "Unable to connect to fileserver"); synchronized(fileservers) { fileserver.online = false; out.println("Fileserver marked as offline: " + fileserver.host + ":" + fileserver.port); } disconnectFileserver(conn); return -1; } return conn; } } /*------------------------------------------------------------------------*/ public void disconnectFileserver(int conn) { if (conn < 0) return; synchronized(fsschannel) { try { if (fsout.get(conn) != null) fsout.get(conn).flush(); } catch(IndexOutOfBoundsException e) {} catch(IOException e) {} try { if (fsin.get(conn) != null) fsin.get(conn).close(); fsin.set(conn, null); for(int i = fsin.size() - 1; i >= 0; --i) { if (fsin.get(i) != null) break; fsin.remove(i); } } catch(IndexOutOfBoundsException e) {} catch(IOException e) {} try { if (fsout.get(conn) != null) fsout.get(conn).close(); fsout.set(conn, null); for(int i = fsout.size() - 1; i >= 0; --i) { if (fsout.get(i) != null) break; fsout.remove(i); } } catch(IndexOutOfBoundsException e) {} catch(IOException e) {} try { if (fsschannel.get(conn) != null) fsschannel.get(conn).close(); fsschannel.set(conn, null); for(int i = fsschannel.size() - 1; i >= 0; --i) { if (fsschannel.get(i) != null) break; fsschannel.remove(i); } } catch(IndexOutOfBoundsException e) {} catch(IOException e) {} try { fileserver.set(conn, null); for(int i = fileserver.size() - 1; i >= 0; --i) { if (fileserver.get(i) != null) break; fileserver.remove(i); } } catch(IndexOutOfBoundsException e) {} } } /*------------------------------------------------------------------------*/ public void cmdList(String cmd, String[] args) throws IOException { if (!checkLogin()) return; if (args.length != 0) { Utils.sendError(clout, "Invalid Syntax: !list"); return; } filelist = new HashMap(); synchronized(fileservers) { for(FSRecord fileserver : fileservers.values()) { if (!fileserver.online) continue; int conn = -1; try { if ((conn = connectFileserver(fileserver)) < 0) return; fsout.get(conn).writeLine(cmd + " " + Utils.join(Arrays.asList(args), " ")); fsout.get(conn).flush(); curconn = conn; fscmd.run(fsin.get(conn)); } catch(Utils.HashError e) { fileserver.online = false; } catch(IOException e) { fileserver.online = false; Utils.sendError(clout, "Connection to fileserver " + fileserver.host + ":" + fileserver.port + " terminated unexpected"); } disconnectFileserver(conn); } } if (filelist.size() == 0) Utils.sendOutput(clout, "No files available"); else { ArrayList tmp = new ArrayList(); for(Map.Entry file : filelist.entrySet()) tmp.add(file.getKey() + " v=" + file.getValue()); Utils.sendOutput(clout, tmp.toArray(new String[tmp.size()])); } clout.flush(); } /*------------------------------------------------------------------------*/ public void cmdDownload(String cmd, String[] args) throws IOException { if (!checkLogin()) return; if (args.length != 1) { Utils.sendError(clout, "Invalid Syntax: !download "); return; } FSRecord curfs = null; long curversion = -1; synchronized(fileservers) { calcGifford(); for(FSRecord fileserver : rfileservers) { filelist = new HashMap(); int conn = -1; try { if ((conn = connectFileserver(fileserver)) < 0) return; fsout.get(conn).writeLine("!list"); fsout.get(conn).flush(); curconn = conn; fscmd.run(fsin.get(conn)); } catch(Utils.HashError e) { fileserver.online = false; } catch(IOException e) { fileserver.online = false; Utils.sendError(clout, "Connection to fileserver " + fileserver.host + ":" + fileserver.port + " terminated unexpected"); } disconnectFileserver(conn); Long version = filelist.get(args[0]); if (version != null && version > curversion) { curfs = fileserver; curversion = version; } } } if (curfs == null) { Utils.sendError(clout, "File not found or no fileservers online"); return; } int conn = -1; try { if ((conn = connectFileserver(curfs)) < 0) return; synchronized(users) { fsout.get(conn).writeLine(cmd + " " + Utils.join(Arrays.asList(args), " ") + " " + user.credits); } fsout.get(conn).flush(); curconn = conn; fscmd.run(fsin.get(conn)); } catch(Utils.HashError e) { curfs.online = false; } catch(IOException e) { Utils.sendError(clout, "Connection to fileserver " + fileserver.get(conn).host + ":" + fileserver.get(conn).port + "terminated unexpected"); } clout.flush(); disconnectFileserver(conn); } /*------------------------------------------------------------------------*/ public void cmdUpload(String cmd, String[] args) throws IOException { if (!checkLogin()) return; if (args.length != 1) { Utils.sendError(clout, "Invalid Syntax: !upload "); return; } synchronized(fileservers) { if (fileservers.size() == 0) { Utils.sendError(clout, "Unable to execute command. No fileservers online"); return; } } clout.writeLine(cmd + " " + args[0]); clout.flush(); } /*------------------------------------------------------------------------*/ public void cmdUpload2(String cmd, String[] args) throws IOException { if (!checkLogin()) return; if (args.length < 2 || args[1].length() <= 0) { err.println("Error: Invalid " + cmd + "-command paket from client. Ignoring..."); Utils.sendError(clout, "Internal Error: Invalid packet from client"); return; } long filesize, filesizecpy; if ((filesize = Utils.parseHeaderNum(args, 1)) < 0) return; filesizecpy = filesize; long curversion = 0; synchronized(fileservers) { calcGifford(); for(FSRecord fileserver : rfileservers) { filelist = new HashMap(); int conn = -1; try { if ((conn = connectFileserver(fileserver)) < 0) return; fsout.get(conn).writeLine("!list"); fsout.get(conn).flush(); curconn = conn; fscmd.run(fsin.get(conn)); } catch(Utils.HashError e) { fileserver.online = false; } catch(IOException e) { fileserver.online = false; Utils.sendError(clout, "Connection to fileserver " + fileserver.host + ":" + fileserver.port + " terminated unexpected"); } disconnectFileserver(conn); Long version = filelist.get(args[0]); if (version != null && version > curversion) curversion = version; } ArrayList conns = new ArrayList(); for(FSRecord fileserver : wfileservers) { int conn = -1; if ((conn = connectFileserver(fileserver)) < 0) return; conns.add(conn); } try { for(Integer conn : conns) fsout.get(conn).writeLine(cmd + " " + Utils.join(Arrays.asList(args), " ") + " " + (curversion + 1)); byte[] buffer = new byte[8 * 1024]; int toread = buffer.length; while(filesize > 0) { if (filesize < toread) toread = (int) filesize; int count = clin.read(buffer, 0, toread); if (count == -1) throw new IOException("Connection reset by peer"); /* decode + hash that chunk */ byte[] decbuffer = new byte[aesdecrypt.getOutputSize(count)]; int deccount = aesdecrypt.update(buffer, 0, count, decbuffer); hmac.update(decbuffer, 0, deccount); for(Integer conn : conns) fsout.get(conn).write(decbuffer, 0, deccount); filesize -= count; } /* decryption + hash must be finalized */ byte[] decbuffer = aesdecrypt.doFinal(); byte[] hash = hmac.doFinal(decbuffer); for(Integer conn : conns) fsout.get(conn).write(decbuffer); for(Integer conn : conns) fsout.get(conn).write(hash); for(Integer conn : conns) fsout.get(conn).flush(); synchronized(users) { user.credits += 2 * filesizecpy; } for(Integer conn : conns) fileserver.get(conn).usage += filesizecpy; Utils.sendOutput(clout, "File '" + args[0] + "' successfully uploaded."); } catch(IOException e) { err.println("Error during file transfer: " + e.getMessage() + ". Closing connection to client"); stop(); for(Integer conn : conns) { fileserver.get(conn).usage += filesizecpy; fileserver.get(conn).online = false; out.println("Fileserver marked as offline: " + fileserver.get(conn).host + ":" + fileserver.get(conn).port); } } catch(GeneralSecurityException e) { err.println("Error during encrypting file transfer: " + e.getMessage() + ". Closing connection to client"); stop(); } for(Integer conn : conns) disconnectFileserver(conn); } } /*------------------------------------------------------------------------*/ public void cmdUnknown(String cmd, String[] args) throws IOException { if (args.length == 0 && cmd.equals(mychallenge)) { String thirdMessage = cmd; assert thirdMessage.matches("[" + B64 + "]{43}=") : "3rd message"; synchronized(users) { user.loggedin.add(clientaddr); } Utils.sendOutput(clout, "Successfully logged in"); return; } err.println("Error: Unknown data from client: " + cmd + " " + Utils.join(Arrays.asList(args), " ")); Utils.sendError(clout, "Unknown command"); } /*------------------------------------------------------------------------*/ @SuppressWarnings("deprecation") public void cmdFSRelayOutput(String cmd, String[] args) throws IOException { long num; if ((num = Utils.parseHeaderNum(args, 0)) < 0) return; String msg; clout.writeLine(cmd + " " + Utils.join(Arrays.asList(args), " ")); for (; num > 0 && (msg = fsin.get(curconn).readLine()) != null; --num) clout.writeLine(msg); clout.flush(); } /*------------------------------------------------------------------------*/ public void cmdFSDownload(String cmd, String[] args) throws IOException { if (args.length < 2 || args[1].length() <= 0) { err.println("Error: Invalid " + cmd + "-command paket from fileserver. Ignoring..."); Utils.sendError(clout, "Internal Error: Invalid packet from fileserver"); return; } String file = args[0]; long filesize, filesizecpy; if ((filesize = Utils.parseHeaderNum(args, 1)) < 0) return; filesizecpy = filesize; clout.writeLine(cmd + " " + Utils.join(Arrays.asList(args), " ")); try { byte[] buffer = new byte[8 * 1024]; int toread = buffer.length; while(filesize > 0) { if (filesize < toread) toread = (int) filesize; int count = fsin.get(curconn).read(buffer, 0, toread); if (count == -1) throw new IOException("Connection reset by peer"); hmac.update(buffer, 0, count); /* encode that chunk */ byte[] encbuffer = new byte[aesencrypt.getOutputSize(count)]; int enccount = aesencrypt.update(buffer, 0, count, encbuffer); clout.write(encbuffer, 0, enccount); filesize -= count; } /* encryption must be finalized */ clout.write(aesencrypt.doFinal()); byte[] hashc = hmac.doFinal(); byte[] hasht = new byte[hmac.getMacLength()]; fsin.get(curconn).readFully(hasht); if (!Arrays.equals(hashc, hasht)) { synchronized(fileservers) { err.println("Error: invalid MAC during filetransfer. Taking fileserver offline and restarting download."); fileserver.get(curconn).online = false; out.println("Fileserver marked as offline: " + fileserver.get(curconn).host + ":" + fileserver.get(curconn).port); String[] tmp = { file }; cmdDownload("!download", tmp); return; } } Utils.sendOutput(clout, "File '" + file + "' successfully downloaded."); synchronized(users) { user.credits -= filesizecpy; } synchronized(fileservers) { fileserver.get(curconn).usage += filesizecpy; } } catch(IOException e) { err.println("Error during file transfer: " + e.getMessage() + ". Closing connection to client"); stop(); synchronized(fileservers) { fileserver.get(curconn).usage += filesizecpy; fileserver.get(curconn).online = false; out.println("Fileserver marked as offline: " + fileserver.get(curconn).host + ":" + fileserver.get(curconn).port); } } catch(GeneralSecurityException e) { err.println("Error during encrypting file transfer: " + e.getMessage() + ". Closing connection to client"); stop(); } } /*------------------------------------------------------------------------*/ @SuppressWarnings("unchecked") public void cmdFSList(String cmd, String[] args) throws IOException { long num; if ((num = Utils.parseHeaderNum(args, 0)) < 0) return; String msg; for (; num > 0 && (msg = fsin.get(curconn).readLine()) != null; --num) { int ix = msg.lastIndexOf(' '); if (ix == -1) { err.println("Error: Unknown filelist-message. Ignoring..."); continue; } try { String file = msg.substring(0, ix); Long fileversion = Long.valueOf(msg.substring(ix + 1)); Long version = filelist.get(file); if (version == null || fileversion > version) filelist.put(file, fileversion); } catch(NumberFormatException e) { err.println("Error: Unable to parse file version. Ignoring..."); continue; } } } /*------------------------------------------------------------------------*/ public void cmdFSHashError(String cmd, String[] args) throws IOException { err.println("Error: Hasherror from fileserver!"); synchronized(fileservers) { fileserver.get(curconn).online = false; out.println("Fileserver marked as offline: " + fileserver.get(curconn).host + ":" + fileserver.get(curconn).port); } cmdHandler.call2(); } /*------------------------------------------------------------------------*/ public void shutdown() { for(int i = fsschannel.size() - 1; i >= 0; --i) disconnectFileserver(i); try { clout.flush(); } catch(IOException e) {} try { clin.close(); } catch(IOException e) {} try { clout.close(); } catch(IOException e) {} try { if (sock.isOpen()) sock.close(); } catch(IOException e) {} } /*------------------------------------------------------------------------*/ public void run() { try { out.println("[" + Thread.currentThread().getId() + "] New client connection from " + clientaddr); run(clin); clout.flush(); } catch(CommandHandler.Exception e) { err.println("Internal Error: " + e.getMessage()); e.printStackTrace(); } catch(IOException e) { /* ignore that exception * it's usually a closed connection from client so * we can't do anything about it anyway */ } if (user != null) { synchronized(users) { user.loggedin.remove(user.loggedin.indexOf(clientaddr)); } } out.println("[" + Thread.currentThread().getId() + "] Connection closed"); shutdown(); } } /*==========================================================================*/ public class TCPSocketReader implements Runnable { private final ServerSocketChannel sschannel; private final FSRecords fileservers; private final UserRecords users; private final Object mainLock; private final ExecutorService pool; TCPSocketReader(ServerSocketChannel sschannel, FSRecords fileservers, UserRecords users, Object mainLock) { this.sschannel = sschannel; this.fileservers = fileservers; this.users = users; this.mainLock = mainLock; this.pool = Executors.newCachedThreadPool(); } /*------------------------------------------------------------------------*/ public void run() { try { while(true) pool.execute(new ClientConnection(sschannel.accept(), fileservers, users)); } catch(NoSuchMethodException e) { err.println("Error: Unable to setup remote command handler"); } catch(IOException e) { /* ignore that exception * thread will shutdown and unlock the main thread * which will shutdown the application */ } pool.shutdown(); try { if (!pool.awaitTermination(100, TimeUnit.MILLISECONDS)) out.println("Trying to shutdown the client connections. This may take up to 15 seconds..."); if (!pool.awaitTermination(5, TimeUnit.SECONDS)) { pool.shutdownNow(); if (!pool.awaitTermination(5, TimeUnit.SECONDS)) err.println("Error: Client connections did not terminate. You may have to kill that appplication."); } } catch(InterruptedException e) { pool.shutdownNow(); } synchronized(mainLock) { mainLock.notify(); } } } /*==========================================================================*/ public class Interactive extends CommandInteractive implements Runnable { private final InputStream sin; private final Object mainLock; Interactive(InputStream sin, Object mainLock) throws NoSuchMethodException { this.sin = sin; this.mainLock = mainLock; cmdHandler.register("unknown", this, "cmdUnknown"); cmdHandler.register("!fileservers", this, "cmdFileservers"); cmdHandler.register("!users", this, "cmdUsers"); cmdHandler.register("!exit", this, "cmdExit"); } /*------------------------------------------------------------------------*/ public void cmdUnknown(String cmd, String[] args) { err.println("Unknown command: " + cmd + " " + Utils.join(Arrays.asList(args), " ")); } /*------------------------------------------------------------------------*/ public void cmdFileservers(String cmd, String[] args) { synchronized(fileservers) { if (fileservers.size() == 0) out.println("No fileservers registered"); else { calcGifford(); int line = 1; int pad = Integer.toString(fileservers.size()).length(); for(Map.Entry entry : fileservers.entrySet()) { FSRecord record = entry.getValue(); out.println(String.format("%0" + pad + "d. IP: %s, Port: %d, %s, Usage: %d, Quorum: %s%s", line, record.host, record.port, (record.online) ? "online" : "offline", record.usage, (rfileservers.contains(record)) ? "R" : "", (wfileservers.contains(record)) ? "W" : "")); ++line; } } } } /*------------------------------------------------------------------------*/ public void cmdUsers(String cmd, String[] args) { synchronized(users) { if (users.size() == 0) out.println("No users registered"); else { int line = 1; int pad = Integer.toString(users.size()).length(); for(Map.Entry entry : users.entrySet()) { UserRecord record = entry.getValue(); out.println(String.format("%0" + pad + "d. User: %s, %s, Credits: %d", line, record.name, (record.loggedin.size() > 0) ? "online" : "offline", record.credits)); for(String host : record.loggedin) out.println(String.format("%1$#" + (pad + 1) + "s Client: %2$s", " ", host)); ++line; } } } } /*------------------------------------------------------------------------*/ public void cmdExit(String cmd, String[] args) { stop(); } /*------------------------------------------------------------------------*/ public void printPrompt() { out.print(">: "); out.flush(); } /*------------------------------------------------------------------------*/ public void run() { try { run(sin); } catch(CommandHandler.Exception e) { err.println("Internal Error: " + e.getMessage()); } catch (IOException e) { /* ignore that exception * thread will shutdown and unlock the main thread * which will shutdown the application */ } synchronized(mainLock) { mainLock.notify(); } } } /*==========================================================================*/ public class CheckFSTask implements Runnable { private FSRecords fileservers; private final int fserverTimeout; CheckFSTask(FSRecords fileservers, int fserverTimeout) { this.fileservers = fileservers; this.fserverTimeout = fserverTimeout; } /*------------------------------------------------------------------------*/ public void run() { synchronized(fileservers) { long curTime = Calendar.getInstance().getTimeInMillis(); for(Map.Entry entry : fileservers.entrySet()) { if (entry.getValue().online && entry.getValue().lastUpdate + fserverTimeout < curTime) { entry.getValue().online = false; out.println("Fileserver has gone offline: " + entry.getKey()); } } } } } /*==========================================================================*/ private static int tcpPort; private static int udpPort; private static int fserverTimeout; private static int checkPeriod; private FSRecords fileservers; private ArrayList rfileservers; private ArrayList wfileservers; private UserRecords users; private ScheduledExecutorService scheduler = null; private DatagramChannel dchannel = null; private Thread tUDPSocketReader = null; private ServerSocketChannel sschannel = null; private Thread tTCPSocketReader = null; private Thread tInteractive = null; private InputStream stdin = null; private final Object mainLock = new Object(); private static String keysDir; private static String proxyKey; private PrivateKey privateKey = null; private Cipher rsaencrypt, rsadecrypt; private Cipher aesencrypt, aesdecrypt; private Mac hmac; private final String B64 = "a-zA-Z0-9/+"; /*--------------------------------------------------------------------------*/ Proxy() { fileservers = new FSRecords(); rfileservers = new ArrayList(); wfileservers = new ArrayList(); users = new UserRecords(); try { rsaencrypt = Cipher.getInstance("RSA/NONE/OAEPWithSHA256AndMGF1Padding"); rsadecrypt = Cipher.getInstance("RSA/NONE/OAEPWithSHA256AndMGF1Padding"); aesencrypt = Cipher.getInstance("AES/CTR/NoPadding"); aesdecrypt = Cipher.getInstance("AES/CTR/NoPadding"); } catch(NoSuchAlgorithmException e) { bailout("Unable to initialize cipher: " + e.getMessage()); } catch(NoSuchPaddingException e) { bailout("Unable to initialize cipher: " + e.getMessage()); } } /*--------------------------------------------------------------------------*/ public static void usage() throws Utils.Shutdown { out.println("Usage: Proxy\n"); // Java is some piece of crap which doesn't allow me to set exitcode w/o // using System.exit. Maybe someday Java will be a fully functional // programming language, but I wouldn't bet my money //System.exit(1); throw new Utils.Shutdown("FUCK YOU JAVA"); } /*--------------------------------------------------------------------------*/ public void bailout(String error) throws Utils.Shutdown { err.println("Error: " + error); shutdown(); // Java is some piece of crap which doesn't allow me to set exitcode w/o // using System.exit. Maybe someday Java will be a fully functional // programming language, but I wouldn't bet my money //System.exit(2); throw new Utils.Shutdown("FUCK YOU JAVA"); } /*--------------------------------------------------------------------------*/ public void parseArgs(String[] args) { if (args.length != 0) usage(); } /*--------------------------------------------------------------------------*/ public void parseConfig() { Config config = null; try { config = new Config("proxy"); } catch(MissingResourceException e) { bailout("configuration file doesn't exist or isn't readable"); } String directive = null; try { directive = "tcp.port"; tcpPort = config.getInt(directive); if (tcpPort <= 0 || tcpPort > 65536) bailout("configuration directive '" + directive + "' must be a valid port number (1 - 65535)"); directive = "udp.port"; udpPort = config.getInt(directive); if (udpPort <= 0 || udpPort > 65536) bailout("configuration directive '" + directive + "' must be a valid port number (1 - 65535)"); directive = "fileserver.timeout"; fserverTimeout = config.getInt(directive); if (fserverTimeout <= 0) bailout("configuration directive '" + directive + "' must be a positive number"); directive = "fileserver.checkPeriod"; checkPeriod = config.getInt(directive); if (checkPeriod <= 0) bailout("configuration directive '" + directive + "' must be a positive number"); directive = "keys.dir"; keysDir = config.getString(directive); File dir = new File(keysDir); if (!dir.isDirectory()) bailout("configuration directive '" + directive + "' is not a directory"); if (!dir.canRead()) bailout("configuration directive '" + directive + "' is not readable"); directive = "key"; proxyKey = config.getString(directive); File key = new File(proxyKey); if (!key.isFile()) bailout("configuration directive '" + directive + "' is not a file"); if (!key.canRead()) bailout("configuration directive '" + directive + "' is not readable"); PEMReader in = new PEMReader(new FileReader(key), new PasswordFinder() { @Override public char[] getPassword() { try { /* reads the password from standard input for decrypting the private key */ out.println("Enter pass phrase for proxy key:"); return new BufferedReader(new InputStreamReader(System.in)).readLine().toCharArray(); } catch(IOException e) { char[] tmp = {}; return tmp; } } }); try { KeyPair keyPair = (KeyPair) in.readObject(); privateKey = keyPair.getPrivate(); rsadecrypt.init(Cipher.DECRYPT_MODE, privateKey); } catch(IOException e) { bailout("Error while reading private key of proxy. Maybe wrong pass phrase"); } directive = "hmac.key"; File hmackey = new File(config.getString(directive)); if (!hmackey.isFile()) bailout("configuration directive '" + directive + "' is not a file"); if (!hmackey.canRead()) bailout("configuration directive '" + directive + "' is not readable"); byte[] keybytes = new byte[1024]; FileInputStream fis = new FileInputStream(hmackey); fis.read(keybytes); fis.close(); hmac = Mac.getInstance("HmacSHA256"); hmac.init(new javax.crypto.spec.SecretKeySpec(Hex.decode(keybytes), "HmacSHA256")); } catch(FileNotFoundException e) { bailout("unable to read file of directive '" + directive + "'"); } catch(IOException e) { bailout("Error while reading file: " + e.getMessage()); } catch(InvalidKeyException e) { bailout("invalid key file: " + e.getMessage()); } catch(MissingResourceException e) { bailout("configuration directive '" + directive + "' is not set"); } catch(NumberFormatException e) { bailout("configuration directive '" + directive + "' must be numeric"); } catch(NoSuchAlgorithmException e) { bailout("Unable to initialize cipher: " + e.getMessage()); } } /*--------------------------------------------------------------------------*/ public void parseUsers() { Config config = null; try { config = new Config("user"); } catch(MissingResourceException e) { bailout("Unable to read from user properties file: " + e.getMessage()); } /* first load the users only */ for (Enumeration e = config.getKeys(); e.hasMoreElements();) { String prop = (String)e.nextElement(); String[] pieces = prop.split("\\.", 2); String user = pieces[0]; if (pieces.length == 1) { int credits = config.getInt(prop); if (credits < 0) { err.println("Property " + prop + " must be positive number. Skipping..."); continue; } users.put(user, new UserRecord(user, credits)); } } /* next load their properties */ for (Enumeration e = config.getKeys(); e.hasMoreElements();) { String prop = (String)e.nextElement(); String[] pieces = prop.split("\\.", 2); String user = pieces[0]; if (pieces.length == 1) continue; else if (pieces.length == 2) { UserRecord record = users.get(user); if (record == null) { err.println("Can't load user properties for unknown user '" + user + "'. Skipping..."); continue; } if (pieces[1].equals("password")) record.pass = config.getString(prop); else err.println("Property " + prop + " is unknown. Skipping..."); } else err.println("Property " + prop + " is unknown. Skipping..."); } } /*--------------------------------------------------------------------------*/ public void calcGifford() { synchronized(fileservers) { rfileservers.clear(); wfileservers.clear(); ArrayList fslist = new ArrayList(); for(FSRecord record : fileservers.values()) { if (!record.online) continue; fslist.add(record); } if (fslist.size() == 0) return; Collections.sort(fslist); int numwrite = (int) Math.floor(fslist.size() / 2) + 1; int numread = fslist.size() - numwrite + 1; for(int i = 0; i < numread; ++i) rfileservers.add(fslist.get(i)); for(int i = 0; i < numwrite; ++i) wfileservers.add(fslist.get(i)); } } /*--------------------------------------------------------------------------*/ public void shutdown() { try { if (scheduler != null) { scheduler.shutdownNow(); scheduler.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS); } } catch(InterruptedException e) {} try { if (dchannel != null) dchannel.close(); } catch(IOException e) {} try { if (tUDPSocketReader != null) tUDPSocketReader.join(); } catch(InterruptedException e) {} try { if (sschannel != null) sschannel.close(); } catch(IOException e) {} try { if (tTCPSocketReader != null) tTCPSocketReader.join(); } catch(InterruptedException e) {} try { if (tInteractive != null) { tInteractive.interrupt(); tInteractive.join(); } } catch(InterruptedException e) {} try { if (stdin != null) stdin.close(); } catch(IOException e) {} } /*--------------------------------------------------------------------------*/ public void run(String[] args) { parseArgs(args); parseConfig(); synchronized(users) { parseUsers(); out.println("Users loaded successfully: " + users.size() + " users loaded"); } synchronized(mainLock) { scheduler = Executors.newScheduledThreadPool(1); ScheduledFuture checkFSTimer = scheduler.scheduleAtFixedRate( new CheckFSTask(fileservers, fserverTimeout), 0, checkPeriod, TimeUnit.MILLISECONDS); try { dchannel = DatagramChannel.open(); dchannel.socket().bind(new InetSocketAddress(udpPort)); tUDPSocketReader = new Thread(new UDPSocketReader(dchannel, fileservers, mainLock)); tUDPSocketReader.start(); out.println("Listening on udp:/" + dchannel.socket().getLocalSocketAddress()); } catch(IOException e) { bailout("Unable to create UDP Socket: " + e.getMessage()); } try { sschannel = ServerSocketChannel.open(); sschannel.socket().bind(new InetSocketAddress(tcpPort)); tTCPSocketReader = new Thread(new TCPSocketReader(sschannel, fileservers, users, mainLock)); tTCPSocketReader.start(); out.println("Listening on tcp:/" + sschannel.socket().getLocalSocketAddress()); } catch(IOException e) { bailout("Unable to create TCP Socket: " + e.getMessage()); } try { InputStream stdin = java.nio.channels.Channels.newInputStream( new FileInputStream(FileDescriptor.in).getChannel()); tInteractive = new Thread(new Interactive(stdin, mainLock)); tInteractive.start(); } catch(NoSuchMethodException e) { bailout("Unable to setup interactive command handler"); } out.println("Proxy startup successful!"); try { mainLock.wait(); } catch(InterruptedException e) { /* if we get interrupted -> ignore */ } try { /* let the threads shutdown */ Thread.sleep(100); } catch(InterruptedException e) {} } if (tUDPSocketReader != null && !tUDPSocketReader.isAlive()) bailout("Listening UDP socket closed unexpected. Terminating..."); if (tTCPSocketReader != null && !tTCPSocketReader.isAlive()) bailout("Listening TCP socket closed unexpected. Terminating..."); shutdown(); } /*--------------------------------------------------------------------------*/ public static void main(String[] args) { try { Proxy proxy = new Proxy(); proxy.run(args); } catch(Utils.Shutdown e) {} } }