Lux GPU Core 0.2.0
Lightweight plugin-based GPU acceleration for blockchain and ML
Loading...
Searching...
No Matches
gpu.h
Go to the documentation of this file.
1// Copyright (c) 2024-2026 Lux Industries Inc.
2// SPDX-License-Identifier: BSD-3-Clause-Eco
3//
4// Lux GPU - Unified GPU acceleration with switchable backends
5//
6// Backends:
7// - Metal: Apple Silicon (macOS/iOS)
8// - CUDA: NVIDIA GPUs
9// - Dawn: WebGPU via Dawn (cross-platform)
10// - CPU: SIMD-optimized fallback
11//
12// Usage:
13// #include <lux/gpu.h>
14//
15// LuxGPU* gpu = lux_gpu_create();
16// lux_gpu_set_backend(gpu, LUX_BACKEND_METAL);
17//
18// LuxTensor* a = lux_tensor_zeros(gpu, shape, 2, LUX_FLOAT32);
19// LuxTensor* b = lux_tensor_ones(gpu, shape, 2, LUX_FLOAT32);
20// LuxTensor* c = lux_tensor_add(gpu, a, b);
21//
22// lux_gpu_sync(gpu);
23// lux_gpu_destroy(gpu);
24
25#ifndef LUX_GPU_H
26#define LUX_GPU_H
27
28#include <stddef.h>
29#include <stdint.h>
30#include <stdbool.h>
31
32#ifdef __cplusplus
33extern "C" {
34#endif
35
36// =============================================================================
37// Version
38// =============================================================================
39
40#define LUX_GPU_VERSION_MAJOR 0
41#define LUX_GPU_VERSION_MINOR 2
42#define LUX_GPU_VERSION_PATCH 0
43
44// =============================================================================
45// Backend Types
46// =============================================================================
47
48typedef enum {
49 LUX_BACKEND_AUTO = 0, // Auto-detect best backend
50 LUX_BACKEND_CPU = 1, // CPU with SIMD
51 LUX_BACKEND_METAL = 2, // Apple Metal
52 LUX_BACKEND_CUDA = 3, // NVIDIA CUDA
53 LUX_BACKEND_DAWN = 4, // WebGPU via Dawn
55
66
76
77// =============================================================================
78// Curve Types (for crypto operations)
79// =============================================================================
80
87
88// =============================================================================
89// Opaque Types
90// =============================================================================
91
92typedef struct LuxGPU LuxGPU;
93typedef struct LuxTensor LuxTensor;
94typedef struct LuxStream LuxStream;
95typedef struct LuxEvent LuxEvent;
96
97// =============================================================================
98// Device Info
99// =============================================================================
100
113
114// =============================================================================
115// GPU Context
116// =============================================================================
117
118// Create GPU context (auto-detects best backend)
120
121// Create GPU context with specific backend
123
124// Create GPU context with specific device
125LuxGPU* lux_gpu_create_with_device(LuxBackend backend, int device_index);
126
127// Destroy GPU context
129
130// Get current backend
132
133// Get backend name string
134const char* lux_gpu_backend_name(LuxGPU* gpu);
135
136// Switch backend at runtime. Returns LUX_ERROR_INVALID_ARGUMENT if any
137// LuxTensor created against the current backend is still alive — destroy
138// outstanding tensors before swapping backends. Returns
139// LUX_ERROR_BACKEND_NOT_AVAILABLE if the target backend isn't loadable.
141
142// Get device info
144
145// Synchronize all operations
147
148// Get last error message
149const char* lux_gpu_error(LuxGPU* gpu);
150
151// =============================================================================
152// Backend Query
153// =============================================================================
154
155// Get number of available backends
157
158// Check if backend is available
160
161// Get backend name
162const char* lux_backend_name(LuxBackend backend);
163
164// Get number of devices for backend
166
167// Get device info for backend/index
169
170// =============================================================================
171// Tensor Operations
172// =============================================================================
173
174// Create tensor filled with zeros
175LuxTensor* lux_tensor_zeros(LuxGPU* gpu, const int64_t* shape, int ndim, LuxDtype dtype);
176
177// Create tensor filled with ones
178LuxTensor* lux_tensor_ones(LuxGPU* gpu, const int64_t* shape, int ndim, LuxDtype dtype);
179
180// Create tensor filled with value
181LuxTensor* lux_tensor_full(LuxGPU* gpu, const int64_t* shape, int ndim, LuxDtype dtype, double value);
182
183// Create tensor from data
184LuxTensor* lux_tensor_from_data(LuxGPU* gpu, const void* data, const int64_t* shape, int ndim, LuxDtype dtype);
185
186// Destroy tensor
188
189// Get tensor shape
191int64_t lux_tensor_shape(LuxTensor* tensor, int dim);
192int64_t lux_tensor_size(LuxTensor* tensor);
194
195// Copy tensor data to host
196LuxError lux_tensor_to_host(LuxTensor* tensor, void* data, size_t size);
197
198// Arithmetic operations
204
205// Unary operations
215
216// Reductions (full tensor -> scalar)
221
222// Reductions along axes
223LuxTensor* lux_tensor_sum(LuxGPU* gpu, LuxTensor* t, const int* axes, int naxes);
224LuxTensor* lux_tensor_mean(LuxGPU* gpu, LuxTensor* t, const int* axes, int naxes);
225LuxTensor* lux_tensor_max(LuxGPU* gpu, LuxTensor* t, const int* axes, int naxes);
226LuxTensor* lux_tensor_min(LuxGPU* gpu, LuxTensor* t, const int* axes, int naxes);
227
228// Softmax and normalization
233
234// Transpose and copy
237
238// =============================================================================
239// Crypto Operations: Hash Functions
240// =============================================================================
241
242// Poseidon2 hash (algebraic hash for ZK circuits)
244 const uint64_t* inputs, // [num_hashes * rate]
245 uint64_t* outputs, // [num_hashes]
246 size_t rate, // Poseidon rate parameter
247 size_t num_hashes);
248
249// BLAKE3 hash (high-performance cryptographic hash)
251 const uint8_t* inputs, // Concatenated inputs
252 uint8_t* outputs, // [num_hashes * 32]
253 const size_t* input_lens, // Length of each input
254 size_t num_hashes);
255
256// Keccak-256 hash (Ethereum variant, NOT NIST SHA-3)
257// - Padding: 0x01 || 0x00...0x00 || 0x80 (Keccak, not SHA-3's 0x06)
258// - Output: 32 bytes per input
259// - Primary use: EVM state trie hashing, address derivation
261 const uint8_t* inputs, // Concatenated inputs
262 uint8_t* outputs, // [num_inputs * 32]
263 const size_t* input_lens, // Length of each input
264 size_t num_inputs);
265
266// =============================================================================
267// Crypto Operations: secp256k1 ECDSA Recovery (Ethereum ecrecover)
268// =============================================================================
269
270// Packed signature for ecrecover batch operations.
271// Each entry: r[32] || s[32] || v[1] || pad[3] || msg_hash[32] || pad[28] = 128 bytes
272typedef struct {
273 uint8_t r[32]; // ECDSA r value (big-endian)
274 uint8_t s[32]; // ECDSA s value (big-endian)
275 uint8_t v; // Recovery id (0 or 1)
276 uint8_t _pad[3]; // Alignment padding
277 uint8_t msg_hash[32]; // Message hash (big-endian)
278 uint8_t _pad2[28]; // Pad to 128 bytes
280
281// Output of ecrecover: recovered Ethereum address.
282typedef struct {
283 uint8_t address[20]; // Recovered address (or zeros on failure)
284 uint8_t valid; // 1 if recovery succeeded, 0 otherwise
285 uint8_t _pad[11]; // Pad to 32 bytes
287
288// Batch secp256k1 ECDSA public key recovery → Ethereum address.
289//
290// For each signature (r, s, v, msg_hash):
291// 1. Recover public key Q from the ECDSA signature
292// 2. Compute address = keccak256(Q.x || Q.y)[12:]
293//
294// This is the EVM ecrecover precompile, batched for GPU parallelism.
295// Each GPU thread processes one signature independently.
296//
297// =============================================================================
298// Signature malleability — low-s vs high-s
299// =============================================================================
300// This batch accepts BOTH low-s (s ≤ n/2) and high-s (s > n/2) signatures.
301// That matches the Ethereum 0x01 ecrecover precompile semantics — the
302// precompile recovers an address from any (r, s, v) triple in range
303// regardless of which side of n/2 s falls on. It does NOT match EIP-2's
304// strict low-s rule that ethereum txpool / EIP-155 transactions enforce
305// at the consensus layer above the precompile.
306//
307// Address recovery is unchanged by s-malleability: (r, s, v) and (r, n-s, v')
308// produce the SAME recovered public key (and therefore the same address) up
309// to a flip of the recovery-id parity. Callers that need to reject malleable
310// signatures (EIP-2 enforcement, txpool admission, replay protection on
311// non-precompile signature surfaces) MUST check `s <= n/2` separately —
312// this function does not.
313//
314// Returns LUX_OK on success (individual failures are indicated by valid=0
315// in the output; the batch call itself only fails on argument errors).
317 const LuxEcrecoverInput* signatures,
318 LuxEcrecoverOutput* addresses,
319 size_t num_signatures);
320
321// =============================================================================
322// Crypto Operations: MSM (Multi-Scalar Multiplication)
323// =============================================================================
324
326 const void* scalars, // Scalar field elements
327 const void* points, // Curve points (affine)
328 void* result, // Single output point
329 size_t count, // Number of scalar-point pairs
330 LuxCurve curve); // Which curve to use
331
332// =============================================================================
333// Crypto Operations: BLS12-381 Curve
334// =============================================================================
335
336// Point addition (G1 or G2)
338 const void* a, const void* b, void* out,
339 size_t count, bool is_g2);
340
341// Scalar multiplication (G1 or G2)
343 const void* points, const void* scalars, void* out,
344 size_t count, bool is_g2);
345
346// Pairing computation (multi-pairing for efficiency)
348 const void* g1_points, const void* g2_points,
349 void* out, size_t count);
350
351// High-level BLS signature verification
353 const uint8_t* sig, size_t sig_len,
354 const uint8_t* msg, size_t msg_len,
355 const uint8_t* pubkey, size_t pubkey_len,
356 bool* result);
357
359 const uint8_t* const* sigs, const size_t* sig_lens,
360 const uint8_t* const* msgs, const size_t* msg_lens,
361 const uint8_t* const* pubkeys, const size_t* pubkey_lens,
362 int count, bool* results);
363
365 const uint8_t* const* sigs, const size_t* sig_lens,
366 int count, uint8_t* out, size_t* out_len);
367
368// =============================================================================
369// Crypto Operations: BN254 Curve
370// =============================================================================
371
372// Point addition (G1 or G2)
374 const void* a, const void* b, void* out,
375 size_t count, bool is_g2);
376
377// Scalar multiplication (G1 or G2)
379 const void* points, const void* scalars, void* out,
380 size_t count, bool is_g2);
381
382// =============================================================================
383// Crypto Operations: KZG Polynomial Commitments
384// =============================================================================
385
386// Commit to polynomial using SRS
388 const void* coeffs, // Polynomial coefficients
389 const void* srs, // SRS G1 points
390 void* commitment, // Output commitment
391 size_t degree, // Polynomial degree
392 LuxCurve curve);
393
394// Open commitment at evaluation point
396 const void* coeffs, // Polynomial coefficients
397 const void* srs, // SRS G1 points
398 const void* point, // Evaluation point
399 void* proof, // Output proof
400 size_t degree, // Polynomial degree
401 LuxCurve curve);
402
403// Verify KZG opening proof
405 const void* commitment, // Commitment point
406 const void* proof, // Proof point
407 const void* point, // Evaluation point
408 const void* value, // Claimed evaluation
409 const void* srs_g2, // G2 element from SRS
410 bool* result, // Verification result
411 LuxCurve curve);
412
413// =============================================================================
414// FHE Operations: NTT (Number Theoretic Transform)
415// =============================================================================
416
417LuxError lux_ntt_forward(LuxGPU* gpu, uint64_t* data, size_t n, uint64_t modulus);
418LuxError lux_ntt_inverse(LuxGPU* gpu, uint64_t* data, size_t n, uint64_t modulus);
419LuxError lux_ntt_batch(LuxGPU* gpu, uint64_t** polys, size_t count, size_t n, uint64_t modulus);
420
421// =============================================================================
422// FHE Operations: Polynomial Arithmetic
423// =============================================================================
424
425// Polynomial multiplication: result = a * b mod (X^n + 1) mod modulus
427 const uint64_t* a, const uint64_t* b,
428 uint64_t* result, size_t n, uint64_t modulus);
429
430// =============================================================================
431// FHE Operations: TFHE
432// =============================================================================
433
434// TFHE programmable bootstrap: evaluates the LUT encoded in test_poly on the
435// encrypted input. BSK shape: [n_lwe][(k+1)*l][k+1][N] u64. lwe_out length
436// is k*N + 1; k = 0 is rejected as INVALID_ARGUMENT. The gadget contract
437// B = 2^base_log requires l * base_log <= log2(q); otherwise the bottom
438// gadget level collapses to zero.
440 const uint64_t* lwe_in, // Input LWE [n_lwe + 1]
441 uint64_t* lwe_out, // Output LWE [k*N + 1]
442 const uint64_t* bsk, // Bootstrapping key
443 const uint64_t* test_poly, // Test polynomial (LUT)
444 uint32_t n_lwe, // Input LWE dimension
445 uint32_t N, // GLWE polynomial degree (power of two)
446 uint32_t k, // GLWE dimension (>= 1)
447 uint32_t l, // Decomposition levels
448 uint32_t base_log, // Bits per gadget digit (B = 2^base_log)
449 uint64_t q); // Modulus
450
451// TFHE key switching: changes LWE key. KSK rows encode an LWE encryption
452// (under the OUT key) of -s_{in_idx} * q / B^{level+1}, B = 2^base_log.
454 const uint64_t* lwe_in, // Input LWE [n_in + 1]
455 uint64_t* lwe_out, // Output LWE [n_out + 1]
456 const uint64_t* ksk, // Key switching key
457 uint32_t n_in, // Input dimension
458 uint32_t n_out, // Output dimension
459 uint32_t l, // Decomposition levels
460 uint32_t base_log, // Base log
461 uint64_t q); // Modulus
462
463// Blind rotation: rotates polynomial accumulator by encrypted amount.
464// Same BSK shape and gadget contract as lux_tfhe_bootstrap.
466 uint64_t* acc, // Accumulator GLWE [(k+1) * N]
467 const uint64_t* bsk, // Bootstrapping key
468 const uint64_t* lwe_a, // LWE 'a' coefficients [n_lwe]
469 uint32_t n_lwe, // LWE dimension
470 uint32_t N, // GLWE polynomial degree (power of two)
471 uint32_t k, // GLWE dimension (>= 1)
472 uint32_t l, // Decomposition levels
473 uint32_t base_log, // Bits per gadget digit (B = 2^base_log)
474 uint64_t q); // Modulus
475
476// =============================================================================
477// FHE Helpers — small inspectors and parameter validators exposed so that
478// callers can pre-check inputs before dispatching to the heavy ops above.
479// Each helper is stateless and constant-time over its inputs.
480// =============================================================================
481
482// Returns true iff N is a power of two in (0, 2^20]. Matches the validation
483// inside lux_tfhe_bootstrap / lux_blind_rotate.
484bool lux_fhe_is_valid_N(uint32_t N);
485
486// Returns true iff (l, base_log) satisfies the gadget contract on a 64-bit
487// modulus q: l in [1, 64], base_log in [1, 64], l*base_log <= 64. The bottom
488// gadget level q / B^l collapses to zero when l*base_log >= 64.
489bool lux_fhe_is_valid_gadget(uint32_t l, uint32_t base_log);
490
491// Returns true iff the full TFHE-AP PBS parameter set is well-formed:
492// k >= 1, N a power of two in (0, 2^20], gadget contract holds, q != 0.
493bool lux_fhe_is_valid_pbs(uint32_t n_lwe, uint32_t N, uint32_t k,
494 uint32_t l, uint32_t base_log, uint64_t q);
495
496// Total BSK length in u64 words: n_lwe * (k+1) * l * (k+1) * N.
497// Returns 0 if the parameters fail lux_fhe_is_valid_pbs.
498size_t lux_fhe_bsk_words(uint32_t n_lwe, uint32_t N, uint32_t k, uint32_t l);
499
500// KSK length in u64 words: n_in * l * (n_out + 1).
501// Returns 0 if any of (n_in, n_out, l, base_log) is degenerate.
502size_t lux_fhe_ksk_words(uint32_t n_in, uint32_t n_out,
503 uint32_t l, uint32_t base_log);
504
505// Output LWE length for sample-extract: k*N + 1.
506// Returns 0 if k = 0 or N is not a power of two.
507size_t lux_fhe_lwe_out_words(uint32_t N, uint32_t k);
508
509// Accumulator GLWE length: (k+1) * N.
510// Returns 0 if k = 0 or N is not a power of two.
511size_t lux_fhe_acc_words(uint32_t N, uint32_t k);
512
513// Suggested base_log for a given l on a 64-bit modulus q: floor(log2(q) / l)
514// clamped to [1, 64]. Returns 0 when l = 0 or q = 0.
515uint32_t lux_fhe_suggest_base_log(uint32_t l, uint64_t q);
516
517// Signed-digit decomposition of one u64 value at gadget level under base_log.
518// Returns the centered digit (range [-2^{base_log-1}, 2^{base_log-1})).
519// Helper exposed so callers can build BSK / KSK fixtures with the exact
520// gadget encoding the bootstrap consumes.
521//
522// IMPORTANT: this is the legacy per-digit form. It extracts each digit
523// INDEPENDENTLY (no carry propagation), so for typical inputs the per-digit
524// residue can be as large as q/B. For numerically-stable PBS / keyswitch use
525// lux_fhe_signed_decomp_all below, which carry-propagates and bounds the
526// total residue at q/B^l.
527int64_t lux_fhe_signed_decomp_digit(uint64_t value,
528 uint32_t level, uint32_t base_log);
529
530// Carry-propagating signed-radix decomposition. Writes l digits into out[],
531// top-down: out[0] has gadget weight q/B (largest), out[l-1] has weight
532// q/B^l (smallest). |out[lvl]| <= B/2. Aggregate approximation error is
533// bounded by q/B^l, matching OpenFHE's SignedDigitDecompose. Returns false
534// (and leaves out[] unchanged) if (l, base_log) fail lux_fhe_is_valid_gadget.
535bool lux_fhe_signed_decomp_all(uint64_t value, uint32_t l, uint32_t base_log,
536 int64_t* out);
537
538// Compute a_tilde = round(a * 2N / q) mod 2N (the rotation amount used by the
539// blind-rotation loop). Exposes the exact rounding the implementation uses.
540uint32_t lux_fhe_compute_a_tilde(uint64_t a, uint32_t N, uint64_t q);
541
542// Encode plaintext m ∈ [0, modulus) as Δ · m mod q where Δ = q / modulus.
543// Test-vector helper to keep callers from re-implementing the same encoding
544// inconsistently across the test suite.
545uint64_t lux_fhe_encode_message(uint64_t m, uint64_t modulus, uint64_t q);
546
547// Decode an LWE phase to its nearest multiple of Δ = q / modulus, returning
548// the message in [0, modulus). Inverse of lux_fhe_encode_message under noise
549// up to Δ/2.
550uint64_t lux_fhe_decode_phase(uint64_t phase, uint64_t modulus, uint64_t q);
551
552// Returns the canonical "plateau" test polynomial used by the FHE test
553// suite: coefficients in [N/4, 3N/4) take value Delta, all other slots are
554// zero. Writes N words into out. Returns false if N is not a valid PBS N.
555// Use to drive bootstrap correctness tests with a wide noise margin.
556bool lux_fhe_test_poly_plateau(uint32_t N, uint64_t Delta, uint64_t* out);
557
558// Returns the canonical "half-plane" test polynomial: 0 for i < N/2, Delta
559// for i >= N/2. Conventional TFHE LUT for a single bit; sensitive to
560// noise at the N/2 boundary.
561bool lux_fhe_test_poly_half(uint32_t N, uint64_t Delta, uint64_t* out);
562
563// Returns gadget[lvl] = q >> ((lvl+1) * base_log). Convenience that mirrors
564// the bootstrap and keyswitch internal gadget construction exactly.
565// Returns 0 on invalid (l, base_log) — note this matches the contract that
566// the bottom gadget collapses to 0 when l*base_log >= 64.
567uint64_t lux_fhe_gadget_value(uint32_t level, uint32_t base_log, uint64_t q);
568
569// Reconstruct val ≈ Σ digit[lvl] · gadget[lvl] mod q, in canonical Z_q.
570// `digits` is l-long, top-down, signed. Returns 0 on validation failure.
571// The reconstruction error |reconstructed - val| is bounded by q/B^l for
572// digits returned from lux_fhe_signed_decomp_all.
573uint64_t lux_fhe_gadget_reconstruct(const int64_t* digits, uint32_t l,
574 uint32_t base_log, uint64_t q);
575
576// Total decomp scratch length in i64 words: (k+1) * l * N. Mirrors the
577// internal scratch the blind-rotation step allocates per AP iteration.
578// Returns 0 on validation failure.
579size_t lux_fhe_decomp_words(uint32_t N, uint32_t k, uint32_t l);
580
581// Validate KSK-shape parameters independently of PBS-shape parameters.
582// Returns true iff (n_in, n_out, l, base_log) is well-formed and q != 0.
583bool lux_fhe_is_valid_keyswitch(uint32_t n_in, uint32_t n_out, uint32_t l,
584 uint32_t base_log, uint64_t q);
585
586// Convenience: returns the LWE-input length for keyswitch = n_in + 1, or 0.
587size_t lux_fhe_keyswitch_in_words(uint32_t n_in);
588
589// Convenience: returns the LWE-output length for keyswitch = n_out + 1, or 0.
590size_t lux_fhe_keyswitch_out_words(uint32_t n_out);
591
592// Returns the current FHE ABI revision so callers can refuse to load a
593// plugin compiled against an older bootstrap signature.
594uint32_t lux_fhe_abi_revision(void);
595
596// =============================================================================
597// ZK Primitives: Field Elements and High-Level Operations
598// =============================================================================
599
600// BN254 scalar field element (Fr) - 256-bit integer in 4 x 64-bit limbs
601// Represents elements of the scalar field of BN254 curve
602typedef struct {
603 uint64_t limbs[4];
604} LuxFr256;
605
606// Poseidon2 compression: out[i] = Poseidon2(left[i], right[i])
607// Poseidon2 is an algebraic hash function optimized for ZK circuits.
609 LuxFr256* out,
610 const LuxFr256* left,
611 const LuxFr256* right,
612 size_t n);
613
614// Merkle tree root computation using Poseidon2 hash
615// Computes root from n leaves (pads to next power of 2 internally)
617 LuxFr256* out,
618 const LuxFr256* leaves,
619 size_t n);
620
621// Pedersen-style commitment: out[i] = Poseidon2(Poseidon2(value, blinding), salt)
622// Suitable for hiding commitments in ZK protocols
624 LuxFr256* out,
625 const LuxFr256* values,
626 const LuxFr256* blindings,
627 const LuxFr256* salts,
628 size_t n);
629
630// Nullifier derivation: out[i] = Poseidon2(Poseidon2(key, commitment), index)
631// Used to prevent double-spending in ZK protocols
633 LuxFr256* out,
634 const LuxFr256* keys,
635 const LuxFr256* commitments,
636 const LuxFr256* indices,
637 size_t n);
638
639// =============================================================================
640// Crypto Operations: Post-Quantum Signatures
641// =============================================================================
642
643// ML-DSA-65 (FIPS 204, CRYSTALS-Dilithium) batch signature verification
644// pubkeys: array of public keys (1952 bytes each)
645// messages: array of message hashes (64 bytes each)
646// signatures: array of signatures (3360 bytes each, padded)
647// results: output boolean array (1=valid, 0=invalid)
649 const uint8_t* const* pubkeys,
650 const uint8_t* const* messages,
651 const uint8_t* const* signatures,
652 bool* results,
653 size_t count);
654
655// ML-KEM-768 (FIPS 203, CRYSTALS-Kyber) batch decapsulation
656// secret_keys: array of decapsulation keys (2400 bytes each)
657// ciphertexts: array of ciphertexts (1088 bytes each)
658// shared_secrets: output array of shared secrets (32 bytes each)
660 const uint8_t* const* secret_keys,
661 const uint8_t* const* ciphertexts,
662 uint8_t** shared_secrets,
663 size_t count);
664
665// SLH-DSA (FIPS 205, SPHINCS+) batch signature verification
666// pubkeys: array of public keys (32 bytes each for SHAKE-128f)
667// messages: array of message hashes (32 bytes each)
668// signatures: array of signatures (up to 17088 bytes each)
670 const uint8_t* const* pubkeys,
671 const uint8_t* const* messages,
672 const uint8_t* const* signatures,
673 bool* results,
674 size_t count);
675
676// =============================================================================
677// Crypto Operations: Threshold Signatures
678// =============================================================================
679
680// Ringtail lattice-based threshold partial signing
681// shares: array of secret shares (1024 bytes each, 256 int32 coefficients)
682// messages: array of message hashes (32 bytes each)
683// partial_sigs: output partial signatures (1024 bytes each)
685 const uint8_t* const* shares,
686 const uint8_t* const* messages,
687 uint8_t** partial_sigs,
688 size_t count);
689
690// Ringtail threshold combine: merge k partial sigs into one
691// partial_sigs: array of partial signatures [count * threshold]
692// lagrange_coeffs: Lagrange interpolation coefficients [count * threshold]
693// combined_sigs: output combined signatures [count]
695 const uint8_t* const* partial_sigs,
696 const int32_t* lagrange_coeffs,
697 uint8_t** combined_sigs,
698 size_t threshold,
699 size_t count);
700
701// FROST threshold Schnorr partial signature verification
702// commitments: participant commitments (66 bytes each)
703// signatures: partial signature scalars (32 bytes each)
704// pubkeys: public key shares (33 bytes each)
705// challenges: pre-computed c*lambda_i scalars (32 bytes each)
707 const uint8_t* const* commitments,
708 const uint8_t* const* signatures,
709 const uint8_t* const* pubkeys,
710 const uint8_t* const* challenges,
711 bool* results,
712 size_t count);
713
714// CGGMP21 threshold ECDSA partial signing
715// inputs: k_share[32] || chi_share[32] || msg_hash[32] || gamma_share[32] per entry
716// r_x: x-coordinate of combined nonce R (32 bytes)
717// partial_sigs: output sigma_i values (32 bytes each)
719 const uint8_t* const* inputs,
720 const uint8_t* r_x,
721 uint8_t** partial_sigs,
722 size_t count);
723
724// =============================================================================
725// Crypto Operations: Ed25519 / sr25519
726// =============================================================================
727
728// Ed25519 batch signature verification
729// pubkeys: 32-byte compressed points
730// messages: 64-byte pre-computed H(R||A||M), reduced mod L by host
731// signatures: 64-byte signatures (R[32] || S[32])
733 const uint8_t* const* pubkeys,
734 const uint8_t* const* messages,
735 const uint8_t* const* signatures,
736 bool* results,
737 size_t count);
738
739// sr25519 (Schnorrkel/Ristretto255) batch signature verification
740// pubkeys: 32-byte Ristretto255 compressed points
741// messages: 64-byte pre-computed transcript hashes
742// signatures: 64-byte signatures (R[32] || s[32])
744 const uint8_t* const* pubkeys,
745 const uint8_t* const* messages,
746 const uint8_t* const* signatures,
747 bool* results,
748 size_t count);
749
750// =============================================================================
751// Stream/Event Management
752// =============================================================================
753
757
763
764#ifdef __cplusplus
765}
766#endif
767
768#endif // LUX_GPU_H
LuxError lux_tensor_to_host(LuxTensor *tensor, void *data, size_t size)
uint64_t lux_fhe_gadget_reconstruct(const int64_t *digits, uint32_t l, uint32_t base_log, uint64_t q)
LuxError lux_ntt_batch(LuxGPU *gpu, uint64_t **polys, size_t count, size_t n, uint64_t modulus)
size_t lux_fhe_bsk_words(uint32_t n_lwe, uint32_t N, uint32_t k, uint32_t l)
uint64_t lux_fhe_encode_message(uint64_t m, uint64_t modulus, uint64_t q)
LuxTensor * lux_tensor_add(LuxGPU *gpu, LuxTensor *a, LuxTensor *b)
void lux_gpu_destroy(LuxGPU *gpu)
LuxTensor * lux_tensor_gelu(LuxGPU *gpu, LuxTensor *t)
LuxTensor * lux_tensor_mean(LuxGPU *gpu, LuxTensor *t, const int *axes, int naxes)
int lux_device_count(LuxBackend backend)
bool lux_fhe_is_valid_gadget(uint32_t l, uint32_t base_log)
LuxError lux_gpu_keccak256_batch(LuxGPU *gpu, const uint8_t *inputs, uint8_t *outputs, const size_t *input_lens, size_t num_inputs)
void lux_event_destroy(LuxEvent *event)
struct LuxTensor LuxTensor
Definition gpu.h:93
LuxDtype lux_tensor_dtype(LuxTensor *tensor)
LuxTensor * lux_tensor_rms_norm(LuxGPU *gpu, LuxTensor *t, LuxTensor *weight, float eps)
LuxError lux_device_info(LuxBackend backend, int index, LuxDeviceInfo *info)
const char * lux_gpu_error(LuxGPU *gpu)
LuxError lux_event_wait(LuxEvent *event, LuxStream *stream)
LuxError lux_event_record(LuxEvent *event, LuxStream *stream)
LuxError lux_gpu_slhdsa_verify_batch(LuxGPU *gpu, const uint8_t *const *pubkeys, const uint8_t *const *messages, const uint8_t *const *signatures, bool *results, size_t count)
float lux_tensor_reduce_min(LuxGPU *gpu, LuxTensor *t)
LuxGPU * lux_gpu_create(void)
LuxTensor * lux_tensor_copy(LuxGPU *gpu, LuxTensor *t)
LuxError lux_gpu_cggmp21_partial_sign_batch(LuxGPU *gpu, const uint8_t *const *inputs, const uint8_t *r_x, uint8_t **partial_sigs, size_t count)
LuxError lux_ntt_inverse(LuxGPU *gpu, uint64_t *data, size_t n, uint64_t modulus)
bool lux_fhe_test_poly_half(uint32_t N, uint64_t Delta, uint64_t *out)
float lux_tensor_reduce_mean(LuxGPU *gpu, LuxTensor *t)
LuxError lux_gpu_sync(LuxGPU *gpu)
struct LuxStream LuxStream
Definition gpu.h:94
const char * lux_gpu_backend_name(LuxGPU *gpu)
LuxTensor * lux_tensor_matmul(LuxGPU *gpu, LuxTensor *a, LuxTensor *b)
bool lux_fhe_test_poly_plateau(uint32_t N, uint64_t Delta, uint64_t *out)
LuxTensor * lux_tensor_softmax(LuxGPU *gpu, LuxTensor *t, int axis)
size_t lux_fhe_decomp_words(uint32_t N, uint32_t k, uint32_t l)
uint32_t lux_fhe_abi_revision(void)
LuxTensor * lux_tensor_sqrt(LuxGPU *gpu, LuxTensor *t)
LuxDtype
Definition gpu.h:56
@ LUX_UINT64
Definition gpu.h:63
@ LUX_INT64
Definition gpu.h:61
@ LUX_BFLOAT16
Definition gpu.h:59
@ LUX_BOOL
Definition gpu.h:64
@ LUX_UINT32
Definition gpu.h:62
@ LUX_INT32
Definition gpu.h:60
@ LUX_FLOAT32
Definition gpu.h:57
@ LUX_FLOAT16
Definition gpu.h:58
LuxError lux_msm(LuxGPU *gpu, const void *scalars, const void *points, void *result, size_t count, LuxCurve curve)
LuxTensor * lux_tensor_from_data(LuxGPU *gpu, const void *data, const int64_t *shape, int ndim, LuxDtype dtype)
LuxError lux_gpu_ecrecover_batch(LuxGPU *gpu, const LuxEcrecoverInput *signatures, LuxEcrecoverOutput *addresses, size_t num_signatures)
LuxError lux_gpu_mlkem_decapsulate_batch(LuxGPU *gpu, const uint8_t *const *secret_keys, const uint8_t *const *ciphertexts, uint8_t **shared_secrets, size_t count)
LuxError lux_gpu_device_info(LuxGPU *gpu, LuxDeviceInfo *info)
LuxTensor * lux_tensor_log_softmax(LuxGPU *gpu, LuxTensor *t, int axis)
uint32_t lux_fhe_compute_a_tilde(uint64_t a, uint32_t N, uint64_t q)
float lux_event_elapsed(LuxEvent *start, LuxEvent *end)
const char * lux_backend_name(LuxBackend backend)
int64_t lux_tensor_shape(LuxTensor *tensor, int dim)
LuxTensor * lux_tensor_zeros(LuxGPU *gpu, const int64_t *shape, int ndim, LuxDtype dtype)
size_t lux_fhe_keyswitch_out_words(uint32_t n_out)
LuxError lux_bn254_add(LuxGPU *gpu, const void *a, const void *b, void *out, size_t count, bool is_g2)
bool lux_fhe_is_valid_pbs(uint32_t n_lwe, uint32_t N, uint32_t k, uint32_t l, uint32_t base_log, uint64_t q)
LuxGPU * lux_gpu_create_with_backend(LuxBackend backend)
LuxError lux_gpu_nullifier(LuxGPU *gpu, LuxFr256 *out, const LuxFr256 *keys, const LuxFr256 *commitments, const LuxFr256 *indices, size_t n)
int64_t lux_tensor_size(LuxTensor *tensor)
LuxError lux_bls12_381_add(LuxGPU *gpu, const void *a, const void *b, void *out, size_t count, bool is_g2)
LuxError lux_poseidon2_hash(LuxGPU *gpu, const uint64_t *inputs, uint64_t *outputs, size_t rate, size_t num_hashes)
void lux_tensor_destroy(LuxTensor *tensor)
LuxError lux_stream_sync(LuxStream *stream)
bool lux_fhe_signed_decomp_all(uint64_t value, uint32_t l, uint32_t base_log, int64_t *out)
LuxError lux_bn254_mul(LuxGPU *gpu, const void *points, const void *scalars, void *out, size_t count, bool is_g2)
size_t lux_fhe_lwe_out_words(uint32_t N, uint32_t k)
size_t lux_fhe_keyswitch_in_words(uint32_t n_in)
LuxTensor * lux_tensor_relu(LuxGPU *gpu, LuxTensor *t)
LuxTensor * lux_tensor_neg(LuxGPU *gpu, LuxTensor *t)
LuxError lux_gpu_sr25519_verify_batch(LuxGPU *gpu, const uint8_t *const *pubkeys, const uint8_t *const *messages, const uint8_t *const *signatures, bool *results, size_t count)
float lux_tensor_reduce_max(LuxGPU *gpu, LuxTensor *t)
bool lux_fhe_is_valid_keyswitch(uint32_t n_in, uint32_t n_out, uint32_t l, uint32_t base_log, uint64_t q)
LuxTensor * lux_tensor_layer_norm(LuxGPU *gpu, LuxTensor *t, LuxTensor *gamma, LuxTensor *beta, float eps)
size_t lux_fhe_acc_words(uint32_t N, uint32_t k)
LuxError lux_tfhe_keyswitch(LuxGPU *gpu, const uint64_t *lwe_in, uint64_t *lwe_out, const uint64_t *ksk, uint32_t n_in, uint32_t n_out, uint32_t l, uint32_t base_log, uint64_t q)
LuxCurve
Definition gpu.h:81
@ LUX_CURVE_BLS12_381
Definition gpu.h:82
@ LUX_CURVE_SECP256K1
Definition gpu.h:84
@ LUX_CURVE_BN254
Definition gpu.h:83
@ LUX_CURVE_ED25519
Definition gpu.h:85
LuxError lux_kzg_open(LuxGPU *gpu, const void *coeffs, const void *srs, const void *point, void *proof, size_t degree, LuxCurve curve)
LuxError lux_kzg_verify(LuxGPU *gpu, const void *commitment, const void *proof, const void *point, const void *value, const void *srs_g2, bool *result, LuxCurve curve)
uint32_t lux_fhe_suggest_base_log(uint32_t l, uint64_t q)
LuxError lux_blake3_hash(LuxGPU *gpu, const uint8_t *inputs, uint8_t *outputs, const size_t *input_lens, size_t num_hashes)
LuxBackend
Definition gpu.h:48
@ LUX_BACKEND_DAWN
Definition gpu.h:53
@ LUX_BACKEND_AUTO
Definition gpu.h:49
@ LUX_BACKEND_CUDA
Definition gpu.h:52
@ LUX_BACKEND_CPU
Definition gpu.h:50
@ LUX_BACKEND_METAL
Definition gpu.h:51
LuxTensor * lux_tensor_min(LuxGPU *gpu, LuxTensor *t, const int *axes, int naxes)
LuxTensor * lux_tensor_transpose(LuxGPU *gpu, LuxTensor *t)
LuxTensor * lux_tensor_sigmoid(LuxGPU *gpu, LuxTensor *t)
LuxTensor * lux_tensor_sum(LuxGPU *gpu, LuxTensor *t, const int *axes, int naxes)
LuxError lux_gpu_commitment(LuxGPU *gpu, LuxFr256 *out, const LuxFr256 *values, const LuxFr256 *blindings, const LuxFr256 *salts, size_t n)
LuxTensor * lux_tensor_exp(LuxGPU *gpu, LuxTensor *t)
LuxTensor * lux_tensor_ones(LuxGPU *gpu, const int64_t *shape, int ndim, LuxDtype dtype)
void lux_stream_destroy(LuxStream *stream)
LuxError lux_kzg_commit(LuxGPU *gpu, const void *coeffs, const void *srs, void *commitment, size_t degree, LuxCurve curve)
LuxTensor * lux_tensor_log(LuxGPU *gpu, LuxTensor *t)
LuxError lux_bls_verify(LuxGPU *gpu, const uint8_t *sig, size_t sig_len, const uint8_t *msg, size_t msg_len, const uint8_t *pubkey, size_t pubkey_len, bool *result)
LuxError lux_gpu_frost_partial_verify_batch(LuxGPU *gpu, const uint8_t *const *commitments, const uint8_t *const *signatures, const uint8_t *const *pubkeys, const uint8_t *const *challenges, bool *results, size_t count)
struct LuxGPU LuxGPU
Definition gpu.h:92
LuxTensor * lux_tensor_mul(LuxGPU *gpu, LuxTensor *a, LuxTensor *b)
size_t lux_fhe_ksk_words(uint32_t n_in, uint32_t n_out, uint32_t l, uint32_t base_log)
LuxEvent * lux_event_create(LuxGPU *gpu)
LuxError lux_ntt_forward(LuxGPU *gpu, uint64_t *data, size_t n, uint64_t modulus)
LuxError lux_gpu_mldsa_verify_batch(LuxGPU *gpu, const uint8_t *const *pubkeys, const uint8_t *const *messages, const uint8_t *const *signatures, bool *results, size_t count)
LuxError lux_bls_aggregate(LuxGPU *gpu, const uint8_t *const *sigs, const size_t *sig_lens, int count, uint8_t *out, size_t *out_len)
float lux_tensor_reduce_sum(LuxGPU *gpu, LuxTensor *t)
LuxBackend lux_gpu_backend(LuxGPU *gpu)
LuxError lux_gpu_set_backend(LuxGPU *gpu, LuxBackend backend)
LuxTensor * lux_tensor_div(LuxGPU *gpu, LuxTensor *a, LuxTensor *b)
LuxError lux_gpu_merkle_root(LuxGPU *gpu, LuxFr256 *out, const LuxFr256 *leaves, size_t n)
LuxError lux_gpu_poseidon2(LuxGPU *gpu, LuxFr256 *out, const LuxFr256 *left, const LuxFr256 *right, size_t n)
int lux_backend_count(void)
LuxError lux_bls_verify_batch(LuxGPU *gpu, const uint8_t *const *sigs, const size_t *sig_lens, const uint8_t *const *msgs, const size_t *msg_lens, const uint8_t *const *pubkeys, const size_t *pubkey_lens, int count, bool *results)
LuxError
Definition gpu.h:67
@ LUX_OK
Definition gpu.h:68
@ LUX_ERROR_INVALID_ARGUMENT
Definition gpu.h:69
@ LUX_ERROR_NOT_SUPPORTED
Definition gpu.h:74
@ LUX_ERROR_KERNEL_FAILED
Definition gpu.h:73
@ LUX_ERROR_DEVICE_NOT_FOUND
Definition gpu.h:72
@ LUX_ERROR_OUT_OF_MEMORY
Definition gpu.h:70
@ LUX_ERROR_BACKEND_NOT_AVAILABLE
Definition gpu.h:71
LuxError lux_gpu_ed25519_verify_batch(LuxGPU *gpu, const uint8_t *const *pubkeys, const uint8_t *const *messages, const uint8_t *const *signatures, bool *results, size_t count)
uint64_t lux_fhe_gadget_value(uint32_t level, uint32_t base_log, uint64_t q)
LuxError lux_gpu_ringtail_combine_batch(LuxGPU *gpu, const uint8_t *const *partial_sigs, const int32_t *lagrange_coeffs, uint8_t **combined_sigs, size_t threshold, size_t count)
int64_t lux_fhe_signed_decomp_digit(uint64_t value, uint32_t level, uint32_t base_log)
LuxError lux_blind_rotate(LuxGPU *gpu, uint64_t *acc, const uint64_t *bsk, const uint64_t *lwe_a, uint32_t n_lwe, uint32_t N, uint32_t k, uint32_t l, uint32_t base_log, uint64_t q)
LuxGPU * lux_gpu_create_with_device(LuxBackend backend, int device_index)
uint64_t lux_fhe_decode_phase(uint64_t phase, uint64_t modulus, uint64_t q)
LuxError lux_gpu_ringtail_partial_sign_batch(LuxGPU *gpu, const uint8_t *const *shares, const uint8_t *const *messages, uint8_t **partial_sigs, size_t count)
LuxError lux_bls12_381_mul(LuxGPU *gpu, const void *points, const void *scalars, void *out, size_t count, bool is_g2)
LuxStream * lux_stream_create(LuxGPU *gpu)
LuxError lux_bls12_381_pairing(LuxGPU *gpu, const void *g1_points, const void *g2_points, void *out, size_t count)
bool lux_fhe_is_valid_N(uint32_t N)
LuxError lux_poly_mul(LuxGPU *gpu, const uint64_t *a, const uint64_t *b, uint64_t *result, size_t n, uint64_t modulus)
LuxTensor * lux_tensor_sub(LuxGPU *gpu, LuxTensor *a, LuxTensor *b)
LuxTensor * lux_tensor_tanh(LuxGPU *gpu, LuxTensor *t)
LuxTensor * lux_tensor_abs(LuxGPU *gpu, LuxTensor *t)
LuxTensor * lux_tensor_full(LuxGPU *gpu, const int64_t *shape, int ndim, LuxDtype dtype, double value)
LuxError lux_tfhe_bootstrap(LuxGPU *gpu, const uint64_t *lwe_in, uint64_t *lwe_out, const uint64_t *bsk, const uint64_t *test_poly, uint32_t n_lwe, uint32_t N, uint32_t k, uint32_t l, uint32_t base_log, uint64_t q)
int lux_tensor_ndim(LuxTensor *tensor)
struct LuxEvent LuxEvent
Definition gpu.h:95
bool lux_backend_available(LuxBackend backend)
LuxTensor * lux_tensor_max(LuxGPU *gpu, LuxTensor *t, const int *axes, int naxes)
int max_workgroup_size
Definition gpu.h:111
uint64_t memory_total
Definition gpu.h:106
bool is_discrete
Definition gpu.h:108
const char * vendor
Definition gpu.h:105
bool is_unified_memory
Definition gpu.h:109
LuxBackend backend
Definition gpu.h:102
const char * name
Definition gpu.h:104
uint64_t memory_available
Definition gpu.h:107
int compute_units
Definition gpu.h:110
int index
Definition gpu.h:103
uint8_t v
Definition gpu.h:275
uint8_t valid
Definition gpu.h:284