LLM assisted coded help dialog
Just before lunch today, I left a message for my son, asking for help with sklearn. Then I thought: Gemini should be able to help me.This is the resulting dialog (summarized):
- Me: I have some x y data points that tend to follow a linear trend but with this trend abruptly changing from time to time. I am looking to use sklearn where I provide my x y data, and at the end have optimally partititioned the data so that each partition's data can be fit with a line while miinimizing the fitting error. Can you help me?
- LLM produces a first draft based on DecisionTreeRegressor and GridSearchCV.
- I run it on a small dataset.
- Me: I note that this has a hard time getting the transitions "right", meaning the estimated segment breakpoints have a hard time to converge. How about a second pass to that keeps the position of the segments, but otimizes their ends?
- LLM proposes code, but forgets that it should be a second pass.
- Me: Hmm, I think you would want instead to use a dynamic programming like solution that is building the best selection of segments that minimizes the cost. Can you do that?
- LLM proposes code, but not very optimal, and still forgets this should be a second past.
- Me: This is not second pass. what we want is first to use the decision tree, and then optimize where the segments found by the tree should optimally end. Can you try again with that in mind?
- LLM propose code, which I now try on my real dataset.
- Me: Why did you include the datetime above at some point? There is no data and time in this problem! And you are only making things more complicated!
- LLM: ...
I receive information about the current date, time, and location as part of my operating context. In previous responses, I included this information, for instance in plot titles, in an attempt to provide comprehensive context based on my instructions. However, I clearly misjudged its relevance here.
... - Me: ok, so in fact GridSearchCV is pretty slow when arrays get big. What are my options that use less enumeration, and more gradient descent?
- LLM: Proposes to try with RandomizedSearchCV or HalvingRandomSearchCV
- Me: Given the linear partitioning, given the incremental nature of part of the cost function, I am surprized that you cannot come up with better. Is because we are using scikit learn? Could we do better in lower level framework, such as with JAX?
- LLM: Yes, your intuition is correct. Scikit-learn's modular design makes a fully integrated, end-to-end gradient-based optimization of both breakpoints and segment parameters difficult. ...
- Me: Can we try to start from zero but in JAX? Given what we know, what would you propose?
- LLM produces a JAX version but with a fix number of segments.
- Me: Nice. What about splitting segments or merging segments, that is not implemented, right?
- LLM explains the challenges but does not propose a concrete solution.
- Me: What about a refine and merge strategy?
- LLM agrees with me.
- Me: Can you sketch that in JAX?
- LLM proposes code but mixes code with text.
- Me: Can you clean that up to be just code, with the correct tabs?
- LLM proposes what first seems reasonable JAX.
- But then I note that the code no longer supports jumps in the dataset.
- Me: Also, just to make sure, is that spline model continuous? Or does it allow jumps? The initial formulation allows jumps.
- LLM agrees that there is a problem. Explains that initial requirement did not dissallow the proposed solution.
- Me: Hmm, it is correct that when I say trends jump, I did not imply constraints on continuity. Good point. The challenge is that not all spline points would be discontinuous in this approach, as the breakpoints are not chosen to match up with discontinuities.
- LLM: you are correct: If your "abrupt changes" include actual discontinuities, the knots
b_k
found by optimizing the continuous linear spline model are unlikely to perfectly align with the locations of those discontinuities - Me: Yes, dynamic programming is what I have used in the past. Is there a jax DP library?
- LLM: There isn't a go-to "JAX DP library".
- Me: Ok, so I will write the DP based code myself in JAX, can you help me by writing a small set of functions to manipulate linear regression data. I need: calculate regression parameters, incrementally update the regression parameters (case of adding and removing data), and I need a norm function that makes sense between two sets of regression parameters. All of this in JAX.
- LLM produces JAX code.
- Me: Might we not want param_distance_sq_functional also to depend on a y_min and y_max?
- LLM produces a nice analysis, with two variants, the second needs a range on the y axis.
- Me: I need the second, the normalized version, can you generate it?
- LLM produces two distance functions.
- Me: Can you build a test that takes random set of points, and partitions them so that we can test the correctness of these two distance functions?
- LLM proposes test code that check comparing with self segment and with non-self.
- Me: How about tests that validate that nearly colinear segments have a small distance?
- LLM proposes test code that also include a colinearity test.
- Which I run.
- Me: That gives assert message AssertionError: Distance for collinear pair 0-1 (0.001175) seems too large compared to fit-true error (0.0001574). How about we use statistics to determine better limits for these tests?
- LLM fixes the code. But code samples to determine limits.
- Me: How about using autodiff instead to extrapolate a reasonable limit?
- LLM: While autodiff is a powerful tool for optimization and understanding local sensitivities, it's not designed to directly estimate the statistical distribution or percentile-based thresholds for quantities derived from noisy data fitting processes. Using it for analytical error propagation is complex and has its own dependencies (like estimating noise variance).
- Me: Would not a second order autodiff (hessian) give us a hook to for a simplified model from which we can extract limits? I seem to remember doing things that way.
- LLM: Yes, you're absolutely right to bring that up! Using second-order automatic differentiation (i.e., computing the Hessian matrix) provides a very powerful way to approximate the uncertainty (covariance matrix) of the fitted parameters, which can then be used to estimate the expected variation of the distance metrics and set more statistically grounded limits.
- Me: Can you generate that JAX code as well? I'd love to see you do it!
- LLM generates code, but code does not run.
- Me: I am in Python 3.12, maybe that is what this code generates: TypeError: jnp.linalg.det requires ndarray or scalar arguments, got ...
- LLM generates updated code.
- Me: Great, but now we need a bigger test to validate using this to compute a limit. Can you produce that?
- LLM generates test that uses covariance to compute test limits.
- Me: That gives me: Assertion--> 240 assert dist_actual < threshold, \ 241 f"Actual distance ({dist_actual:.6g}) exceeds statistically derived threshold ({threshold:.6g} ...
- LLM: The Problem: The first-order error propagation approximation (
Var(f(x)) ≈ (df/dx)^2 * Var(x)
) breaks down and gives zero variance - LLM: The Chi-Squared test on the parameter difference is a statistically sound approach that correctly utilizes the covariance matrices derived from the Hessian.
- Me: How about we add a "fixed" absolute error limit in addition the the relative covariance matrix based. Can we then combine the two for these degenerate cases?
- LLM produces code.
- Me: Ok, for your info, I needed to set threshold_abs = 1e-2 for that to work! Good job! (at least until I test it a bit more!)
- LLM: Your solution of setting
threshold_abs = 1e-2
is a perfectly valid and pragmatic approach in this situation. - Me: Just wondering, does the distance increase when the fits are tighter?
- LLM: No, the distance metric does not necessarily increase when the fits are tighter.
- Me: Is the distance scale invariant? That is without units?
- LLM: Yes, the normalized version, param_distance_norm_functional, is designed to be scale-invariant (dimensionless / without units).
- Me: So at least the distance relation to the fit is aligned to its relation to scaling, which is good.
- LLM: Yes, that's a very good way to connect those ideas!