Files

164 lines
4.0 KiB
C
Raw Permalink Normal View History

2026-06-09 06:43:13 +02:00
/*
5(N×N矩阵乘法函数的并行线程化版本matmult.c
2^k N < 2^(k+1)
线1242^(k-1)2^k等情况下的运行时间
*/
#include <stdio.h>
#include <stdlib.h>
#include <pthread.h>
#include <sys/time.h>
double get_time()
{
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec + tv.tv_usec / 1000000.0;
}
typedef struct {
int id;
int start_row;
int end_row;
int N;
double *A;
double *B;
double *C;
} thread_arg_t;
void *mul_thread(void *arg)
{
thread_arg_t *ta = (thread_arg_t *)arg;
for (int i = ta->start_row; i < ta->end_row; i++)
{
for (int j = 0; j < ta->N; j++)
{
double sum = 0.0;
for (int k = 0; k < ta->N; k++)
{
sum += ta->A[i * ta->N + k] * ta->B[k * ta->N + j];
}
ta->C[i * ta->N + j] = sum;
}
}
return NULL;
}
/* 串行版本用于验证 */
void matmul_serial(double *A, double *B, double *C, int N)
{
for (int i = 0; i < N; i++)
{
for (int j = 0; j < N; j++)
{
double sum = 0.0;
for (int k = 0; k < N; k++)
{
sum += A[i * N + k] * B[k * N + j];
}
C[i * N + j] = sum;
}
}
}
int verify(double *C1, double *C2, int N)
{
for (int i = 0; i < N * N; i++)
{
if (C1[i] - C2[i] > 1e-6 || C2[i] - C1[i] > 1e-6)
{
printf("Mismatch at index %d: %f vs %f\n", i, C1[i], C2[i]);
return 0;
}
}
return 1;
}
int main(int argc, char **argv)
{
int N, nthreads;
if (argc != 3)
{
printf("Usage: %s <N> <nthreads>\n", argv[0]);
exit(0);
}
N = atoi(argv[1]);
nthreads = atoi(argv[2]);
double *A = (double *)malloc(N * N * sizeof(double));
double *B = (double *)malloc(N * N * sizeof(double));
double *C = (double *)malloc(N * N * sizeof(double));
double *C_serial = (double *)malloc(N * N * sizeof(double));
srand(42);
for (int i = 0; i < N * N; i++)
{
A[i] = (double)(rand() % 100) / 10.0;
B[i] = (double)(rand() % 100) / 10.0;
}
/* 并行版本计时 */
pthread_t *tid = (pthread_t *)malloc(nthreads * sizeof(pthread_t));
thread_arg_t *targs = (thread_arg_t *)malloc(nthreads * sizeof(thread_arg_t));
double t_start = get_time();
int rows_per_thread = N / nthreads;
int extra = N % nthreads;
int current_row = 0;
for (int i = 0; i < nthreads; i++)
{
targs[i].id = i;
targs[i].start_row = current_row;
targs[i].end_row = current_row + rows_per_thread + (i < extra ? 1 : 0);
targs[i].N = N;
targs[i].A = A;
targs[i].B = B;
targs[i].C = C;
current_row = targs[i].end_row;
pthread_create(&tid[i], NULL, mul_thread, &targs[i]);
}
for (int i = 0; i < nthreads; i++)
{
pthread_join(tid[i], NULL);
}
double t_end = get_time();
double t_parallel = t_end - t_start;
/* 串行版本验证 */
double t_ser_start = get_time();
matmul_serial(A, B, C_serial, N);
double t_ser_end = get_time();
double t_serial = t_ser_end - t_ser_start;
printf("Matrix size: %d x %d, Threads: %d\n", N, N, nthreads);
printf("Serial time: %f seconds\n", t_serial);
printf("Parallel time: %f seconds\n", t_parallel);
if (verify(C, C_serial, N))
{
printf("Verification: PASSED - parallel result matches serial.\n");
}
else
{
printf("Verification: FAILED.\n");
}
if (t_parallel > 0)
{
printf("Speedup: %f\n", t_serial / t_parallel);
printf("Efficiency: %f\n", t_serial / (nthreads * t_parallel));
}
free(A);
free(B);
free(C);
free(C_serial);
free(tid);
free(targs);
return 0;
}