/* SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only */
/* Copyright (c) 2025 Brett A C Sheffield <bacs@librecast.net> */

#include "test.h"
#include "testnet.h"
#include <fcntl.h>
#include <librecast/net.h>
#include <librecast/if.h>
#include <pthread.h>
#include <semaphore.h>
#include <unistd.h>

#define WAITS 4 /* timeout seconds */
#define TEST_SIZE 1024
#define TEST_NAME "netlink interface detection"

#ifdef HAVE_LINUX_NETLINK_H
static sem_t sem_done;
static sem_t sem_recv;

enum {
	RECV,
	SEND
};
static char send_buf[TEST_SIZE];

static void *thread_recv(void *arg)
{
	lc_ctx_t *lctx;
	lc_socket_t *sock;
	lc_channel_t *chan;
	char buf[BUFSIZ];
	ssize_t byt;
	int rc;

	lctx = lc_ctx_new();
	if (!test_assert(lctx != NULL, "lc_ctx_new()")) return NULL;
	sock = lc_socket_new(lctx);
	if (!test_assert(sock != NULL, "lc_socket_new()")) goto free_ctx;
	chan = lc_channel_new(lctx, TEST_NAME);
	if (!test_assert(chan != NULL, "lc_channel_new()")) goto free_ctx;
	rc = lc_channel_bind(sock, chan);
	if (!test_assert(rc == 0, "lc_channel_bind()")) goto free_ctx;
	rc = lc_channel_join(chan);
	if (!test_assert(rc == 0, "lc_channel_join()")) goto free_ctx;
	memset(buf, 0, sizeof buf);

	sem_post(&sem_recv);	/* tell sender we're ready */

	byt = lc_channel_recv(chan, buf, TEST_SIZE, 0);
	if (byt == -1) perror("lc_channel_recv");
	test_assert(byt == TEST_SIZE, "lc_channel_recv() returned %zi", byt);
	test_assert(memcmp(send_buf, buf, TEST_SIZE) == 0, "received data matches");

	sem_post(&sem_done);	/* tell ctrl thread we are finished */

free_ctx:
	lc_ctx_free(lctx);
	return arg;
}

static void disable_dad(char *ifname)
{
	char fname[128];
	char sysvar[] = "/proc/sys/net/ipv6/conf/%s/accept_dad";
	int fd;
	snprintf(fname, 128, sysvar, ifname);
	fd = open(fname, O_WRONLY);
	test_assert(write(fd, "0", 1) == 1, "write");
	close(fd);
}

static void *thread_send(void *arg)
{
	lc_ctx_t *lctx;
	lc_socket_t *sock;
	lc_channel_t *chan;
	char ifname[IFNAMSIZ] = {0};
	ssize_t byt;
	unsigned int ifx;
	int fd;
	int rc;

	lctx = lc_ctx_new();
	if (!test_assert(lctx != NULL, "lc_ctx_new()")) return NULL;
	sock = lc_socket_new(lctx);
	if (!test_assert(sock != NULL, "lc_socket_new()")) goto free_ctx;
	chan = lc_channel_new(lctx, TEST_NAME);
	if (!test_assert(chan != NULL, "lc_channel_new()")) goto free_ctx;
	rc = lc_channel_bind(sock, chan);
	if (!test_assert(rc == 0, "lc_channel_bind()")) goto free_ctx;
	lc_socket_loop(sock, 1);

	sem_wait(&sem_recv); /* wait for receiver to be ready */

	/* create a new interface and use that to send */
	fd = lc_tap_create(ifname);
	if (fd == -1) perror("lc_tap_create");
	test_assert(fd >= 0, "lc_tap_create() returned %i", fd);
	disable_dad(ifname); /* otherwise we need to sleep 2s for DAD */
	rc = lc_link_set(lctx, ifname, 1);
	if (!test_assert(rc == 0, "bring up interface")) goto free_ctx;
	ifx = if_nametoindex(ifname);
	if (!test_assert(ifx > 0, "find ifx for %s == %u", ifname, ifx)) goto free_ctx;
	rc = lc_socket_bind(sock, ifx);
	if (!test_assert(rc == 0, "bind to interface %u", ifx)) goto free_ctx;

	/* allow time for recv to handle netlink event */
	if (RUNNING_ON_VALGRIND) sleep(2);
	else usleep(10000);

	byt = lc_channel_send(chan, send_buf, sizeof send_buf, 0);
	if (byt == -1) perror("lc_channel_send");
	test_assert(byt == (ssize_t)TEST_SIZE, "lc_channel_send() returned %zi / %zi", byt, TEST_SIZE);

free_ctx:
	lc_ctx_free(lctx);
	return arg;
}

static int test_netlink(void)
{
	pthread_t tid[2];
	void *(*thread_f[2])(void *) = { &thread_recv, &thread_send };
	int threads = 0;
	int rc;

	/* generate test data */
	arc4random_buf(send_buf, sizeof send_buf);

	rc = sem_init(&sem_done, 0, 0);
	if (!test_assert(rc == 0, "sem_init() sem_done")) return test_status;
	rc = sem_init(&sem_recv, 0, 0);
	if (!test_assert(rc == 0, "sem_init() sem_recv")) goto free_sem_done;

	/* start threads */
	for (int i = 0; i < 2; i++) {
		rc = pthread_create(&tid[i], NULL, thread_f[i], NULL);
		if (!test_assert(rc == 0, "%i: pthread_create", i)) goto join_threads;
		threads++;
	}

	/* timeout */
	struct timespec ts;
	if (!test_assert(!clock_gettime(CLOCK_REALTIME, &ts), "clock_gettime()")) goto join_threads;
	ts.tv_sec += WAITS;
	test_assert(!sem_timedwait(&sem_done, &ts), "timeout");

join_threads:
	for (int i = 0; i < threads; i++) {
		pthread_cancel(tid[i]);
		pthread_join(tid[i], NULL);
	}
	sem_destroy(&sem_recv);
free_sem_done:
	sem_destroy(&sem_done);
	return test_status;
}
#endif /* HAVE_LINUX_NETLINK_H */

int main(void)
{
	char name[] = TEST_NAME;

	test_cap_require(CAP_NET_ADMIN);
	test_name(name);
#ifdef HAVE_LINUX_NETLINK_H
	test_require_net(TEST_NET_BASIC);

	if (test_netlink()) return test_status;
#endif
	return test_status;
}
