/* ext_mlkem.c
 *
 * Copyright (C) 2006-2025 wolfSSL Inc.
 *
 * This file is part of wolfSSL.
 *
 * wolfSSL is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * wolfSSL is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA
 */

#include <wolfssl/wolfcrypt/libwolfssl_sources.h>

#if defined(WOLFSSL_HAVE_MLKEM) && !defined(WOLFSSL_WC_MLKEM)
#include <wolfssl/wolfcrypt/ext_mlkem.h>

#ifdef NO_INLINE
    #include <wolfssl/wolfcrypt/misc.h>
#else
    #define WOLFSSL_MISC_INCLUDED
    #include <wolfcrypt/src/misc.c>
#endif

#if defined (HAVE_LIBOQS)

#include <wolfssl/wolfcrypt/port/liboqs/liboqs.h>

static const char* OQS_ID2name(int id) {
    switch (id) {
    #ifndef WOLFSSL_NO_ML_KEM
        case WC_ML_KEM_512:  return OQS_KEM_alg_ml_kem_512;
        case WC_ML_KEM_768:  return OQS_KEM_alg_ml_kem_768;
        case WC_ML_KEM_1024: return OQS_KEM_alg_ml_kem_1024;
    #endif
    #ifdef WOLFSSL_MLKEM_KYBER
        case KYBER_LEVEL1: return OQS_KEM_alg_kyber_512;
        case KYBER_LEVEL3: return OQS_KEM_alg_kyber_768;
        case KYBER_LEVEL5: return OQS_KEM_alg_kyber_1024;
    #endif
        default:           break;
    }
    return NULL;
}

int ext_mlkem_enabled(int id)
{
    const char * name = OQS_ID2name(id);
    return OQS_KEM_alg_is_enabled(name);
}
#endif

/******************************************************************************/
/* Initializer and cleanup functions. */

/**
 * Initialize the Kyber key.
 *
 * @param  [out]  key    Kyber key object to initialize.
 * @param  [in]   type   Type of key: KYBER512, KYBER768, KYBER1024.
 * @param  [in]   heap   Dynamic memory hint.
 * @param  [in]   devId  Device Id.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key is NULL or type is unrecognized.
 * @return  NOT_COMPILED_IN when key type is not supported.
 */
int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
{
    int ret = 0;

    /* Validate key. */
    if (key == NULL) {
        ret = BAD_FUNC_ARG;
    }
    if (ret == 0) {
        /* Validate type. */
        switch (type) {
#ifndef WOLFSSL_NO_ML_KEM
        case WC_ML_KEM_512:
    #ifdef HAVE_LIBOQS
        case WC_ML_KEM_768:
        case WC_ML_KEM_1024:
    #endif /* HAVE_LIBOQS */
#endif
#ifdef WOLFSSL_MLKEM_KYBER
        case KYBER_LEVEL1:
    #ifdef HAVE_LIBOQS
        case KYBER_LEVEL3:
        case KYBER_LEVEL5:
    #endif /* HAVE_LIBOQS */
#endif
            break;
        default:
            /* No other values supported. */
            ret = BAD_FUNC_ARG;
            break;
        }
    }
    if (ret == 0) {
        /* Zero out all data. */
        XMEMSET(key, 0, sizeof(*key));

        /* Keep type for parameters. */
        key->type = type;

#ifdef WOLF_CRYPTO_CB
        key->devCtx = NULL;
        key->devId = devId;
#endif
    }

    (void)heap;
    (void)devId;

    return ret;
}

/**
 * Free the Kyber key object.
 *
 * @param  [in, out]  key   Kyber key object to dispose of.
 */
int wc_MlKemKey_Free(MlKemKey* key)
{
    if (key != NULL) {
        /* Ensure all private data is zeroed. */
        ForceZero(key, sizeof(*key));
    }

    return 0;
}

/******************************************************************************/
/* Data size getters. */

/**
 * Get the size in bytes of encoded private key for the key.
 *
 * @param  [in]   key  Kyber key object.
 * @param  [out]  len  Length of encoded private key in bytes.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key or len is NULL.
 * @return  NOT_COMPILED_IN when key type is not supported.
 */
int wc_MlKemKey_PrivateKeySize(MlKemKey* key, word32* len)
{
    int ret = 0;

    /* Validate parameters. */
    if ((key == NULL) || (len == NULL)) {
        ret = BAD_FUNC_ARG;
    }

#ifdef HAVE_LIBOQS
    /* NOTE: SHAKE and AES variants have the same length private key. */
    if (ret == 0) {
        switch (key->type) {
    #ifndef WOLFSSL_NO_ML_KEM
        case WC_ML_KEM_512:
            *len = OQS_KEM_ml_kem_512_length_secret_key;
            break;
        case WC_ML_KEM_768:
            *len = OQS_KEM_ml_kem_768_length_secret_key;
            break;
        case WC_ML_KEM_1024:
            *len = OQS_KEM_ml_kem_1024_length_secret_key;
            break;
    #endif
    #ifdef WOLFSSL_MLKEM_KYBER
        case KYBER_LEVEL1:
            *len = OQS_KEM_kyber_512_length_secret_key;
            break;
        case KYBER_LEVEL3:
            *len = OQS_KEM_kyber_768_length_secret_key;
            break;
        case KYBER_LEVEL5:
            *len = OQS_KEM_kyber_1024_length_secret_key;
            break;
    #endif
        default:
            /* No other values supported. */
            ret = BAD_FUNC_ARG;
            break;
        }
    }
#endif /* HAVE_LIBOQS */

    return ret;
}

/**
 * Get the size in bytes of encoded public key for the key.
 *
 * @param  [in]   key  Kyber key object.
 * @param  [out]  len  Length of encoded public key in bytes.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key or len is NULL.
 * @return  NOT_COMPILED_IN when key type is not supported.
 */
int wc_MlKemKey_PublicKeySize(MlKemKey* key, word32* len)
{
    int ret = 0;

    /* Validate parameters. */
    if ((key == NULL) || (len == NULL)) {
        ret = BAD_FUNC_ARG;
    }

#ifdef HAVE_LIBOQS
    /* NOTE: SHAKE and AES variants have the same length public key. */
    if (ret == 0) {
        switch (key->type) {
    #ifndef WOLFSSL_NO_ML_KEM
        case WC_ML_KEM_512:
            *len = OQS_KEM_ml_kem_512_length_public_key;
            break;
        case WC_ML_KEM_768:
            *len = OQS_KEM_ml_kem_768_length_public_key;
            break;
        case WC_ML_KEM_1024:
            *len = OQS_KEM_ml_kem_1024_length_public_key;
            break;
    #endif
    #ifdef WOLFSSL_MLKEM_KYBER
        case KYBER_LEVEL1:
            *len = OQS_KEM_kyber_512_length_public_key;
            break;
        case KYBER_LEVEL3:
            *len = OQS_KEM_kyber_768_length_public_key;
            break;
        case KYBER_LEVEL5:
            *len = OQS_KEM_kyber_1024_length_public_key;
            break;
    #endif
        default:
            /* No other values supported. */
            ret = BAD_FUNC_ARG;
            break;
        }
    }
#endif /* HAVE_LIBOQS */

    return ret;
}

/**
 * Get the size in bytes of cipher text for key.
 *
 * @param  [in]   key  Kyber key object.
 * @param  [out]  len  Length of cipher text in bytes.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key or len is NULL.
 * @return  NOT_COMPILED_IN when key type is not supported.
 */
int wc_MlKemKey_CipherTextSize(MlKemKey* key, word32* len)
{
    int ret = 0;

    /* Validate parameters. */
    if ((key == NULL) || (len == NULL)) {
        ret = BAD_FUNC_ARG;
    }

#ifdef HAVE_LIBOQS
    /* NOTE: SHAKE and AES variants have the same length ciphertext. */
    if (ret == 0) {
        switch (key->type) {
    #ifndef WOLFSSL_NO_ML_KEM
        case WC_ML_KEM_512:
            *len = OQS_KEM_ml_kem_512_length_ciphertext;
            break;
        case WC_ML_KEM_768:
            *len = OQS_KEM_ml_kem_768_length_ciphertext;
            break;
        case WC_ML_KEM_1024:
            *len = OQS_KEM_ml_kem_1024_length_ciphertext;
            break;
    #endif
    #ifdef WOLFSSL_MLKEM_KYBER
        case KYBER_LEVEL1:
            *len = OQS_KEM_kyber_512_length_ciphertext;
            break;
        case KYBER_LEVEL3:
            *len = OQS_KEM_kyber_768_length_ciphertext;
            break;
        case KYBER_LEVEL5:
            *len = OQS_KEM_kyber_1024_length_ciphertext;
            break;
    #endif
        default:
            /* No other values supported. */
            ret = BAD_FUNC_ARG;
            break;
        }
    }
#endif /* HAVE_LIBOQS */

    return ret;
}

/**
 * Size of a shared secret in bytes. Always KYBER_SS_SZ.
 *
 * @param  [in]   key  Kyber key object. Not used.
 * @param  [out]  Size of the shared secret created with a Kyber key.
 * @return  0 on success.
 * @return  0 to indicate success.
 */
int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
{
    (void)key;
    /* Validate parameters. */
    if (len == NULL) {
        return BAD_FUNC_ARG;
    }

    *len = KYBER_SS_SZ;

    return 0;
}

/******************************************************************************/
/* Cryptographic operations. */

/**
 * Make a Kyber key object using a random number generator.
 *
 * NOTE: rng is ignored. OQS doesn't use our RNG.
 *
 * @param  [in, out]  key   Kyber key ovject.
 * @param  [in]       rng   Random number generator.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key or rng is NULL.
 * @return  MEMORY_E when dynamic memory allocation failed.
 */
int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
{
    int ret = 0;
#ifdef HAVE_LIBOQS
    const char* algName = NULL;
    OQS_KEM *kem = NULL;
#endif

    /* Validate parameter. */
    if (key == NULL) {
        return BAD_FUNC_ARG;
    }

#ifdef WOLF_CRYPTO_CB
    #ifndef WOLF_CRYPTO_CB_FIND
    if (key->devId != INVALID_DEVID)
    #endif
    {
        ret = wc_CryptoCb_MakePqcKemKey(rng, WC_PQC_KEM_TYPE_KYBER,
                                        key->type, key);
        if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
            return ret;
        /* fall-through when unavailable */
        ret = 0;
    }
#endif

#ifdef HAVE_LIBOQS
    if (ret == 0) {
        algName = OQS_ID2name(key->type);
        if (algName == NULL) {
            ret = BAD_FUNC_ARG;
        }
    }

    if (ret == 0) {
        kem = OQS_KEM_new(algName);
        if (kem == NULL) {
            ret = BAD_FUNC_ARG;
        }
    }
    if (ret == 0) {
        ret = wolfSSL_liboqsRngMutexLock(rng);
    }
    if (ret == 0) {
        if (OQS_KEM_keypair(kem, key->pub, key->priv) !=
            OQS_SUCCESS) {
            ret = BAD_FUNC_ARG;
        }
    }
    wolfSSL_liboqsRngMutexUnlock();
    OQS_KEM_free(kem);
#endif /* HAVE_LIBOQS */

    if (ret != 0) {
        ForceZero(key, sizeof(*key));
    }

    return ret;
}

/**
 * Make a Kyber key object using random data.
 *
 * @param  [in, out]  key   Kyber key ovject.
 * @param  [in]       rng   Random number generator.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key or rand is NULL.
 * @return  BUFFER_E when length is not KYBER_MAKEKEY_RAND_SZ.
 * @return  NOT_COMPILED_IN when key type is not supported.
 * @return  MEMORY_E when dynamic memory allocation failed.
 */
int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
    int len)
{
    (void)rand;
    (void)len;
    /* OQS doesn't support external randomness. */
    return wc_MlKemKey_MakeKey(key, NULL);
}

/**
 * Encapsulate with random number generator and derive secret.
 *
 * @param  [in]   key  Kyber key object.
 * @param  [out]  ct   Cipher text.
 * @param  [out]  ss   Shared secret generated.
 * @param  [in]   rng  Random number generator.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key, ct, ss or RNG is NULL.
 * @return  NOT_COMPILED_IN when key type is not supported.
 * @return  MEMORY_E when dynamic memory allocation failed.
 */
int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* ct, unsigned char* ss,
    WC_RNG* rng)
{
    int ret = 0;
#ifdef WOLF_CRYPTO_CB
    word32 ctlen = 0;
#endif
#ifdef HAVE_LIBOQS
    const char * algName = NULL;
    OQS_KEM *kem = NULL;
#endif

    (void)rng;

    /* Validate parameters. */
    if ((key == NULL) || (ct == NULL) || (ss == NULL)) {
        ret = BAD_FUNC_ARG;
    }

#ifdef WOLF_CRYPTO_CB
    if (ret == 0) {
        ret = wc_MlKemKey_CipherTextSize(key, &ctlen);
    }
    if ((ret == 0)
    #ifndef WOLF_CRYPTO_CB_FIND
        && (key->devId != INVALID_DEVID)
    #endif
    ) {
        ret = wc_CryptoCb_PqcEncapsulate(ct, ctlen, ss, KYBER_SS_SZ, rng,
                                         WC_PQC_KEM_TYPE_KYBER, key);
        if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
            return ret;
        /* fall-through when unavailable */
        ret = 0;
    }
#endif

#ifdef HAVE_LIBOQS
    if (ret == 0) {
        algName = OQS_ID2name(key->type);
        if (algName == NULL) {
            ret = BAD_FUNC_ARG;
        }
    }
    if (ret == 0) {
        kem = OQS_KEM_new(algName);
        if (kem == NULL) {
            ret = BAD_FUNC_ARG;
        }
    }
    if (ret == 0) {
        ret = wolfSSL_liboqsRngMutexLock(rng);
    }
    if (ret == 0) {
        if (OQS_KEM_encaps(kem, ct, ss, key->pub) != OQS_SUCCESS) {
            ret = BAD_FUNC_ARG;
        }
    }
    wolfSSL_liboqsRngMutexUnlock();
    OQS_KEM_free(kem);
#endif /* HAVE_LIBOQS */

    return ret;
}

/**
 * Encapsulate with random data and derive secret.
 *
 * @param  [out]  ct    Cipher text.
 * @param  [out]  ss    Shared secret generated.
 * @param  [in]   rand  Random data.
 * @param  [in]   len   Random data.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key, ct, ss or RNG is NULL.
 * @return  BUFFER_E when len is not KYBER_ENC_RAND_SZ.
 * @return  NOT_COMPILED_IN when key type is not supported.
 * @return  MEMORY_E when dynamic memory allocation failed.
 */
int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* ct,
    unsigned char* ss, const unsigned char* rand, int len)
{
    (void)rand;
    (void)len;
    /* OQS doesn't support external randomness. */
    return wc_MlKemKey_Encapsulate(key, ct, ss, NULL);
}

/**
 * Decapsulate the cipher text to calculate the shared secret.
 *
 * Validates the cipher text by encapsulating and comparing with data passed in.
 *
 * @param  [in]   key  Kyber key object.
 * @param  [out]  ss   Shared secret.
 * @param  [in]   ct   Cipher text.
 * @param  [in]   len  Length of cipher text.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key, ss or cr are NULL.
 * @return  NOT_COMPILED_IN when key type is not supported.
 * @return  BUFFER_E when len is not the length of cipher text for the key type.
 * @return  MEMORY_E when dynamic memory allocation failed.
 */
int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
    const unsigned char* ct, word32 len)
{
    int ret = 0;
    word32 ctlen = 0;
#ifdef HAVE_LIBOQS
    const char * algName = NULL;
    OQS_KEM *kem = NULL;
#endif

    /* Validate parameters. */
    if ((key == NULL) || (ss == NULL) || (ct == NULL)) {
        ret = BAD_FUNC_ARG;
    }
    if (ret == 0) {
        ret = wc_MlKemKey_CipherTextSize(key, &ctlen);
    }
    if ((ret == 0) && (len != ctlen)) {
        ret = BUFFER_E;
    }

#ifdef WOLF_CRYPTO_CB
    if ((ret == 0)
    #ifndef WOLF_CRYPTO_CB_FIND
        && (key->devId != INVALID_DEVID)
    #endif
    ) {
        ret = wc_CryptoCb_PqcDecapsulate(ct, ctlen, ss, KYBER_SS_SZ,
                                         WC_PQC_KEM_TYPE_KYBER, key);
        if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
            return ret;
        /* fall-through when unavailable */
        ret = 0;
    }
#endif

#ifdef HAVE_LIBOQS
    if (ret == 0) {
        algName = OQS_ID2name(key->type);
        if (algName == NULL) {
            ret = BAD_FUNC_ARG;
        }
    }
    if (ret == 0) {
        kem = OQS_KEM_new(algName);
        if (kem == NULL) {
            ret = BAD_FUNC_ARG;
        }
    }
    if (ret == 0) {
        if (OQS_KEM_decaps(kem, ss, ct, key->priv) != OQS_SUCCESS) {
            ret = BAD_FUNC_ARG;
        }
    }

    OQS_KEM_free(kem);
#endif /* HAVE_LIBOQS */

    return ret;

}

/******************************************************************************/
/* Encoding and decoding functions. */

/**
 * Decode the private key.
 *
 * We store the whole thing in the private key buffer. Note this means we cannot
 * do the encapsulation operation with the private key. But generally speaking
 * this is never done.
 *
 * @param  [in, out]  key  Kyber key object.
 * @param  [in]       in   Buffer holding encoded key.
 * @param  [in]       len  Length of data in buffer.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key or in is NULL.
 * @return  NOT_COMPILED_IN when key type is not supported.
 * @return  BUFFER_E when len is not the correct size.
 */
int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
    word32 len)
{
    int ret = 0;
    word32 privLen = 0;

    /* Validate parameters. */
    if ((key == NULL) || (in == NULL)) {
        ret = BAD_FUNC_ARG;
    }

    if (ret == 0) {
        ret = wc_MlKemKey_PrivateKeySize(key, &privLen);
    }

    /* Ensure the data is the correct length for the key type. */
    if ((ret == 0) && (len != privLen)) {
        ret = BUFFER_E;
    }

    if (ret == 0) {
        XMEMCPY(key->priv, in, privLen);
    }

    return ret;
}

/**
 * Decode public key.
 *
 * We store the whole thing in the public key buffer.
 *
 * @param  [in, out]  key  Kyber key object.
 * @param  [in]       in   Buffer holding encoded key.
 * @param  [in]       len  Length of data in buffer.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key or in is NULL.
 * @return  NOT_COMPILED_IN when key type is not supported.
 * @return  BUFFER_E when len is not the correct size.
 */
int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in,
    word32 len)
{
    int ret = 0;
    word32 pubLen = 0;

    /* Validate parameters. */
    if ((key == NULL) || (in == NULL)) {
        ret = BAD_FUNC_ARG;
    }

    if (ret == 0) {
        ret = wc_MlKemKey_PublicKeySize(key, &pubLen);
    }

    /* Ensure the data is the correct length for the key type. */
    if ((ret == 0) && (len != pubLen)) {
        ret = BUFFER_E;
    }

    if (ret == 0) {
        XMEMCPY(key->pub, in, pubLen);
    }

    return ret;
}

/**
 * Encode the private key.
 *
 * We stored it as a blob so we can just copy it over.
 *
 * @param  [in]   key  Kyber key object.
 * @param  [out]  out  Buffer to hold data.
 * @param  [in]   len  Size of buffer in bytes.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key or out is NULL or private/public key not
 * available.
 * @return  NOT_COMPILED_IN when key type is not supported.
 */
int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len)
{
    int ret = 0;
    unsigned int privLen = 0;

    if ((key == NULL) || (out == NULL)) {
        ret = BAD_FUNC_ARG;
    }

    if (ret == 0) {
        ret = wc_MlKemKey_PrivateKeySize(key, &privLen);
    }

    /* Check buffer is big enough for encoding. */
    if ((ret == 0) && (len != privLen)) {
        ret = BUFFER_E;
    }

    if (ret == 0) {
        XMEMCPY(out, key->priv, privLen);
    }

    return ret;
}

/**
 * Encode the public key.
 *
 * We stored it as a blob so we can just copy it over.
 *
 * @param  [in]   key  Kyber key object.
 * @param  [out]  out  Buffer to hold data.
 * @param  [in]   len  Size of buffer in bytes.
 * @return  0 on success.
 * @return  BAD_FUNC_ARG when key or out is NULL or public key not available.
 * @return  NOT_COMPILED_IN when key type is not supported.
 */
int wc_MlKemKey_EncodePublicKey(MlKemKey* key, unsigned char* out, word32 len)
{
    int ret = 0;
    unsigned int pubLen = 0;

    if ((key == NULL) || (out == NULL)) {
        ret = BAD_FUNC_ARG;
    }

    if (ret == 0) {
        ret = wc_MlKemKey_PublicKeySize(key, &pubLen);
    }

    /* Check buffer is big enough for encoding. */
    if ((ret == 0) && (len != pubLen)) {
        ret = BUFFER_E;
    }

    if (ret == 0) {
        XMEMCPY(out, key->pub, pubLen);
    }

    return ret;
}

#endif /* WOLFSSL_HAVE_MLKEM && !WOLFSSL_WC_MLKEM */
