diff --git a/include/net/tls.h b/include/net/tls.h index 2ad28545b15f..6c642ea18050 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -395,8 +395,12 @@ tls_offload_ctx_tx(const struct tls_context *tls_ctx) static inline bool tls_sw_has_ctx_tx(const struct sock *sk) { - struct tls_context *ctx = tls_get_ctx(sk); + struct tls_context *ctx; + if (!sk_is_inet(sk) || !inet_test_bit(IS_ICSK, sk)) + return false; + + ctx = tls_get_ctx(sk); if (!ctx) return false; return !!tls_sw_ctx_tx(ctx); @@ -404,8 +408,12 @@ static inline bool tls_sw_has_ctx_tx(const struct sock *sk) static inline bool tls_sw_has_ctx_rx(const struct sock *sk) { - struct tls_context *ctx = tls_get_ctx(sk); + struct tls_context *ctx; + if (!sk_is_inet(sk) || !inet_test_bit(IS_ICSK, sk)) + return false; + + ctx = tls_get_ctx(sk); if (!ctx) return false; return !!tls_sw_ctx_rx(ctx);