static_stub.c 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. // SPDX-License-Identifier: GPL-2.0
  2. /*
  3. * KUnit function redirection (static stubbing) API.
  4. *
  5. * Copyright (C) 2022, Google LLC.
  6. * Author: David Gow <davidgow@google.com>
  7. */
  8. #include <kunit/test.h>
  9. #include <kunit/static_stub.h>
  10. #include "hooks-impl.h"
  11. /* Context for a static stub. This is stored in the resource data. */
  12. struct kunit_static_stub_ctx {
  13. void *real_fn_addr;
  14. void *replacement_addr;
  15. };
  16. static void __kunit_static_stub_resource_free(struct kunit_resource *res)
  17. {
  18. kfree(res->data);
  19. }
  20. /* Matching function for kunit_find_resource(). match_data is real_fn_addr. */
  21. static bool __kunit_static_stub_resource_match(struct kunit *test,
  22. struct kunit_resource *res,
  23. void *match_real_fn_addr)
  24. {
  25. /* This pointer is only valid if res is a static stub resource. */
  26. struct kunit_static_stub_ctx *ctx = res->data;
  27. /* Make sure the resource is a static stub resource. */
  28. if (res->free != &__kunit_static_stub_resource_free)
  29. return false;
  30. return ctx->real_fn_addr == match_real_fn_addr;
  31. }
  32. /* Hook to return the address of the replacement function. */
  33. void *__kunit_get_static_stub_address_impl(struct kunit *test, void *real_fn_addr)
  34. {
  35. struct kunit_resource *res;
  36. struct kunit_static_stub_ctx *ctx;
  37. void *replacement_addr;
  38. res = kunit_find_resource(test,
  39. __kunit_static_stub_resource_match,
  40. real_fn_addr);
  41. if (!res)
  42. return NULL;
  43. ctx = res->data;
  44. replacement_addr = ctx->replacement_addr;
  45. kunit_put_resource(res);
  46. return replacement_addr;
  47. }
  48. void kunit_deactivate_static_stub(struct kunit *test, void *real_fn_addr)
  49. {
  50. struct kunit_resource *res;
  51. KUNIT_ASSERT_PTR_NE_MSG(test, real_fn_addr, NULL,
  52. "Tried to deactivate a NULL stub.");
  53. /* Look up the existing stub for this function. */
  54. res = kunit_find_resource(test,
  55. __kunit_static_stub_resource_match,
  56. real_fn_addr);
  57. /* Error out if the stub doesn't exist. */
  58. KUNIT_ASSERT_PTR_NE_MSG(test, res, NULL,
  59. "Tried to deactivate a nonexistent stub.");
  60. /* Free the stub. We 'put' twice, as we got a reference
  61. * from kunit_find_resource()
  62. */
  63. kunit_remove_resource(test, res);
  64. kunit_put_resource(res);
  65. }
  66. EXPORT_SYMBOL_GPL(kunit_deactivate_static_stub);
  67. /* Helper function for kunit_activate_static_stub(). The macro does
  68. * typechecking, so use it instead.
  69. */
  70. void __kunit_activate_static_stub(struct kunit *test,
  71. void *real_fn_addr,
  72. void *replacement_addr)
  73. {
  74. struct kunit_static_stub_ctx *ctx;
  75. struct kunit_resource *res;
  76. KUNIT_ASSERT_PTR_NE_MSG(test, real_fn_addr, NULL,
  77. "Tried to activate a stub for function NULL");
  78. /* If the replacement address is NULL, deactivate the stub. */
  79. if (!replacement_addr) {
  80. kunit_deactivate_static_stub(test, real_fn_addr);
  81. return;
  82. }
  83. /* Look up any existing stubs for this function, and replace them. */
  84. res = kunit_find_resource(test,
  85. __kunit_static_stub_resource_match,
  86. real_fn_addr);
  87. if (res) {
  88. ctx = res->data;
  89. ctx->replacement_addr = replacement_addr;
  90. /* We got an extra reference from find_resource(), so put it. */
  91. kunit_put_resource(res);
  92. } else {
  93. ctx = kmalloc_obj(*ctx);
  94. KUNIT_ASSERT_NOT_ERR_OR_NULL(test, ctx);
  95. ctx->real_fn_addr = real_fn_addr;
  96. ctx->replacement_addr = replacement_addr;
  97. res = kunit_alloc_resource(test, NULL,
  98. &__kunit_static_stub_resource_free,
  99. GFP_KERNEL, ctx);
  100. }
  101. }
  102. EXPORT_SYMBOL_GPL(__kunit_activate_static_stub);