zstd.c 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Cryptographic API.
  4. *
  5. * Copyright (c) 2017-present, Facebook, Inc.
  6. */
  7. #include <linux/crypto.h>
  8. #include <linux/init.h>
  9. #include <linux/interrupt.h>
  10. #include <linux/mm.h>
  11. #include <linux/module.h>
  12. #include <linux/net.h>
  13. #include <linux/overflow.h>
  14. #include <linux/vmalloc.h>
  15. #include <linux/zstd.h>
  16. #include <crypto/internal/acompress.h>
  17. #include <crypto/scatterwalk.h>
  18. #define ZSTD_DEF_LEVEL 3
  19. #define ZSTD_MAX_WINDOWLOG 18
  20. #define ZSTD_MAX_SIZE BIT(ZSTD_MAX_WINDOWLOG)
  21. struct zstd_ctx {
  22. zstd_cctx *cctx;
  23. zstd_dctx *dctx;
  24. size_t wksp_size;
  25. zstd_parameters params;
  26. u8 wksp[] __aligned(8) __counted_by(wksp_size);
  27. };
  28. static DEFINE_MUTEX(zstd_stream_lock);
  29. static void *zstd_alloc_stream(void)
  30. {
  31. zstd_parameters params;
  32. struct zstd_ctx *ctx;
  33. size_t wksp_size;
  34. params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE);
  35. wksp_size = max(zstd_cstream_workspace_bound(&params.cParams),
  36. zstd_dstream_workspace_bound(ZSTD_MAX_SIZE));
  37. if (!wksp_size)
  38. return ERR_PTR(-EINVAL);
  39. ctx = kvmalloc_flex(*ctx, wksp, wksp_size);
  40. if (!ctx)
  41. return ERR_PTR(-ENOMEM);
  42. ctx->params = params;
  43. ctx->wksp_size = wksp_size;
  44. return ctx;
  45. }
  46. static void zstd_free_stream(void *ctx)
  47. {
  48. kvfree(ctx);
  49. }
  50. static struct crypto_acomp_streams zstd_streams = {
  51. .alloc_ctx = zstd_alloc_stream,
  52. .free_ctx = zstd_free_stream,
  53. };
  54. static int zstd_init(struct crypto_acomp *acomp_tfm)
  55. {
  56. int ret = 0;
  57. mutex_lock(&zstd_stream_lock);
  58. ret = crypto_acomp_alloc_streams(&zstd_streams);
  59. mutex_unlock(&zstd_stream_lock);
  60. return ret;
  61. }
  62. static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx,
  63. const void *src, void *dst, unsigned int *dlen)
  64. {
  65. size_t out_len;
  66. ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size);
  67. if (!ctx->cctx)
  68. return -EINVAL;
  69. out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen,
  70. &ctx->params);
  71. if (zstd_is_error(out_len))
  72. return -EINVAL;
  73. *dlen = out_len;
  74. return 0;
  75. }
  76. static int zstd_compress(struct acomp_req *req)
  77. {
  78. struct crypto_acomp_stream *s;
  79. unsigned int pos, scur, dcur;
  80. unsigned int total_out = 0;
  81. bool data_available = true;
  82. zstd_out_buffer outbuf;
  83. struct acomp_walk walk;
  84. zstd_in_buffer inbuf;
  85. struct zstd_ctx *ctx;
  86. size_t pending_bytes;
  87. size_t num_bytes;
  88. int ret;
  89. s = crypto_acomp_lock_stream_bh(&zstd_streams);
  90. ctx = s->ctx;
  91. ret = acomp_walk_virt(&walk, req, true);
  92. if (ret)
  93. goto out;
  94. ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size);
  95. if (!ctx->cctx) {
  96. ret = -EINVAL;
  97. goto out;
  98. }
  99. do {
  100. dcur = acomp_walk_next_dst(&walk);
  101. if (!dcur) {
  102. ret = -ENOSPC;
  103. goto out;
  104. }
  105. outbuf.pos = 0;
  106. outbuf.dst = (u8 *)walk.dst.virt.addr;
  107. outbuf.size = dcur;
  108. do {
  109. scur = acomp_walk_next_src(&walk);
  110. if (dcur == req->dlen && scur == req->slen) {
  111. ret = zstd_compress_one(req, ctx, walk.src.virt.addr,
  112. walk.dst.virt.addr, &total_out);
  113. acomp_walk_done_src(&walk, scur);
  114. acomp_walk_done_dst(&walk, dcur);
  115. goto out;
  116. }
  117. if (scur) {
  118. inbuf.pos = 0;
  119. inbuf.src = walk.src.virt.addr;
  120. inbuf.size = scur;
  121. } else {
  122. data_available = false;
  123. break;
  124. }
  125. num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf);
  126. if (ZSTD_isError(num_bytes)) {
  127. ret = -EIO;
  128. goto out;
  129. }
  130. pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf);
  131. if (ZSTD_isError(pending_bytes)) {
  132. ret = -EIO;
  133. goto out;
  134. }
  135. acomp_walk_done_src(&walk, inbuf.pos);
  136. } while (dcur != outbuf.pos);
  137. total_out += outbuf.pos;
  138. acomp_walk_done_dst(&walk, dcur);
  139. } while (data_available);
  140. pos = outbuf.pos;
  141. num_bytes = zstd_end_stream(ctx->cctx, &outbuf);
  142. if (ZSTD_isError(num_bytes))
  143. ret = -EIO;
  144. else
  145. total_out += (outbuf.pos - pos);
  146. out:
  147. if (ret)
  148. req->dlen = 0;
  149. else
  150. req->dlen = total_out;
  151. crypto_acomp_unlock_stream_bh(s);
  152. return ret;
  153. }
  154. static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx,
  155. const void *src, void *dst, unsigned int *dlen)
  156. {
  157. size_t out_len;
  158. ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size);
  159. if (!ctx->dctx)
  160. return -EINVAL;
  161. out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen);
  162. if (zstd_is_error(out_len))
  163. return -EINVAL;
  164. *dlen = out_len;
  165. return 0;
  166. }
  167. static int zstd_decompress(struct acomp_req *req)
  168. {
  169. struct crypto_acomp_stream *s;
  170. unsigned int total_out = 0;
  171. unsigned int scur, dcur;
  172. zstd_out_buffer outbuf;
  173. struct acomp_walk walk;
  174. zstd_in_buffer inbuf;
  175. struct zstd_ctx *ctx;
  176. size_t pending_bytes;
  177. int ret;
  178. s = crypto_acomp_lock_stream_bh(&zstd_streams);
  179. ctx = s->ctx;
  180. ret = acomp_walk_virt(&walk, req, true);
  181. if (ret)
  182. goto out;
  183. ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size);
  184. if (!ctx->dctx) {
  185. ret = -EINVAL;
  186. goto out;
  187. }
  188. do {
  189. scur = acomp_walk_next_src(&walk);
  190. if (scur) {
  191. inbuf.pos = 0;
  192. inbuf.size = scur;
  193. inbuf.src = walk.src.virt.addr;
  194. } else {
  195. break;
  196. }
  197. do {
  198. dcur = acomp_walk_next_dst(&walk);
  199. if (dcur == req->dlen && scur == req->slen) {
  200. ret = zstd_decompress_one(req, ctx, walk.src.virt.addr,
  201. walk.dst.virt.addr, &total_out);
  202. acomp_walk_done_dst(&walk, dcur);
  203. acomp_walk_done_src(&walk, scur);
  204. goto out;
  205. }
  206. if (!dcur) {
  207. ret = -ENOSPC;
  208. goto out;
  209. }
  210. outbuf.pos = 0;
  211. outbuf.dst = (u8 *)walk.dst.virt.addr;
  212. outbuf.size = dcur;
  213. pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf);
  214. if (ZSTD_isError(pending_bytes)) {
  215. ret = -EIO;
  216. goto out;
  217. }
  218. total_out += outbuf.pos;
  219. acomp_walk_done_dst(&walk, outbuf.pos);
  220. } while (inbuf.pos != scur);
  221. acomp_walk_done_src(&walk, scur);
  222. } while (ret == 0);
  223. out:
  224. if (ret)
  225. req->dlen = 0;
  226. else
  227. req->dlen = total_out;
  228. crypto_acomp_unlock_stream_bh(s);
  229. return ret;
  230. }
  231. static struct acomp_alg zstd_acomp = {
  232. .base = {
  233. .cra_name = "zstd",
  234. .cra_driver_name = "zstd-generic",
  235. .cra_flags = CRYPTO_ALG_REQ_VIRT,
  236. .cra_module = THIS_MODULE,
  237. },
  238. .init = zstd_init,
  239. .compress = zstd_compress,
  240. .decompress = zstd_decompress,
  241. };
  242. static int __init zstd_mod_init(void)
  243. {
  244. return crypto_register_acomp(&zstd_acomp);
  245. }
  246. static void __exit zstd_mod_fini(void)
  247. {
  248. crypto_unregister_acomp(&zstd_acomp);
  249. crypto_acomp_free_streams(&zstd_streams);
  250. }
  251. module_init(zstd_mod_init);
  252. module_exit(zstd_mod_fini);
  253. MODULE_LICENSE("GPL");
  254. MODULE_DESCRIPTION("Zstd Compression Algorithm");
  255. MODULE_ALIAS_CRYPTO("zstd");