auth_tls.c 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. // SPDX-License-Identifier: GPL-2.0-only
  2. /*
  3. * Copyright (c) 2021, 2022 Oracle. All rights reserved.
  4. *
  5. * The AUTH_TLS credential is used only to probe a remote peer
  6. * for RPC-over-TLS support.
  7. */
  8. #include <linux/types.h>
  9. #include <linux/module.h>
  10. #include <linux/sunrpc/clnt.h>
  11. static const char *starttls_token = "STARTTLS";
  12. static const size_t starttls_len = 8;
  13. static struct rpc_auth tls_auth;
  14. static struct rpc_cred tls_cred;
  15. static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
  16. const void *obj)
  17. {
  18. }
  19. static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
  20. void *obj)
  21. {
  22. return 0;
  23. }
  24. static const struct rpc_procinfo rpcproc_tls_probe = {
  25. .p_encode = tls_encode_probe,
  26. .p_decode = tls_decode_probe,
  27. };
  28. static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
  29. {
  30. task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
  31. rpc_call_start(task);
  32. }
  33. static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
  34. {
  35. }
  36. static const struct rpc_call_ops rpc_tls_probe_ops = {
  37. .rpc_call_prepare = rpc_tls_probe_call_prepare,
  38. .rpc_call_done = rpc_tls_probe_call_done,
  39. };
  40. static int tls_probe(struct rpc_clnt *clnt)
  41. {
  42. struct rpc_message msg = {
  43. .rpc_proc = &rpcproc_tls_probe,
  44. };
  45. struct rpc_task_setup task_setup_data = {
  46. .rpc_client = clnt,
  47. .rpc_message = &msg,
  48. .rpc_op_cred = &tls_cred,
  49. .callback_ops = &rpc_tls_probe_ops,
  50. .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
  51. };
  52. struct rpc_task *task;
  53. int status;
  54. task = rpc_run_task(&task_setup_data);
  55. if (IS_ERR(task))
  56. return PTR_ERR(task);
  57. status = task->tk_status;
  58. rpc_put_task(task);
  59. return status;
  60. }
  61. static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
  62. struct rpc_clnt *clnt)
  63. {
  64. refcount_inc(&tls_auth.au_count);
  65. return &tls_auth;
  66. }
  67. static void tls_destroy(struct rpc_auth *auth)
  68. {
  69. }
  70. static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
  71. struct auth_cred *acred, int flags)
  72. {
  73. return get_rpccred(&tls_cred);
  74. }
  75. static void tls_destroy_cred(struct rpc_cred *cred)
  76. {
  77. }
  78. static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
  79. {
  80. return 1;
  81. }
  82. static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
  83. {
  84. __be32 *p;
  85. p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
  86. if (!p)
  87. return -EMSGSIZE;
  88. /* Credential */
  89. *p++ = rpc_auth_tls;
  90. *p++ = xdr_zero;
  91. /* Verifier */
  92. *p++ = rpc_auth_null;
  93. *p = xdr_zero;
  94. return 0;
  95. }
  96. static int tls_refresh(struct rpc_task *task)
  97. {
  98. set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
  99. return 0;
  100. }
  101. static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
  102. {
  103. __be32 *p;
  104. void *str;
  105. p = xdr_inline_decode(xdr, XDR_UNIT);
  106. if (!p)
  107. return -EIO;
  108. if (*p != rpc_auth_null)
  109. return -EIO;
  110. if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
  111. return -EPROTONOSUPPORT;
  112. if (memcmp(str, starttls_token, starttls_len))
  113. return -EPROTONOSUPPORT;
  114. return 0;
  115. }
  116. const struct rpc_authops authtls_ops = {
  117. .owner = THIS_MODULE,
  118. .au_flavor = RPC_AUTH_TLS,
  119. .au_name = "NULL",
  120. .create = tls_create,
  121. .destroy = tls_destroy,
  122. .lookup_cred = tls_lookup_cred,
  123. .ping = tls_probe,
  124. };
  125. static struct rpc_auth tls_auth = {
  126. .au_cslack = NUL_CALLSLACK,
  127. .au_rslack = NUL_REPLYSLACK,
  128. .au_verfsize = NUL_REPLYSLACK,
  129. .au_ralign = NUL_REPLYSLACK,
  130. .au_ops = &authtls_ops,
  131. .au_flavor = RPC_AUTH_TLS,
  132. .au_count = REFCOUNT_INIT(1),
  133. };
  134. static const struct rpc_credops tls_credops = {
  135. .cr_name = "AUTH_TLS",
  136. .crdestroy = tls_destroy_cred,
  137. .crmatch = tls_match,
  138. .crmarshal = tls_marshal,
  139. .crwrap_req = rpcauth_wrap_req_encode,
  140. .crrefresh = tls_refresh,
  141. .crvalidate = tls_validate,
  142. .crunwrap_resp = rpcauth_unwrap_resp_decode,
  143. };
  144. static struct rpc_cred tls_cred = {
  145. .cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru),
  146. .cr_auth = &tls_auth,
  147. .cr_ops = &tls_credops,
  148. .cr_count = REFCOUNT_INIT(2),
  149. .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE,
  150. };