--- a/src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java Fri May 11 14:55:56 2018 -0700
+++ b/src/java.base/share/classes/sun/security/rsa/RSAKeyFactory.java Fri May 11 15:53:12 2018 -0700
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2003, 2011, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2003, 2018, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
@@ -32,10 +32,15 @@
import java.security.spec.*;
import sun.security.action.GetPropertyAction;
+import sun.security.x509.AlgorithmId;
+import static sun.security.rsa.RSAUtil.KeyType;
/**
- * KeyFactory for RSA keys. Keys must be instances of PublicKey or PrivateKey
- * and getAlgorithm() must return "RSA". For such keys, it supports conversion
+ * KeyFactory for RSA keys, e.g. "RSA", "RSASSA-PSS".
+ * Keys must be instances of PublicKey or PrivateKey
+ * and getAlgorithm() must return a value which matches the type which are
+ * specified during construction time of the KeyFactory object.
+ * For such keys, it supports conversion
* between the following:
*
* For public keys:
@@ -58,21 +63,21 @@
* @since 1.5
* @author Andreas Sterbenz
*/
-public final class RSAKeyFactory extends KeyFactorySpi {
+public class RSAKeyFactory extends KeyFactorySpi {
- private static final Class<?> rsaPublicKeySpecClass =
- RSAPublicKeySpec.class;
- private static final Class<?> rsaPrivateKeySpecClass =
- RSAPrivateKeySpec.class;
- private static final Class<?> rsaPrivateCrtKeySpecClass =
- RSAPrivateCrtKeySpec.class;
-
- private static final Class<?> x509KeySpecClass = X509EncodedKeySpec.class;
- private static final Class<?> pkcs8KeySpecClass = PKCS8EncodedKeySpec.class;
+ private static final Class<?> RSA_PUB_KEYSPEC_CLS = RSAPublicKeySpec.class;
+ private static final Class<?> RSA_PRIV_KEYSPEC_CLS =
+ RSAPrivateKeySpec.class;
+ private static final Class<?> RSA_PRIVCRT_KEYSPEC_CLS =
+ RSAPrivateCrtKeySpec.class;
+ private static final Class<?> X509_KEYSPEC_CLS = X509EncodedKeySpec.class;
+ private static final Class<?> PKCS8_KEYSPEC_CLS = PKCS8EncodedKeySpec.class;
public static final int MIN_MODLEN = 512;
public static final int MAX_MODLEN = 16384;
+ private final KeyType type;
+
/*
* If the modulus length is above this value, restrict the size of
* the exponent to something that can be reasonably computed. We
@@ -87,11 +92,18 @@
"true".equalsIgnoreCase(GetPropertyAction.privilegedGetProperty(
"sun.security.rsa.restrictRSAExponent", "true"));
- // instance used for static translateKey();
- private static final RSAKeyFactory INSTANCE = new RSAKeyFactory();
+ static RSAKeyFactory getInstance(KeyType type) {
+ return new RSAKeyFactory(type);
+ }
- public RSAKeyFactory() {
- // empty
+ // Internal utility method for checking key algorithm
+ private static void checkKeyAlgo(Key key, String expectedAlg)
+ throws InvalidKeyException {
+ String keyAlg = key.getAlgorithm();
+ if (!(keyAlg.equalsIgnoreCase(expectedAlg))) {
+ throw new InvalidKeyException("Expected a " + expectedAlg
+ + " key, but got " + keyAlg);
+ }
}
/**
@@ -107,7 +119,14 @@
(key instanceof RSAPublicKeyImpl)) {
return (RSAKey)key;
} else {
- return (RSAKey)INSTANCE.engineTranslateKey(key);
+ try {
+ String keyAlgo = key.getAlgorithm();
+ KeyType type = KeyType.lookup(keyAlgo);
+ RSAKeyFactory kf = RSAKeyFactory.getInstance(type);
+ return (RSAKey) kf.engineTranslateKey(key);
+ } catch (ProviderException e) {
+ throw new InvalidKeyException(e);
+ }
}
}
@@ -171,6 +190,15 @@
}
}
+ // disallowed as KeyType is required
+ private RSAKeyFactory() {
+ this.type = KeyType.RSA;
+ }
+
+ public RSAKeyFactory(KeyType type) {
+ this.type = type;
+ }
+
/**
* Translate an RSA key into a SunRsaSign RSA key. If conversion is
* not possible, throw an InvalidKeyException.
@@ -180,9 +208,14 @@
if (key == null) {
throw new InvalidKeyException("Key must not be null");
}
- String keyAlg = key.getAlgorithm();
- if (keyAlg.equals("RSA") == false) {
- throw new InvalidKeyException("Not an RSA key: " + keyAlg);
+ // ensure the key algorithm matches the current KeyFactory instance
+ checkKeyAlgo(key, type.keyAlgo());
+
+ // no translation needed if the key is already our own impl
+ if ((key instanceof RSAPrivateKeyImpl) ||
+ (key instanceof RSAPrivateCrtKeyImpl) ||
+ (key instanceof RSAPublicKeyImpl)) {
+ return key;
}
if (key instanceof PublicKey) {
return translatePublicKey((PublicKey)key);
@@ -221,22 +254,22 @@
private PublicKey translatePublicKey(PublicKey key)
throws InvalidKeyException {
if (key instanceof RSAPublicKey) {
- if (key instanceof RSAPublicKeyImpl) {
- return key;
- }
RSAPublicKey rsaKey = (RSAPublicKey)key;
try {
return new RSAPublicKeyImpl(
+ RSAUtil.createAlgorithmId(type, rsaKey.getParams()),
rsaKey.getModulus(),
- rsaKey.getPublicExponent()
- );
- } catch (RuntimeException e) {
+ rsaKey.getPublicExponent());
+ } catch (ProviderException e) {
// catch providers that incorrectly implement RSAPublicKey
throw new InvalidKeyException("Invalid key", e);
}
} else if ("X.509".equals(key.getFormat())) {
byte[] encoded = key.getEncoded();
- return new RSAPublicKeyImpl(encoded);
+ RSAPublicKey translated = new RSAPublicKeyImpl(encoded);
+ // ensure the key algorithm matches the current KeyFactory instance
+ checkKeyAlgo(translated, type.keyAlgo());
+ return translated;
} else {
throw new InvalidKeyException("Public keys must be instance "
+ "of RSAPublicKey or have X.509 encoding");
@@ -247,12 +280,10 @@
private PrivateKey translatePrivateKey(PrivateKey key)
throws InvalidKeyException {
if (key instanceof RSAPrivateCrtKey) {
- if (key instanceof RSAPrivateCrtKeyImpl) {
- return key;
- }
RSAPrivateCrtKey rsaKey = (RSAPrivateCrtKey)key;
try {
return new RSAPrivateCrtKeyImpl(
+ RSAUtil.createAlgorithmId(type, rsaKey.getParams()),
rsaKey.getModulus(),
rsaKey.getPublicExponent(),
rsaKey.getPrivateExponent(),
@@ -262,27 +293,28 @@
rsaKey.getPrimeExponentQ(),
rsaKey.getCrtCoefficient()
);
- } catch (RuntimeException e) {
+ } catch (ProviderException e) {
// catch providers that incorrectly implement RSAPrivateCrtKey
throw new InvalidKeyException("Invalid key", e);
}
} else if (key instanceof RSAPrivateKey) {
- if (key instanceof RSAPrivateKeyImpl) {
- return key;
- }
RSAPrivateKey rsaKey = (RSAPrivateKey)key;
try {
return new RSAPrivateKeyImpl(
+ RSAUtil.createAlgorithmId(type, rsaKey.getParams()),
rsaKey.getModulus(),
rsaKey.getPrivateExponent()
);
- } catch (RuntimeException e) {
+ } catch (ProviderException e) {
// catch providers that incorrectly implement RSAPrivateKey
throw new InvalidKeyException("Invalid key", e);
}
} else if ("PKCS#8".equals(key.getFormat())) {
byte[] encoded = key.getEncoded();
- return RSAPrivateCrtKeyImpl.newKey(encoded);
+ RSAPrivateKey translated = RSAPrivateCrtKeyImpl.newKey(encoded);
+ // ensure the key algorithm matches the current KeyFactory instance
+ checkKeyAlgo(translated, type.keyAlgo());
+ return translated;
} else {
throw new InvalidKeyException("Private keys must be instance "
+ "of RSAPrivate(Crt)Key or have PKCS#8 encoding");
@@ -294,13 +326,21 @@
throws GeneralSecurityException {
if (keySpec instanceof X509EncodedKeySpec) {
X509EncodedKeySpec x509Spec = (X509EncodedKeySpec)keySpec;
- return new RSAPublicKeyImpl(x509Spec.getEncoded());
+ RSAPublicKey generated = new RSAPublicKeyImpl(x509Spec.getEncoded());
+ // ensure the key algorithm matches the current KeyFactory instance
+ checkKeyAlgo(generated, type.keyAlgo());
+ return generated;
} else if (keySpec instanceof RSAPublicKeySpec) {
RSAPublicKeySpec rsaSpec = (RSAPublicKeySpec)keySpec;
- return new RSAPublicKeyImpl(
- rsaSpec.getModulus(),
- rsaSpec.getPublicExponent()
- );
+ try {
+ return new RSAPublicKeyImpl(
+ RSAUtil.createAlgorithmId(type, rsaSpec.getParams()),
+ rsaSpec.getModulus(),
+ rsaSpec.getPublicExponent()
+ );
+ } catch (ProviderException e) {
+ throw new InvalidKeySpecException(e);
+ }
} else {
throw new InvalidKeySpecException("Only RSAPublicKeySpec "
+ "and X509EncodedKeySpec supported for RSA public keys");
@@ -312,25 +352,38 @@
throws GeneralSecurityException {
if (keySpec instanceof PKCS8EncodedKeySpec) {
PKCS8EncodedKeySpec pkcsSpec = (PKCS8EncodedKeySpec)keySpec;
- return RSAPrivateCrtKeyImpl.newKey(pkcsSpec.getEncoded());
+ RSAPrivateKey generated = RSAPrivateCrtKeyImpl.newKey(pkcsSpec.getEncoded());
+ // ensure the key algorithm matches the current KeyFactory instance
+ checkKeyAlgo(generated, type.keyAlgo());
+ return generated;
} else if (keySpec instanceof RSAPrivateCrtKeySpec) {
RSAPrivateCrtKeySpec rsaSpec = (RSAPrivateCrtKeySpec)keySpec;
- return new RSAPrivateCrtKeyImpl(
- rsaSpec.getModulus(),
- rsaSpec.getPublicExponent(),
- rsaSpec.getPrivateExponent(),
- rsaSpec.getPrimeP(),
- rsaSpec.getPrimeQ(),
- rsaSpec.getPrimeExponentP(),
- rsaSpec.getPrimeExponentQ(),
- rsaSpec.getCrtCoefficient()
- );
+ try {
+ return new RSAPrivateCrtKeyImpl(
+ RSAUtil.createAlgorithmId(type, rsaSpec.getParams()),
+ rsaSpec.getModulus(),
+ rsaSpec.getPublicExponent(),
+ rsaSpec.getPrivateExponent(),
+ rsaSpec.getPrimeP(),
+ rsaSpec.getPrimeQ(),
+ rsaSpec.getPrimeExponentP(),
+ rsaSpec.getPrimeExponentQ(),
+ rsaSpec.getCrtCoefficient()
+ );
+ } catch (ProviderException e) {
+ throw new InvalidKeySpecException(e);
+ }
} else if (keySpec instanceof RSAPrivateKeySpec) {
RSAPrivateKeySpec rsaSpec = (RSAPrivateKeySpec)keySpec;
- return new RSAPrivateKeyImpl(
- rsaSpec.getModulus(),
- rsaSpec.getPrivateExponent()
- );
+ try {
+ return new RSAPrivateKeyImpl(
+ RSAUtil.createAlgorithmId(type, rsaSpec.getParams()),
+ rsaSpec.getModulus(),
+ rsaSpec.getPrivateExponent()
+ );
+ } catch (ProviderException e) {
+ throw new InvalidKeySpecException(e);
+ }
} else {
throw new InvalidKeySpecException("Only RSAPrivate(Crt)KeySpec "
+ "and PKCS8EncodedKeySpec supported for RSA private keys");
@@ -349,12 +402,13 @@
}
if (key instanceof RSAPublicKey) {
RSAPublicKey rsaKey = (RSAPublicKey)key;
- if (rsaPublicKeySpecClass.isAssignableFrom(keySpec)) {
+ if (RSA_PUB_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
return keySpec.cast(new RSAPublicKeySpec(
rsaKey.getModulus(),
- rsaKey.getPublicExponent()
+ rsaKey.getPublicExponent(),
+ rsaKey.getParams()
));
- } else if (x509KeySpecClass.isAssignableFrom(keySpec)) {
+ } else if (X509_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
return keySpec.cast(new X509EncodedKeySpec(key.getEncoded()));
} else {
throw new InvalidKeySpecException
@@ -362,9 +416,9 @@
+ "X509EncodedKeySpec for RSA public keys");
}
} else if (key instanceof RSAPrivateKey) {
- if (pkcs8KeySpecClass.isAssignableFrom(keySpec)) {
+ if (PKCS8_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
return keySpec.cast(new PKCS8EncodedKeySpec(key.getEncoded()));
- } else if (rsaPrivateCrtKeySpecClass.isAssignableFrom(keySpec)) {
+ } else if (RSA_PRIVCRT_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
if (key instanceof RSAPrivateCrtKey) {
RSAPrivateCrtKey crtKey = (RSAPrivateCrtKey)key;
return keySpec.cast(new RSAPrivateCrtKeySpec(
@@ -375,17 +429,19 @@
crtKey.getPrimeQ(),
crtKey.getPrimeExponentP(),
crtKey.getPrimeExponentQ(),
- crtKey.getCrtCoefficient()
+ crtKey.getCrtCoefficient(),
+ crtKey.getParams()
));
} else {
throw new InvalidKeySpecException
("RSAPrivateCrtKeySpec can only be used with CRT keys");
}
- } else if (rsaPrivateKeySpecClass.isAssignableFrom(keySpec)) {
+ } else if (RSA_PRIV_KEYSPEC_CLS.isAssignableFrom(keySpec)) {
RSAPrivateKey rsaKey = (RSAPrivateKey)key;
return keySpec.cast(new RSAPrivateKeySpec(
rsaKey.getModulus(),
- rsaKey.getPrivateExponent()
+ rsaKey.getPrivateExponent(),
+ rsaKey.getParams()
));
} else {
throw new InvalidKeySpecException
@@ -397,4 +453,16 @@
throw new InvalidKeySpecException("Neither public nor private key");
}
}
+
+ public static final class Legacy extends RSAKeyFactory {
+ public Legacy() {
+ super(KeyType.RSA);
+ }
+ }
+
+ public static final class PSS extends RSAKeyFactory {
+ public PSS() {
+ super(KeyType.PSS);
+ }
+ }
}