13#ifndef LUX_GPU_BACKEND_PLUGIN_H
14#define LUX_GPU_BACKEND_PLUGIN_H
33#define LUX_GPU_BACKEND_ABI_VERSION 5
133 const uint64_t* a,
const uint64_t* b, uint64_t* result,
134 size_t n, uint64_t modulus
147 const uint64_t* lwe_in, uint64_t* lwe_out,
148 const uint64_t* bsk,
const uint64_t* test_poly,
149 uint32_t n_lwe, uint32_t N, uint32_t k, uint32_t l,
150 uint32_t base_log, uint64_t q
161 const uint64_t* lwe_in, uint64_t* lwe_out,
163 uint32_t n_in, uint32_t n_out, uint32_t l, uint32_t base_log, uint64_t q
169 uint64_t* acc,
const uint64_t* bsk,
const uint64_t* lwe_a,
170 uint32_t n_lwe, uint32_t N, uint32_t k, uint32_t l,
171 uint32_t base_log, uint64_t q
177 const uint64_t* inputs, uint64_t* outputs,
178 size_t rate,
size_t num_hashes
183 const uint8_t* inputs, uint8_t* outputs,
184 const size_t* input_lens,
size_t num_hashes
189 const uint8_t* inputs, uint8_t* outputs,
190 const size_t* input_lens,
size_t num_inputs
196 const void* a,
const void* b,
void* out,
size_t n,
bool is_g2
200 const void* points,
const void* scalars,
void* out,
size_t n,
bool is_g2
204 const void* g1_points,
const void* g2_points,
void* out,
size_t n
210 const void* a,
const void* b,
void* out,
size_t n,
bool is_g2
214 const void* points,
const void* scalars,
void* out,
size_t n,
bool is_g2
220 const void* scalars,
const void* points,
void* result,
221 size_t n,
int curve_type
227 const void* coeffs,
const void* srs,
void* commitment,
228 size_t degree,
int curve_type
232 const void* coeffs,
const void* srs,
const void* point,
void* proof,
233 size_t degree,
int curve_type
237 const void* commitment,
const void* proof,
238 const void* point,
const void* value,
const void* srs_g2,
239 bool* result,
int curve_type
245 const void* signatures,
void* addresses,
size_t num_signatures
256 const uint8_t*
const* pubkeys,
257 const uint8_t*
const* messages,
258 const uint8_t*
const* signatures,
267 const uint8_t*
const* secret_keys,
268 const uint8_t*
const* ciphertexts,
269 uint8_t** shared_secrets,
277 const uint8_t*
const* pubkeys,
278 const uint8_t*
const* messages,
279 const uint8_t*
const* signatures,
292 const uint8_t*
const* shares,
293 const uint8_t*
const* messages,
294 uint8_t** partial_sigs,
302 const uint8_t*
const* partial_sigs,
303 const int32_t* lagrange_coeffs,
304 uint8_t** combined_sigs,
314 const uint8_t*
const* commitments,
315 const uint8_t*
const* signatures,
316 const uint8_t*
const* pubkeys,
317 const uint8_t*
const* challenges,
328 const uint8_t*
const* inputs,
330 uint8_t** partial_sigs,
342 const uint8_t*
const* pubkeys,
343 const uint8_t*
const* messages,
344 const uint8_t*
const* signatures,
354 const uint8_t*
const* pubkeys,
355 const uint8_t*
const* messages,
356 const uint8_t*
const* signatures,
387#define LUX_CAP_TENSOR_OPS (1u << 0)
388#define LUX_CAP_MATMUL (1u << 1)
389#define LUX_CAP_NTT (1u << 2)
390#define LUX_CAP_MSM (1u << 3)
391#define LUX_CAP_UNIFIED_MEMORY (1u << 4)
392#define LUX_CAP_FHE (1u << 5)
393#define LUX_CAP_TFHE (1u << 6)
394#define LUX_CAP_REDUCE (1u << 7)
395#define LUX_CAP_SOFTMAX (1u << 8)
396#define LUX_CAP_UNARY (1u << 9)
397#define LUX_CAP_NORMALIZATION (1u << 10)
398#define LUX_CAP_BLS12_381 (1u << 11)
399#define LUX_CAP_BN254 (1u << 12)
400#define LUX_CAP_KZG (1u << 13)
401#define LUX_CAP_POSEIDON2 (1u << 14)
402#define LUX_CAP_BLAKE3 (1u << 15)
403#define LUX_CAP_KECCAK256 (1u << 16)
404#define LUX_CAP_ECRECOVER (1u << 17)
405#define LUX_CAP_BLIND_ROTATE (1u << 18)
406#define LUX_CAP_POLY_MUL (1u << 19)
407#define LUX_CAP_MLDSA (1u << 20)
408#define LUX_CAP_MLKEM (1u << 21)
409#define LUX_CAP_SLHDSA (1u << 22)
410#define LUX_CAP_RINGTAIL (1u << 23)
411#define LUX_CAP_FROST (1u << 24)
412#define LUX_CAP_CGGMP21 (1u << 25)
413#define LUX_CAP_ED25519 (1u << 26)
414#define LUX_CAP_SR25519 (1u << 27)
424#define LUX_GPU_BACKEND_INIT_SYMBOL "lux_gpu_backend_init"
427#define LUX_GPU_BACKEND_EXPORT __declspec(dllexport)
429#define LUX_GPU_BACKEND_EXPORT __attribute__((visibility("default")))
433#define LUX_GPU_DECLARE_BACKEND(init_func) \
434 extern "C" LUX_GPU_BACKEND_EXPORT bool \
435 lux_gpu_backend_init(lux_gpu_backend_desc* out) { return init_func(out); }
437#define LUX_GPU_DECLARE_BACKEND(init_func) \
438 LUX_GPU_BACKEND_EXPORT bool \
439 lux_gpu_backend_init(lux_gpu_backend_desc* out) { return init_func(out); }
struct LuxBackendContext LuxBackendContext
bool(* lux_gpu_backend_init_fn)(lux_gpu_backend_desc *out)
@ LUX_BACKEND_ERROR_INTERNAL
@ LUX_BACKEND_ERROR_NOT_SUPPORTED
@ LUX_BACKEND_ERROR_DEVICE_LOST
@ LUX_BACKEND_ERROR_INVALID_ARGUMENT
@ LUX_BACKEND_ERROR_OUT_OF_MEMORY
struct LuxBackendBuffer LuxBackendBuffer
uint64_t memory_available
const lux_gpu_backend_vtbl * vtbl
const char * backend_version
const char * backend_name
LuxBackendError(* op_kzg_verify)(LuxBackendContext *ctx, const void *commitment, const void *proof, const void *point, const void *value, const void *srs_g2, bool *result, int curve_type)
LuxBackendError(* op_sqrt_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_poly_mul)(LuxBackendContext *ctx, const uint64_t *a, const uint64_t *b, uint64_t *result, size_t n, uint64_t modulus)
LuxBackendError(* op_mldsa_verify_batch)(LuxBackendContext *ctx, const uint8_t *const *pubkeys, const uint8_t *const *messages, const uint8_t *const *signatures, bool *results, size_t count)
LuxBackendError(* op_neg_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_bls12_381_mul)(LuxBackendContext *ctx, const void *points, const void *scalars, void *out, size_t n, bool is_g2)
LuxBackendError(* op_reduce_sum_axis_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t outer_size, size_t inner_size)
LuxBackendError(* get_device_info)(LuxBackendContext *ctx, LuxBackendDeviceInfo *info)
LuxBackendError(* op_reduce_max_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_reduce_mean_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_abs_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_msm)(LuxBackendContext *ctx, const void *scalars, const void *points, void *result, size_t n, int curve_type)
LuxBackendError(* op_ecrecover_batch)(LuxBackendContext *ctx, const void *signatures, void *addresses, size_t num_signatures)
LuxBackendError(* op_mlkem_decapsulate_batch)(LuxBackendContext *ctx, const uint8_t *const *secret_keys, const uint8_t *const *ciphertexts, uint8_t **shared_secrets, size_t count)
LuxBackendError(* op_gelu_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_ntt_forward)(LuxBackendContext *ctx, uint64_t *data, size_t n, uint64_t modulus)
LuxBackendError(* op_sigmoid_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_log_softmax_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t batch_size, size_t dim)
LuxBackendError(* op_blind_rotate)(LuxBackendContext *ctx, uint64_t *acc, const uint64_t *bsk, const uint64_t *lwe_a, uint32_t n_lwe, uint32_t N, uint32_t k, uint32_t l, uint32_t base_log, uint64_t q)
LuxBackendError(* op_kzg_open)(LuxBackendContext *ctx, const void *coeffs, const void *srs, const void *point, void *proof, size_t degree, int curve_type)
LuxBackendError(* op_softmax_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t batch_size, size_t dim)
LuxBackendError(* op_ringtail_combine_batch)(LuxBackendContext *ctx, const uint8_t *const *partial_sigs, const int32_t *lagrange_coeffs, uint8_t **combined_sigs, size_t threshold, size_t count)
LuxBackendError(* op_tanh_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_cggmp21_partial_sign_batch)(LuxBackendContext *ctx, const uint8_t *const *inputs, const uint8_t *r_x, uint8_t **partial_sigs, size_t count)
LuxBackendError(* op_bls12_381_pairing)(LuxBackendContext *ctx, const void *g1_points, const void *g2_points, void *out, size_t n)
LuxBackendError(* op_reduce_max_axis_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t outer_size, size_t inner_size)
LuxBackendError(* op_ringtail_partial_sign_batch)(LuxBackendContext *ctx, const uint8_t *const *shares, const uint8_t *const *messages, uint8_t **partial_sigs, size_t count)
LuxBackendError(* op_kzg_commit)(LuxBackendContext *ctx, const void *coeffs, const void *srs, void *commitment, size_t degree, int curve_type)
void(* destroy_context)(LuxBackendContext *ctx)
LuxBackendError(* op_sub_f32)(LuxBackendContext *ctx, LuxBackendBuffer *a, LuxBackendBuffer *b, LuxBackendBuffer *out, size_t n)
LuxBackendError(* get_device_count)(int *count)
LuxBackendError(* op_copy_f32)(LuxBackendContext *ctx, LuxBackendBuffer *src, LuxBackendBuffer *dst, size_t n)
LuxBackendError(* op_exp_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_bls12_381_add)(LuxBackendContext *ctx, const void *a, const void *b, void *out, size_t n, bool is_g2)
LuxBackendError(* sync)(LuxBackendContext *ctx)
LuxBackendError(* op_slhdsa_verify_batch)(LuxBackendContext *ctx, const uint8_t *const *pubkeys, const uint8_t *const *messages, const uint8_t *const *signatures, bool *results, size_t count)
LuxBackendError(* op_blake3_hash)(LuxBackendContext *ctx, const uint8_t *inputs, uint8_t *outputs, const size_t *input_lens, size_t num_hashes)
LuxBackendError(* op_rms_norm_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, LuxBackendBuffer *weight, size_t batch_size, size_t dim, float eps)
LuxBackendError(* op_tfhe_bootstrap)(LuxBackendContext *ctx, const uint64_t *lwe_in, uint64_t *lwe_out, const uint64_t *bsk, const uint64_t *test_poly, uint32_t n_lwe, uint32_t N, uint32_t k, uint32_t l, uint32_t base_log, uint64_t q)
LuxBackendError(* op_layer_norm_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, LuxBackendBuffer *gamma, LuxBackendBuffer *beta, size_t batch_size, size_t dim, float eps)
LuxBackendError(* op_log_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_div_f32)(LuxBackendContext *ctx, LuxBackendBuffer *a, LuxBackendBuffer *b, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_matmul_f32)(LuxBackendContext *ctx, LuxBackendBuffer *a, LuxBackendBuffer *b, LuxBackendBuffer *out, int M, int K, int N)
LuxBackendError(* op_bn254_add)(LuxBackendContext *ctx, const void *a, const void *b, void *out, size_t n, bool is_g2)
LuxBackendError(* op_reduce_sum_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_keccak256_hash)(LuxBackendContext *ctx, const uint8_t *inputs, uint8_t *outputs, const size_t *input_lens, size_t num_inputs)
LuxBackendError(* op_ed25519_verify_batch)(LuxBackendContext *ctx, const uint8_t *const *pubkeys, const uint8_t *const *messages, const uint8_t *const *signatures, bool *results, size_t count)
LuxBackendError(* op_poseidon2_hash)(LuxBackendContext *ctx, const uint64_t *inputs, uint64_t *outputs, size_t rate, size_t num_hashes)
LuxBackendError(* op_add_f32)(LuxBackendContext *ctx, LuxBackendBuffer *a, LuxBackendBuffer *b, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_transpose_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, int rows, int cols)
LuxBackendError(* op_mul_f32)(LuxBackendContext *ctx, LuxBackendBuffer *a, LuxBackendBuffer *b, LuxBackendBuffer *out, size_t n)
LuxBackendError(* buffer_copy_from_host)(LuxBackendContext *ctx, LuxBackendBuffer *buf, const void *src, size_t bytes)
LuxBackendError(* op_relu_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_tfhe_keyswitch)(LuxBackendContext *ctx, const uint64_t *lwe_in, uint64_t *lwe_out, const uint64_t *ksk, uint32_t n_in, uint32_t n_out, uint32_t l, uint32_t base_log, uint64_t q)
void(* buffer_free)(LuxBackendContext *ctx, LuxBackendBuffer *buf)
LuxBackendError(* buffer_copy_to_host)(LuxBackendContext *ctx, LuxBackendBuffer *buf, void *dst, size_t bytes)
LuxBackendError(* op_reduce_min_f32)(LuxBackendContext *ctx, LuxBackendBuffer *in, LuxBackendBuffer *out, size_t n)
LuxBackendError(* op_bn254_mul)(LuxBackendContext *ctx, const void *points, const void *scalars, void *out, size_t n, bool is_g2)
LuxBackendError(* op_frost_partial_verify_batch)(LuxBackendContext *ctx, const uint8_t *const *commitments, const uint8_t *const *signatures, const uint8_t *const *pubkeys, const uint8_t *const *challenges, bool *results, size_t count)
LuxBackendError(* op_sr25519_verify_batch)(LuxBackendContext *ctx, const uint8_t *const *pubkeys, const uint8_t *const *messages, const uint8_t *const *signatures, bool *results, size_t count)
LuxBackendError(* op_ntt_inverse)(LuxBackendContext *ctx, uint64_t *data, size_t n, uint64_t modulus)