| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- // SPDX-License-Identifier: GPL-2.0-only
- /*
- * Cryptographic API.
- *
- * Copyright (c) 2017-present, Facebook, Inc.
- */
- #include <linux/crypto.h>
- #include <linux/init.h>
- #include <linux/interrupt.h>
- #include <linux/mm.h>
- #include <linux/module.h>
- #include <linux/net.h>
- #include <linux/overflow.h>
- #include <linux/vmalloc.h>
- #include <linux/zstd.h>
- #include <crypto/internal/acompress.h>
- #include <crypto/scatterwalk.h>
- #define ZSTD_DEF_LEVEL 3
- #define ZSTD_MAX_WINDOWLOG 18
- #define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG)
- struct zstd_ctx {
- zstd_cctx *cctx;
- zstd_dctx *dctx;
- size_t wksp_size;
- zstd_parameters params;
- u8 wksp[] __aligned(8) __counted_by(wksp_size);
- };
- static DEFINE_MUTEX(zstd_stream_lock);
- static void *zstd_alloc_stream(void)
- {
- zstd_parameters params;
- struct zstd_ctx *ctx;
- size_t wksp_size;
- params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE);
- wksp_size = max(zstd_cstream_workspace_bound(¶ms.cParams),
- zstd_dstream_workspace_bound(ZSTD_MAX_SIZE));
- if (!wksp_size)
- return ERR_PTR(-EINVAL);
- ctx = kvmalloc_flex(*ctx, wksp, wksp_size);
- if (!ctx)
- return ERR_PTR(-ENOMEM);
- ctx->params = params;
- ctx->wksp_size = wksp_size;
- return ctx;
- }
- static void zstd_free_stream(void *ctx)
- {
- kvfree(ctx);
- }
- static struct crypto_acomp_streams zstd_streams = {
- .alloc_ctx = zstd_alloc_stream,
- .free_ctx = zstd_free_stream,
- };
- static int zstd_init(struct crypto_acomp *acomp_tfm)
- {
- int ret = 0;
- mutex_lock(&zstd_stream_lock);
- ret = crypto_acomp_alloc_streams(&zstd_streams);
- mutex_unlock(&zstd_stream_lock);
- return ret;
- }
- static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx,
- const void *src, void *dst, unsigned int *dlen)
- {
- size_t out_len;
- ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size);
- if (!ctx->cctx)
- return -EINVAL;
- out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen,
- &ctx->params);
- if (zstd_is_error(out_len))
- return -EINVAL;
- *dlen = out_len;
- return 0;
- }
- static int zstd_compress(struct acomp_req *req)
- {
- struct crypto_acomp_stream *s;
- unsigned int pos, scur, dcur;
- unsigned int total_out = 0;
- bool data_available = true;
- zstd_out_buffer outbuf;
- struct acomp_walk walk;
- zstd_in_buffer inbuf;
- struct zstd_ctx *ctx;
- size_t pending_bytes;
- size_t num_bytes;
- int ret;
- s = crypto_acomp_lock_stream_bh(&zstd_streams);
- ctx = s->ctx;
- ret = acomp_walk_virt(&walk, req, true);
- if (ret)
- goto out;
- ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size);
- if (!ctx->cctx) {
- ret = -EINVAL;
- goto out;
- }
- do {
- dcur = acomp_walk_next_dst(&walk);
- if (!dcur) {
- ret = -ENOSPC;
- goto out;
- }
- outbuf.pos = 0;
- outbuf.dst = (u8 *)walk.dst.virt.addr;
- outbuf.size = dcur;
- do {
- scur = acomp_walk_next_src(&walk);
- if (dcur == req->dlen && scur == req->slen) {
- ret = zstd_compress_one(req, ctx, walk.src.virt.addr,
- walk.dst.virt.addr, &total_out);
- acomp_walk_done_src(&walk, scur);
- acomp_walk_done_dst(&walk, dcur);
- goto out;
- }
- if (scur) {
- inbuf.pos = 0;
- inbuf.src = walk.src.virt.addr;
- inbuf.size = scur;
- } else {
- data_available = false;
- break;
- }
- num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf);
- if (ZSTD_isError(num_bytes)) {
- ret = -EIO;
- goto out;
- }
- pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf);
- if (ZSTD_isError(pending_bytes)) {
- ret = -EIO;
- goto out;
- }
- acomp_walk_done_src(&walk, inbuf.pos);
- } while (dcur != outbuf.pos);
- total_out += outbuf.pos;
- acomp_walk_done_dst(&walk, dcur);
- } while (data_available);
- pos = outbuf.pos;
- num_bytes = zstd_end_stream(ctx->cctx, &outbuf);
- if (ZSTD_isError(num_bytes))
- ret = -EIO;
- else
- total_out += (outbuf.pos - pos);
- out:
- if (ret)
- req->dlen = 0;
- else
- req->dlen = total_out;
- crypto_acomp_unlock_stream_bh(s);
- return ret;
- }
- static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx,
- const void *src, void *dst, unsigned int *dlen)
- {
- size_t out_len;
- ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size);
- if (!ctx->dctx)
- return -EINVAL;
- out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen);
- if (zstd_is_error(out_len))
- return -EINVAL;
- *dlen = out_len;
- return 0;
- }
- static int zstd_decompress(struct acomp_req *req)
- {
- struct crypto_acomp_stream *s;
- unsigned int total_out = 0;
- unsigned int scur, dcur;
- zstd_out_buffer outbuf;
- struct acomp_walk walk;
- zstd_in_buffer inbuf;
- struct zstd_ctx *ctx;
- size_t pending_bytes;
- int ret;
- s = crypto_acomp_lock_stream_bh(&zstd_streams);
- ctx = s->ctx;
- ret = acomp_walk_virt(&walk, req, true);
- if (ret)
- goto out;
- ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size);
- if (!ctx->dctx) {
- ret = -EINVAL;
- goto out;
- }
- do {
- scur = acomp_walk_next_src(&walk);
- if (scur) {
- inbuf.pos = 0;
- inbuf.size = scur;
- inbuf.src = walk.src.virt.addr;
- } else {
- break;
- }
- do {
- dcur = acomp_walk_next_dst(&walk);
- if (dcur == req->dlen && scur == req->slen) {
- ret = zstd_decompress_one(req, ctx, walk.src.virt.addr,
- walk.dst.virt.addr, &total_out);
- acomp_walk_done_dst(&walk, dcur);
- acomp_walk_done_src(&walk, scur);
- goto out;
- }
- if (!dcur) {
- ret = -ENOSPC;
- goto out;
- }
- outbuf.pos = 0;
- outbuf.dst = (u8 *)walk.dst.virt.addr;
- outbuf.size = dcur;
- pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf);
- if (ZSTD_isError(pending_bytes)) {
- ret = -EIO;
- goto out;
- }
- total_out += outbuf.pos;
- acomp_walk_done_dst(&walk, outbuf.pos);
- } while (inbuf.pos != scur);
- acomp_walk_done_src(&walk, scur);
- } while (ret == 0);
- out:
- if (ret)
- req->dlen = 0;
- else
- req->dlen = total_out;
- crypto_acomp_unlock_stream_bh(s);
- return ret;
- }
- static struct acomp_alg zstd_acomp = {
- .base = {
- .cra_name = "zstd",
- .cra_driver_name = "zstd-generic",
- .cra_flags = CRYPTO_ALG_REQ_VIRT,
- .cra_module = THIS_MODULE,
- },
- .init = zstd_init,
- .compress = zstd_compress,
- .decompress = zstd_decompress,
- };
- static int __init zstd_mod_init(void)
- {
- return crypto_register_acomp(&zstd_acomp);
- }
- static void __exit zstd_mod_fini(void)
- {
- crypto_unregister_acomp(&zstd_acomp);
- crypto_acomp_free_streams(&zstd_streams);
- }
- module_init(zstd_mod_init);
- module_exit(zstd_mod_fini);
- MODULE_LICENSE("GPL");
- MODULE_DESCRIPTION("Zstd Compression Algorithm");
- MODULE_ALIAS_CRYPTO("zstd");
|