GordianCoreLMSSpec.java

/*
 * GordianKnot: Security Suite
 * Copyright 2026. Tony Washer
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License.  You may obtain a copy
 * of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package io.github.tonywasher.joceanus.gordianknot.impl.core.spec.keypair;

import io.github.tonywasher.joceanus.gordianknot.api.base.GordianLength;
import io.github.tonywasher.joceanus.gordianknot.api.keypair.spec.GordianLMSSpec;
import io.github.tonywasher.joceanus.gordianknot.impl.core.spec.base.GordianSpecConstants;
import org.bouncycastle.pqc.crypto.lms.LMOtsParameters;
import org.bouncycastle.pqc.crypto.lms.LMSParameters;
import org.bouncycastle.pqc.crypto.lms.LMSigParameters;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * LMS KeyTypes.
 */
public class GordianCoreLMSSpec
        implements GordianLMSSpec {
    /**
     * Max depth for HSS key.
     */
    public static final int MAX_DEPTH = 8;

    /**
     * Invalid length error.
     */
    private static final String INVALID_LENGTH = "Invalid Length: ";

    /**
     * The hash.
     */
    private final GordianLMSHash theHash;

    /**
     * The width.
     */
    private final GordianLMSWidth theWidth;

    /**
     * The height.
     */
    private final GordianLMSHeight theHeight;

    /**
     * The length.
     */
    private final GordianLength theLength;

    /**
     * The tree depth.
     */
    private final int theDepth;

    /**
     * The Parameters.
     */
    private final LMSParameters theParams;

    /**
     * The Validity.
     */
    private final boolean isValid;

    /**
     * The String name.
     */
    private String theName;

    /**
     * Constructor.
     *
     * @param pHashType the hashType
     * @param pHeight   the height
     * @param pWidth    the width
     * @param pLength   the length
     */
    GordianCoreLMSSpec(final GordianLMSHash pHashType,
                       final GordianLMSHeight pHeight,
                       final GordianLMSWidth pWidth,
                       final GordianLength pLength) {
        this(pHashType, pHeight, pWidth, pLength, 1);
    }

    /**
     * Constructor.
     *
     * @param pLMSSpec   the LMSSpec
     * @param pTreeDepth the treeDepth
     */
    private GordianCoreLMSSpec(final GordianLMSSpec pLMSSpec,
                               final int pTreeDepth) {
        this(pLMSSpec.getHash(), pLMSSpec.getHeight(), pLMSSpec.getWidth(), pLMSSpec.getLength(), pTreeDepth);
    }

    /**
     * Constructor.
     *
     * @param pHashType  the hashType
     * @param pHeight    the height
     * @param pWidth     the width
     * @param pLength    the length
     * @param pTreeDepth the treeDepth
     */
    GordianCoreLMSSpec(final GordianLMSHash pHashType,
                       final GordianLMSHeight pHeight,
                       final GordianLMSWidth pWidth,
                       final GordianLength pLength,
                       final int pTreeDepth) {
        /* Store parameters */
        theHash = pHashType;
        theWidth = pWidth;
        theHeight = pHeight;
        theLength = pLength;
        theDepth = pTreeDepth;

        /* Check validity */
        isValid = checkValidity();

        /* Calculate parameters */
        final LMSigParameters mySig = isValid ? getSigParameter() : null;
        final LMOtsParameters myOts = isValid ? getOtsParameter() : null;
        theParams = isValid ? new LMSParameters(mySig, myOts) : null;
    }

    @Override
    public GordianLMSHash getHash() {
        return theHash;
    }

    @Override
    public GordianLMSHeight getHeight() {
        return theHeight;
    }

    @Override
    public GordianLMSWidth getWidth() {
        return theWidth;
    }

    @Override
    public GordianLength getLength() {
        return theLength;
    }

    @Override
    public int getTreeDepth() {
        return theDepth;
    }

    /**
     * Obtain the parameters.
     *
     * @return the parameters
     */
    public LMSParameters getParameters() {
        return theParams;
    }

    /**
     * Is the keySpec high (height ≥ 15)?
     *
     * @return true/false.
     */
    public boolean isHigh() {
        return isValid && isHigh(theHeight);
    }

    @Override
    public boolean isValid() {
        return isValid;
    }

    /**
     * Check spec validity.
     *
     * @return valid true/false
     */
    protected boolean checkValidity() {
        if (theWidth == null || theHeight == null || theHash == null || theLength == null) {
            return false;
        }
        if (theDepth < 1 || theDepth > MAX_DEPTH) {
            return false;
        }
        return switch (theLength) {
            case LEN_192, LEN_256 -> true;
            default -> false;
        };
    }

    @Override
    public String toString() {
        /* If we have not yet loaded the name */
        if (theName == null) {
            /* If the keySpec is valid */
            if (isValid) {
                /* Load the name */
                theName = theHash.toString() + GordianSpecConstants.SEP + theWidth.toString()
                        + GordianSpecConstants.SEP + theHeight.toString() + GordianSpecConstants.SEP + theLength.toString();
                if (theDepth > 1) {
                    theName = "HSS-" + theDepth + GordianSpecConstants.SEP + theName;
                }
            } else {
                /* Report invalid spec */
                theName = "InvalidLMSKeySpec: " + theHash + GordianSpecConstants.SEP + theWidth
                        + GordianSpecConstants.SEP + theHeight + GordianSpecConstants.SEP + theLength;
                if (theDepth != 1) {
                    theName += GordianSpecConstants.SEP + theDepth;
                }
            }
        }

        /* return the name */
        return theName;
    }

    @Override
    public boolean equals(final Object pThat) {
        /* Handle the trivial cases */
        if (this == pThat) {
            return true;
        }
        if (pThat == null) {
            return false;
        }

        /* Check fields */
        return pThat instanceof GordianCoreLMSSpec myThat
                && theHash == myThat.theHash
                && theLength == myThat.theLength
                && theWidth == myThat.theWidth
                && theHeight == myThat.theHeight
                && theDepth == myThat.theDepth;
    }

    @Override
    public int hashCode() {
        return Objects.hash(theHash, theHeight, theWidth, theLength, theDepth);
    }

    /**
     * Is the parameter high (height &ge; 15)?
     *
     * @param pHeight the height
     * @return true/false.
     */
    private boolean isHigh(final GordianLMSHeight pHeight) {
        return switch (pHeight) {
            case H15, H20, H25 -> true;
            default -> false;
        };
    }

    /**
     * Obtain the sigParameter.
     *
     * @return the parameter
     */
    private LMSigParameters getSigParameter() {
        return switch (theHeight) {
            case H5 -> getH5Parameter();
            case H10 -> getH10Parameter();
            case H15 -> getH15Parameter();
            case H20 -> getH20Parameter();
            case H25 -> getH25Parameter();
            default -> throw new IllegalStateException();
        };
    }

    /**
     * Obtain the H5 sigParameter.
     *
     * @return the parameter
     */
    private LMSigParameters getH5Parameter() {
        return switch (theLength) {
            case LEN_192 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n24_h5 : LMSigParameters.lms_shake256_n24_h5;
            case LEN_256 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n32_h5 : LMSigParameters.lms_shake256_n32_h5;
            default -> throw new IllegalArgumentException(INVALID_LENGTH + theLength);
        };
    }

    /**
     * Obtain the H10 sigParameter.
     *
     * @return the parameter
     */
    private LMSigParameters getH10Parameter() {
        return switch (theLength) {
            case LEN_192 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n24_h10 : LMSigParameters.lms_shake256_n24_h10;
            case LEN_256 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n32_h10 : LMSigParameters.lms_shake256_n32_h10;
            default -> throw new IllegalArgumentException(INVALID_LENGTH + theLength);
        };
    }

    /**
     * Obtain the H15 sigParameter.
     *
     * @return the parameter
     */
    private LMSigParameters getH15Parameter() {
        return switch (theLength) {
            case LEN_192 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n24_h15 : LMSigParameters.lms_shake256_n24_h15;
            case LEN_256 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n32_h15 : LMSigParameters.lms_shake256_n32_h15;
            default -> throw new IllegalArgumentException(INVALID_LENGTH + theLength);
        };
    }

    /**
     * Obtain the H20 sigParameter.
     *
     * @return the parameter
     */
    private LMSigParameters getH20Parameter() {
        return switch (theLength) {
            case LEN_192 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n24_h20 : LMSigParameters.lms_shake256_n24_h20;
            case LEN_256 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n32_h20 : LMSigParameters.lms_shake256_n32_h20;
            default -> throw new IllegalArgumentException(INVALID_LENGTH + theLength);
        };
    }

    /**
     * Obtain the H25 sigParameter.
     *
     * @return the parameter
     */
    private LMSigParameters getH25Parameter() {
        return switch (theLength) {
            case LEN_192 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n24_h25 : LMSigParameters.lms_shake256_n24_h25;
            case LEN_256 ->
                    theHash == GordianLMSHash.SHA256 ? LMSigParameters.lms_sha256_n32_h25 : LMSigParameters.lms_shake256_n32_h25;
            default -> throw new IllegalArgumentException(INVALID_LENGTH + theLength);
        };
    }

    /**
     * Obtain the otsParameter.
     *
     * @return the parameter
     */
    private LMOtsParameters getOtsParameter() {
        return switch (theWidth) {
            case W1 -> getW1Parameter();
            case W2 -> getW2Parameter();
            case W4 -> getW4Parameter();
            case W8 -> getW8Parameter();
            default -> throw new IllegalStateException();
        };
    }

    /**
     * Obtain the W1 otsParameter.
     *
     * @return the parameter
     */
    private LMOtsParameters getW1Parameter() {
        return switch (theLength) {
            case LEN_192 ->
                    theHash == GordianLMSHash.SHA256 ? LMOtsParameters.sha256_n24_w1 : LMOtsParameters.shake256_n24_w1;
            case LEN_256 ->
                    theHash == GordianLMSHash.SHA256 ? LMOtsParameters.sha256_n32_w1 : LMOtsParameters.shake256_n32_w1;
            default -> throw new IllegalArgumentException(INVALID_LENGTH + theLength);
        };
    }

    /**
     * Obtain the W2 otsParameter.
     *
     * @return the parameter
     */
    private LMOtsParameters getW2Parameter() {
        return switch (theLength) {
            case LEN_192 ->
                    theHash == GordianLMSHash.SHA256 ? LMOtsParameters.sha256_n24_w2 : LMOtsParameters.shake256_n24_w2;
            case LEN_256 ->
                    theHash == GordianLMSHash.SHA256 ? LMOtsParameters.sha256_n32_w2 : LMOtsParameters.shake256_n32_w2;
            default -> throw new IllegalArgumentException(INVALID_LENGTH + theLength);
        };
    }

    /**
     * Obtain the W4 otsParameter.
     *
     * @return the parameter
     */
    private LMOtsParameters getW4Parameter() {
        return switch (theLength) {
            case LEN_192 ->
                    theHash == GordianLMSHash.SHA256 ? LMOtsParameters.sha256_n24_w4 : LMOtsParameters.shake256_n24_w4;
            case LEN_256 ->
                    theHash == GordianLMSHash.SHA256 ? LMOtsParameters.sha256_n32_w4 : LMOtsParameters.shake256_n32_w4;
            default -> throw new IllegalArgumentException(INVALID_LENGTH + theLength);
        };
    }

    /**
     * Obtain the W8 otsParameter.
     *
     * @return the parameter
     */
    private LMOtsParameters getW8Parameter() {
        return switch (theLength) {
            case LEN_192 ->
                    theHash == GordianLMSHash.SHA256 ? LMOtsParameters.sha256_n24_w8 : LMOtsParameters.shake256_n24_w8;
            case LEN_256 ->
                    theHash == GordianLMSHash.SHA256 ? LMOtsParameters.sha256_n32_w8 : LMOtsParameters.shake256_n32_w8;
            default -> throw new IllegalArgumentException(INVALID_LENGTH + theLength);
        };
    }

    /**
     * Obtain a list of all possible LMS specs.
     *
     * @return the list
     */
    public static List<GordianLMSSpec> listAllPossibleSpecs() {
        /* Create the list */
        final List<GordianLMSSpec> mySpecs = new ArrayList<>();

        /* Add the specs */
        for (final GordianLMSSpec mySpec : listPossibleLMSSpecs()) {
            for (int i = 1; i < MAX_DEPTH; i++) {
                mySpecs.add(new GordianCoreLMSSpec(mySpec, i));
            }
        }

        /* Return the list */
        return mySpecs;
    }

    /**
     * Obtain a list of all possible LMS specs.
     *
     * @return the list
     */
    public static List<GordianLMSSpec> listPossibleLMSSpecs() {
        /* Create the list */
        final List<GordianLMSSpec> mySpecs = new ArrayList<>();

        /* Add the specs */
        for (final GordianLMSHeight myHeight : GordianLMSHeight.values()) {
            for (final GordianLMSWidth myWidth : GordianLMSWidth.values()) {
                mySpecs.add(new GordianCoreLMSSpec(GordianLMSHash.SHA256, myHeight, myWidth, GordianLength.LEN_256));
                mySpecs.add(new GordianCoreLMSSpec(GordianLMSHash.SHA256, myHeight, myWidth, GordianLength.LEN_192));
                mySpecs.add(new GordianCoreLMSSpec(GordianLMSHash.SHAKE256, myHeight, myWidth, GordianLength.LEN_256));
                mySpecs.add(new GordianCoreLMSSpec(GordianLMSHash.SHAKE256, myHeight, myWidth, GordianLength.LEN_192));
            }
        }

        /* Return the list */
        return mySpecs;
    }

    /**
     * Match keySpec against LMSParameters.
     *
     * @param pSigParams the sigParameters
     * @param pOtsParams the otsParameters
     * @return the matching keySpec
     */
    public static GordianLMSSpec determineSpec(final LMSigParameters pSigParams,
                                               final LMOtsParameters pOtsParams) {
        final List<GordianLMSSpec> mySpecs = listPossibleLMSSpecs();
        for (GordianLMSSpec mySpec : mySpecs) {
            final GordianCoreLMSSpec myCoreSpec = (GordianCoreLMSSpec) mySpec;
            if (pSigParams.equals(myCoreSpec.getParameters().getLMSigParam())
                    && pOtsParams.equals(myCoreSpec.getParameters().getLMOTSParam())) {
                return mySpec;
            }
        }
        throw new IllegalArgumentException("Unsupported LMSSpec");
    }
}