AES暗号ユーティリティ

AES 暗号をお手軽に使えるよう InputStream/OutputStream にかぶせて使うアダプター Stream をつくってみた。

暗号化アダプター

暗号化したいデータを格納した InputStream is があるとして、これに 16 バイト(あるいは 24、もしくは 32 バイト)の鍵データを適用すると IV 込みで暗号化したデータを読み出せる AesEncInputStream。

InputStream wrapped = new AesEncInputStream(key, is, null);

三つめの引数は IV を生成する Random オブジェクトをオプションで渡せるようにしている。 null を渡すとパッケージローカルの SecureRandom オブジェクトを使う。

もしくは OutputStream os に書き込むときに暗号化を施すアダプター AesEncOutputStream。

OutputStream wrapped = new AesOutputStream(key, os, null);

三つめの引数は同様に IV 生成につかうオプション Random。パディングは AesEncOutputStream#close() を呼び出したときに付け足すので #close() 忘れに注意。

復号アダプター

冒頭に IV 16 バイトが付された暗号化データを読み出せる InputStream is があるとして、これに鍵データを適用すると、復号したデータを読み出せる AesDecInputStream。

InputStream wrapped = new AesDecInputStream(key, is);

もしくは OutputStream os に書き込むときに復号を実行する AesDecOutputStream。

OutputStream wrapped = new AesDecOutputStream(key, os);

最後のパディングを処理は AesDecOutputStream#close() を呼び出したときにするため、ちょっと使いにくいかもしれない。

実装

以下実装。

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.SequenceInputStream;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.util.Enumeration;
import java.util.Random;

import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.CipherOutputStream;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

import static javax.crypto.Cipher.DECRYPT_MODE;
import static javax.crypto.Cipher.ENCRYPT_MODE;

class Local {
    static final SecureRandom random = new SecureRandom();
}

class AesEncInputStream extends SequenceInputStream {

    public AesEncInputStream(byte[] key, InputStream is, Random random)
            throws GeneralSecurityException {
        super(new Helper(key, is, random));
    }

    private static class Helper implements Enumeration<InputStream> {
        final InputStream[] iss = new InputStream[2];
        int pos = 0;

        Helper(byte[] key, InputStream is, Random random) throws GeneralSecurityException {
            final byte[] iv = new byte[16];
            (random != null ? random : Local.random).nextBytes(iv);

            final Cipher c = Cipher.getInstance("AES/CBC/PKCS5Padding");
            c.init(ENCRYPT_MODE, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv));

            iss[0] = new ByteArrayInputStream(iv);
            iss[1] = new CipherInputStream(is, c);
        }

        @Override
        public boolean hasMoreElements() {
            return pos < 2;
        }

        @Override
        public InputStream nextElement() {
            return iss[pos++];
        }

    }

}

class AesEncOutputStream extends CipherOutputStream {

    public AesEncOutputStream(byte[] key, OutputStream os, Random random)
            throws GeneralSecurityException {
        super(os, getCipher(key, os, random));
    }

    private static Cipher getCipher(byte[] key, OutputStream os, Random random)
            throws GeneralSecurityException {
        final byte[] iv = new byte[16];
        (random != null ? random : Local.random).nextBytes(iv);
        try {
            os.write(iv, 0, 16);
        } catch (IOException e) {
            throw new GeneralSecurityException("Cannot write IV (16bytes)", e);
        }

        final Cipher c = Cipher.getInstance("AES/CBC/PKCS5Padding");
        c.init(ENCRYPT_MODE, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv));
        return c;
    }

}

class AesDecInputStream extends CipherInputStream {

    public AesDecInputStream(byte[] key, InputStream is)
            throws GeneralSecurityException {
        super(is, getCipher(key, is));
    }

    private static Cipher getCipher(byte[] key, InputStream is) throws GeneralSecurityException {
        Throwable t = null;
        try {
            final byte[] iv = new byte[16];
            if (is.read(iv) == 16) {
                final Cipher c = Cipher.getInstance("AES/CBC/PKCS5Padding");
                c.init(DECRYPT_MODE, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv));
                return c;
            }
        } catch (IOException e) {
            t = e;
        }
        throw new GeneralSecurityException("Cannot read IV (16bytes)", t);
    }

}

class AesDecOutputStream extends OutputStream {

    private final Cipher c;
    private final OutputStream os;
    private final byte[] key;
    private final byte[] iv = new byte[16];
    private int pos = 0;

    public AesDecOutputStream(byte[] key, OutputStream os) throws GeneralSecurityException {
        c = Cipher.getInstance("AES/CBC/PKCS5Padding");
        this.key = key;
        this.os = os;
    }

    @Override
    public void write(int oneByte) throws IOException {
        final byte[] buffer = {(byte) oneByte};
        write(buffer, 0, 1);
    }

    @Override
    public void write(byte[] buffer, int offset, int count) throws IOException {
        final int n = pos == 16 ? 0 : Math.min(16 - pos, count);
        if (n > 0) {
            System.arraycopy(buffer, offset, iv, pos, n);
            if ((pos += n) == 16) {
                try {
                    c.init(DECRYPT_MODE, new SecretKeySpec(key, "AES"), new IvParameterSpec(iv));
                } catch (GeneralSecurityException e) {
                    throw new IOException("Failed to initialize cipher", e);
                }
            }
        }
        if (count - n > 0) os.write(c.update(buffer, offset + n, count - n));
    }

    @Override
    public void flush() throws IOException {
        os.flush();
    }

    @Override
    public void close() throws IOException {
        try {
            if (pos == 16) os.write(c.doFinal());
        } catch (GeneralSecurityException e) {
            throw new IOException("Failed to finalize cipher", e);
        }
    }

}
使用例
    public void testInPair() throws Exception {
        final byte[] key = new SeucreRandom().get(16);
        final String expect = "hello, world! hello, world! hello, world! hello, world!!";

        final byte[] encrypted;
        {
            final InputStream is = new AesEncInputStream(key, new ByteArrayInputStream(expect.getBytes()), null);
            final ByteArrayOutputStream os = new ByteArrayOutputStream();
            final byte[] buffer = new byte[8];
            for (int read; (read = is.read(buffer, 0, Local.random.nextInt(8) + 1)) != -1; )
                os.write(buffer, 0, read);
            encrypted = os.toByteArray();
        }

        final byte[] decrypted;
        {
            final InputStream is = new AesDecInputStream(key, new ByteArrayInputStream(encrypted));
            final ByteArrayOutputStream os = new ByteArrayOutputStream();
            final byte[] buffer = new byte[8];
            for (int read; (read = is.read(buffer, 0, Local.random.nextInt(8) + 1)) != -1; )
                os.write(buffer, 0, read);
            decrypted = os.toByteArray();
        }

        assertEquals(expect, new String(decrypted));
    }