Skip to content

Commit

Permalink
Use references to obtain the signed elements in a signature (#188)
Browse files Browse the repository at this point in the history
Closes keycloak/keycloak-private#191

Signed-off-by: rmartinc <rmartinc@redhat.com>
  • Loading branch information
rmartinc authored and stianst committed Sep 17, 2024
1 parent f253f90 commit ae6a686
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
import org.w3c.dom.Element;
import org.w3c.dom.Node;

import javax.xml.crypto.dsig.XMLSignature;
import javax.xml.datatype.XMLGregorianCalendar;
import javax.xml.stream.XMLEventReader;

Expand Down Expand Up @@ -315,7 +314,7 @@ public static boolean isSignedElement(Element element) {
}

protected static Element getSignature(Element element) {
return DocumentUtil.getDirectChildElement(element, XMLSignature.XMLNS, "Signature");
return XMLSignatureUtil.getSignature(element);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,22 @@
import java.security.interfaces.RSAPublicKey;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import javax.xml.crypto.AlgorithmMethod;
import javax.xml.crypto.Data;
import javax.xml.crypto.KeySelector;
import javax.xml.crypto.KeySelectorException;
import javax.xml.crypto.KeySelectorResult;
import javax.xml.crypto.NodeSetData;
import javax.xml.crypto.URIReferenceException;
import javax.xml.crypto.XMLCryptoContext;
import javax.xml.crypto.dom.DOMStructure;
import org.keycloak.rotation.KeyLocator;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.util.SecurityActions;

/**
Expand Down Expand Up @@ -170,6 +177,52 @@ private static XMLSignatureFactory getXMLSignatureFactory() {
return xsf;
}

/**
* Returns the element that contains the signature for the passed element.
*
* @param element The element to search for the signature
* @return The signature element or null
*/
public static Element getSignature(Element element) {
Document doc = element.getOwnerDocument();
NodeList nl = doc.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature");
if (element.getAttributeNode(JBossSAMLConstants.ID.get()) != null) {
// set the saml ID to be found
element.setIdAttribute(JBossSAMLConstants.ID.get(), true);
}
KeySelector nullSelector = new KeySelector() {
@Override
public KeySelectorResult select(KeyInfo ki, KeySelector.Purpose prps, AlgorithmMethod am, XMLCryptoContext xmlcc) throws KeySelectorException {
return () -> null;
}
};

try {
for (int i = 0; i < nl.getLength(); i++) {
Element signatureElement = (Element) nl.item(i);
DOMValidateContext valContext = new DOMValidateContext(nullSelector, signatureElement);
DOMStructure structure = new DOMStructure(signatureElement);
XMLSignature signature = fac.unmarshalXMLSignature(structure);
for (Reference ref : (List<Reference>) signature.getSignedInfo().getReferences()) {
try {
Data data = fac.getURIDereferencer().dereference(ref, valContext);
if (data instanceof NodeSetData) {
Iterator<Node> it = ((NodeSetData) data).iterator();
if (it.hasNext() && element.equals(it.next())) {
return signatureElement;
}
}
} catch (URIReferenceException e) {
logger.trace("Invalid URI reference in signature " + ref.getURI());
}
}
}
} catch (MarshalException e) {
logger.trace("Error unmarshalling signature", e);
}
return null;
}

/**
* Use this method to not include the KeyInfo in the signature
*
Expand Down Expand Up @@ -404,7 +457,7 @@ public static Document sign(SignatureUtilTransferObject dto, String canonicaliza
* this way both assertions and the containing document are verified when signed.
*
* @param signedDoc
* @param publicKey
* @param locator
*
* @return
*
Expand All @@ -428,39 +481,46 @@ public static boolean validate(Document signedDoc, final KeyLocator locator) thr
if (locator == null)
throw logger.nullValueError("Public Key");

int signedAssertions = 0;
String assertionNameSpaceUri = null;
HashSet<Node> signedNodes = new HashSet<>();

for (int i = 0; i < nl.getLength(); i++) {
Node signatureNode = nl.item(i);
Node parent = signatureNode.getParentNode();
if (parent != null && JBossSAMLConstants.ASSERTION.get().equals(parent.getLocalName())) {
++signedAssertions;
if (assertionNameSpaceUri == null) {
assertionNameSpaceUri = parent.getNamespaceURI();
}
if (!validateSingleNode(signatureNode, locator, signedNodes)) {
return false;
}
}

if (! validateSingleNode(signatureNode, locator)) return false;
if (signedNodes.contains(signedDoc.getDocumentElement())) {
logger.trace("All signatures are OK and root document is signed");
return true;
}

NodeList assertions = signedDoc.getElementsByTagNameNS(assertionNameSpaceUri, JBossSAMLConstants.ASSERTION.get());
NodeList assertions = signedDoc.getElementsByTagNameNS(JBossSAMLURIConstants.ASSERTION_NSURI.get(), JBossSAMLConstants.ASSERTION.get());

if (signedAssertions > 0 && assertions != null && assertions.getLength() != signedAssertions) {
if (logger.isDebugEnabled()) {
logger.debug("SAML Response document may contain malicious assertions. Signature validation will fail.");
if (assertions.getLength() > 0) {
// if document is not fully signed check if all the assertions are signed
for (int i = 0; i < assertions.getLength(); i++) {
if (!signedNodes.contains(assertions.item(i))) {
logger.debug("SAML Response document may contain malicious assertions. Signature validation will fail.");
// there are unsigned assertions mixed with signed ones
return false;
}
}
// there are unsigned assertions mixed with signed ones
return false;
logger.trace("Document not signed but all assertions are signed OK");
return true;
}

return true;
return false;
}

public static boolean validateSingleNode(Node signatureNode, final KeyLocator locator) throws MarshalException, XMLSignatureException {
return validateSingleNode(signatureNode, locator, new HashSet<>());
}

public static boolean validateSingleNode(Node signatureNode, final KeyLocator locator, Set<Node> signedNodes) throws MarshalException, XMLSignatureException {
KeySelectorUtilizingKeyNameHint sel = new KeySelectorUtilizingKeyNameHint(locator);
try {
if (validateUsingKeySelector(signatureNode, sel)) {
if (validateUsingKeySelector(signatureNode, sel, signedNodes)) {
return true;
}
if (sel.wasKeyLocated()) {
Expand All @@ -477,7 +537,7 @@ public static boolean validateSingleNode(Node signatureNode, final KeyLocator lo

for (Key key : locator) {
try {
if (validateUsingKeySelector(signatureNode, KeySelector.singletonKeySelector(key))) {
if (validateUsingKeySelector(signatureNode, KeySelector.singletonKeySelector(key), signedNodes)) {
return true;
}
} catch (XMLSignatureException ex) { // pass through MarshalException
Expand All @@ -489,12 +549,26 @@ public static boolean validateSingleNode(Node signatureNode, final KeyLocator lo
return false;
}

private static boolean validateUsingKeySelector(Node signatureNode, KeySelector validationKeySelector) throws XMLSignatureException, MarshalException {
private static boolean validateUsingKeySelector(Node signatureNode, KeySelector validationKeySelector, Set<Node> signedNodes) throws XMLSignatureException, MarshalException {
DOMValidateContext valContext = new DOMValidateContext(validationKeySelector, signatureNode);
XMLSignature signature = fac.unmarshalXMLSignature(valContext);
boolean coreValidity = signature.validate(valContext);

if (! coreValidity) {
if (coreValidity) {
for (Reference ref : (List<Reference>) signature.getSignedInfo().getReferences()) {
try {
Data data = fac.getURIDereferencer().dereference(ref, valContext);
if (data instanceof NodeSetData) {
Iterator<Node> it = ((NodeSetData) data).iterator();
if (it.hasNext()) {
signedNodes.add(it.next()); // add the first referenced object as signed element
}
}
} catch (URIReferenceException e) {
// ignored as signature was ok so reference can be obtained
}
}
} else {
if (logger.isTraceEnabled()) {
boolean sv = signature.getSignatureValue().validate(valContext);
logger.trace("Signature validation status: " + sv);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.keycloak.testsuite.util.RealmBuilder;
import org.keycloak.testsuite.util.RoleBuilder;
import org.keycloak.testsuite.util.RolesBuilder;
import org.junit.Assert;
import org.junit.Test;
import org.keycloak.testsuite.util.SamlClient.Binding;
import org.keycloak.testsuite.util.SamlClientBuilder;
Expand Down Expand Up @@ -66,6 +67,7 @@
import static org.hamcrest.MatcherAssert.assertThat;
import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.ASSERTION_NSURI;
import static org.keycloak.saml.common.constants.JBossSAMLURIConstants.PROTOCOL_NSURI;
import org.keycloak.saml.common.util.DocumentUtil;
import static org.keycloak.testsuite.adapter.AbstractServletsAdapterTest.samlServletDeployment;
import static org.keycloak.testsuite.saml.AbstractSamlTest.REALM_NAME;
import static org.keycloak.testsuite.saml.AbstractSamlTest.REALM_PRIVATE_KEY;
Expand Down Expand Up @@ -182,6 +184,20 @@ public static void applyXSW8(Document document){
originalSignature.appendChild(object);
object.appendChild(assertion);
}

public static void noDocumentSignatureOnlyOneAssertionSignedBelowResponse(Document document){
// remove the signature for the whole response
removeDocumentSignature(document);
// move the signature from the assertion to the response level
Element assertion = (Element) document.getElementsByTagNameNS(ASSERTION_NSURI.get(), "Assertion").item(0);
Element signature = (Element) assertion.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature").item(0);
assertion.removeChild(signature);
document.getDocumentElement().appendChild(signature);
// create a second assertion without signature
Element evilAssertion = (Element) assertion.cloneNode(true);
evilAssertion.setAttribute("ID", "_evil_assertion_ID");
document.getDocumentElement().insertBefore(evilAssertion, assertion);
}
}

@Page
Expand Down Expand Up @@ -322,11 +338,21 @@ private static void removeAllSignatures(Document doc) throws DOMException {
}
}

private static void removeDocumentSignature(Document doc) throws DOMException {
Element responseSignature = (Element) doc.getElementsByTagNameNS(XMLSignature.XMLNS, "Signature").item(0);
Assert.assertNotNull(doc.getDocumentElement().removeChild(responseSignature));
}

@Test
public void testNoChange() throws Exception {
testSamlResponseModifications(r -> {}, true);
}

@Test
public void testOnlyAssertionSignature() throws Exception {
testSamlResponseModifications(SamlSignatureTest::removeDocumentSignature, true);
}

@Test
public void testRemoveSignatures() throws Exception {
testSamlResponseModifications(SamlSignatureTest::removeAllSignatures, false);
Expand Down Expand Up @@ -372,4 +398,8 @@ public void testXSW8() throws Exception {
testSamlResponseModifications(XSWHelpers::applyXSW8, false);
}

@Test
public void testNoDocumentSignatureOnlyOneAssertionSignedBelowResponse() throws Exception {
testSamlResponseModifications(XSWHelpers::noDocumentSignatureOnlyOneAssertionSignedBelowResponse, false);
}
}

0 comments on commit ae6a686

Please sign in to comment.