import matplotlib.pyplot as plt
from pymembrane.membrane import membrane

# Initialize the spiral membrane object
sm = membrane.spiral_membrane(
    L=4.5, DP=0.5, S=118.5, Pin=9.5, Vin=10.0, T=25
)

# Specify solutes, mass transfer coefficients, and inlet concentrations
sm.solutes = ['sucrose', 'fructose', 'lactic acid']
sm.B = [0.000144, 5.4e-05, 0.00027]
sm.k = [0.036, 0.0432, 0.0684]
sm.Cin = [0.1454, 2.4083, 3.5628]

# Run the simulation
sm.calcul()

# Plot the retention profiles
for i in range(len(sm.solutes)):
    plt.plot(sm.res.x[1:], sm.res.R[i, 1:], label=sm.solutes[i])

plt.xlabel("Membrane position [m]", fontsize=14)
plt.ylabel("Retention rate", fontsize=14)
plt.grid()
plt.legend(fontsize=14)
plt.tight_layout()
plt.show()