// SPDX-License-Identifier: GPL-2.0-only /* * Cryptographic API. * * Copyright (c) 2017-present, Facebook, Inc. */ #include #include #include #include #include #include #include #include #include #include #define ZSTD_DEF_LEVEL 3 #define ZSTD_MAX_WINDOWLOG 18 #define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG) struct zstd_ctx { zstd_cctx *cctx; zstd_dctx *dctx; size_t wksp_size; zstd_parameters params; u8 wksp[] __aligned(8); }; static DEFINE_MUTEX(zstd_stream_lock); static void *zstd_alloc_stream(void) { zstd_parameters params; struct zstd_ctx *ctx; size_t wksp_size; params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE); wksp_size = max_t(size_t, zstd_cstream_workspace_bound(¶ms.cParams), zstd_dstream_workspace_bound(ZSTD_MAX_SIZE)); if (!wksp_size) return ERR_PTR(-EINVAL); ctx = kvmalloc(sizeof(*ctx) + wksp_size, GFP_KERNEL); if (!ctx) return ERR_PTR(-ENOMEM); ctx->params = params; ctx->wksp_size = wksp_size; return ctx; } static void zstd_free_stream(void *ctx) { kvfree(ctx); } static struct crypto_acomp_streams zstd_streams = { .alloc_ctx = zstd_alloc_stream, .free_ctx = zstd_free_stream, }; static int zstd_init(struct crypto_acomp *acomp_tfm) { int ret = 0; mutex_lock(&zstd_stream_lock); ret = crypto_acomp_alloc_streams(&zstd_streams); mutex_unlock(&zstd_stream_lock); return ret; } static void zstd_exit(struct crypto_acomp *acomp_tfm) { crypto_acomp_free_streams(&zstd_streams); } static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx, const void *src, void *dst, unsigned int *dlen) { unsigned int out_len; ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size); if (!ctx->cctx) return -EINVAL; out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen, &ctx->params); if (zstd_is_error(out_len)) return -EINVAL; *dlen = out_len; return 0; } static int zstd_compress(struct acomp_req *req) { struct crypto_acomp_stream *s; unsigned int pos, scur, dcur; unsigned int total_out = 0; bool data_available = true; zstd_out_buffer outbuf; struct acomp_walk walk; zstd_in_buffer inbuf; struct zstd_ctx *ctx; size_t pending_bytes; size_t num_bytes; int ret; s = crypto_acomp_lock_stream_bh(&zstd_streams); ctx = s->ctx; ret = acomp_walk_virt(&walk, req, true); if (ret) goto out; ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size); if (!ctx->cctx) { ret = -EINVAL; goto out; } do { dcur = acomp_walk_next_dst(&walk); if (!dcur) { ret = -ENOSPC; goto out; } outbuf.pos = 0; outbuf.dst = (u8 *)walk.dst.virt.addr; outbuf.size = dcur; do { scur = acomp_walk_next_src(&walk); if (dcur == req->dlen && scur == req->slen) { ret = zstd_compress_one(req, ctx, walk.src.virt.addr, walk.dst.virt.addr, &total_out); acomp_walk_done_src(&walk, scur); acomp_walk_done_dst(&walk, dcur); goto out; } if (scur) { inbuf.pos = 0; inbuf.src = walk.src.virt.addr; inbuf.size = scur; } else { data_available = false; break; } num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf); if (ZSTD_isError(num_bytes)) { ret = -EIO; goto out; } pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf); if (ZSTD_isError(pending_bytes)) { ret = -EIO; goto out; } acomp_walk_done_src(&walk, inbuf.pos); } while (dcur != outbuf.pos); total_out += outbuf.pos; acomp_walk_done_dst(&walk, dcur); } while (data_available); pos = outbuf.pos; num_bytes = zstd_end_stream(ctx->cctx, &outbuf); if (ZSTD_isError(num_bytes)) ret = -EIO; else total_out += (outbuf.pos - pos); out: if (ret) req->dlen = 0; else req->dlen = total_out; crypto_acomp_unlock_stream_bh(s); return ret; } static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx, const void *src, void *dst, unsigned int *dlen) { size_t out_len; ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size); if (!ctx->dctx) return -EINVAL; out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen); if (zstd_is_error(out_len)) return -EINVAL; *dlen = out_len; return 0; } static int zstd_decompress(struct acomp_req *req) { struct crypto_acomp_stream *s; unsigned int total_out = 0; unsigned int scur, dcur; zstd_out_buffer outbuf; struct acomp_walk walk; zstd_in_buffer inbuf; struct zstd_ctx *ctx; size_t pending_bytes; int ret; s = crypto_acomp_lock_stream_bh(&zstd_streams); ctx = s->ctx; ret = acomp_walk_virt(&walk, req, true); if (ret) goto out; ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size); if (!ctx->dctx) { ret = -EINVAL; goto out; } do { scur = acomp_walk_next_src(&walk); if (scur) { inbuf.pos = 0; inbuf.size = scur; inbuf.src = walk.src.virt.addr; } else { break; } do { dcur = acomp_walk_next_dst(&walk); if (dcur == req->dlen && scur == req->slen) { ret = zstd_decompress_one(req, ctx, walk.src.virt.addr, walk.dst.virt.addr, &total_out); acomp_walk_done_dst(&walk, dcur); acomp_walk_done_src(&walk, scur); goto out; } if (!dcur) { ret = -ENOSPC; goto out; } outbuf.pos = 0; outbuf.dst = (u8 *)walk.dst.virt.addr; outbuf.size = dcur; pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf); if (ZSTD_isError(pending_bytes)) { ret = -EIO; goto out; } total_out += outbuf.pos; acomp_walk_done_dst(&walk, outbuf.pos); } while (inbuf.pos != scur); acomp_walk_done_src(&walk, scur); } while (ret == 0); out: if (ret) req->dlen = 0; else req->dlen = total_out; crypto_acomp_unlock_stream_bh(s); return ret; } static struct acomp_alg zstd_acomp = { .base = { .cra_name = "zstd", .cra_driver_name = "zstd-generic", .cra_flags = CRYPTO_ALG_REQ_VIRT, .cra_module = THIS_MODULE, }, .init = zstd_init, .exit = zstd_exit, .compress = zstd_compress, .decompress = zstd_decompress, }; static int __init zstd_mod_init(void) { return crypto_register_acomp(&zstd_acomp); } static void __exit zstd_mod_fini(void) { crypto_unregister_acomp(&zstd_acomp); } module_init(zstd_mod_init); module_exit(zstd_mod_fini); MODULE_LICENSE("GPL"); MODULE_DESCRIPTION("Zstd Compression Algorithm"); MODULE_ALIAS_CRYPTO("zstd");