Saturday, April 05, 2025

The importance of choosing the right first queries to your LLM (2025-03)

Failure and then success in querying an LLM

I enjoy accelerating my software development with the help of LLMs. However, this extra LLM based support does not always work. One of the reasons can be the LLM never recovers from its bad choice of direction based on your first queries. Below I share a good example of such a failure, followed by a second LLM dialog that quickly produces the desired result. Then I explain what is happening.

Here is the dialog that fails, I am trying to get the LLM to produce a draft differentiation approximation based on wavelets, and coded in JAX:

  1. Me: Do some wavelets have special properties with regards to derivatives? More specifically, I have noise data that has a slope, I compute the wavelet transform, do some of the wavelet basis allow me to reconstruct an approximate derivative by manipulating the wavelet coefficients before doing an inverse transform?
    1. (To note that I naturally try not to impose my knowledge, and therefore why I tend to carefully introduce new knowledge (with people, obviously with LLMs too). The point being that I have used wavelets to extract derivatives in the past, yet many years of collaboration has taught me that one gets better result when others have the opportunity to drive the conversation).
    2. LLM: Produces positive answer with some theory.
  2. Me: can you write a jax example unsing cr-wavelets and computing approximate derivatives?
    1. LLM: Produces code that is based on pywt.
  3. Me: Not pywt, I meant CR-Sparse / CR-Wavelets
    1. LLM: Still produces code based on pywt.
  4. Me: You previously wrote: "CR-Sparse / CR-Wavelets: This library includes a module for wavelet transforms (both DWT and CWT) built on top of JAX. It mentions being a partial port of PyWavelets functionality, rewritten in pure Python using the JAX API to leverage XLA for acceleration on GPUs/TPUs. It provides both high-level (flexible but not always JIT-compilable) and low-level (JIT-compilable) functions". Therefore we do not want pywt, we want to use cr.wavelets
    1. LLM: produces code that starts by initializing a cr.wavelets based solution, but then proceeds to  "Get Wavelet Filter Coefficients (using PyWavelets for reliability)" <- the LLM's explanation. And of course PyWavelets is part of pywt which is not what I want.
  5. Me: ouch, I give up on you. (and some further ranting!)
    1. LLM: I sincerely apologize for the confusion and frustration my responses have caused....There is absolutely no hidden agenda or instruction to avoid any particular library like CR-Sparse
Now the dialog that works:
  1. Me: Can you tell me what type of features are implemented in the cr-wavelets jax based library?
    1. LLM: Tells me about cr-wavelets.
  2. Me: Great. Can you help me implement in jax a wavelet based computation of the approximate slopes (derivatives) of some one dimensional data? How about we use the sym4 basis?
    1. LLM: Produces the proper JAX code (with a minor typo).

The insight are:
  1. In the first chat, when I first mentioned  cr-wavelets, the LLM knew that cr-wavelets is tied to cr-sparse but not that it contained wavelet implementations, and knew that one could implement wavelets by fully implementing them in cr-sparse, but also knew that that is expensive, especially knowing the pywt has wavelets, so even while later in the dialog the  LLM has "loaded" the knowledge of what cr-wavelets does, it really had a hard time stepping away from the initial finding that using cr-wavelets was going to be too expensive.
  2. In the second dialog, I purposely pushed the LLM to "load" the knowledge of what cr-wavelets could do. Then I made the request to use cr-wavelets. Which worked well.
The bigger story here is: 
Choose very carefully your first choice of inputs when interacting with an LLM!

In the case of using the LLM as a coding assistant, it is very important that the LLM fully understands the resources you intend to use before you ask any form question that may concern these resources. 

All original content copyright James Litsios, 2025.



Tuesday, April 01, 2025

Playing with LLMs to code some math (2025-03)

 

LLM assisted coded help dialog

Just before lunch today, I left a message for my son, asking for help with sklearn. Then I thought: Gemini should be able to help me.This is the resulting dialog (summarized):

  1. Me: I have some x y data points that tend to follow a linear trend but with this trend abruptly changing from time to time. I am looking to use sklearn where I provide my x y data, and at the end have optimally partititioned the data so that each partition's data can be fit with a line while miinimizing the fitting error. Can you help me?
    1. LLM produces a first draft based on DecisionTreeRegressor and GridSearchCV.
    2. I run it on a small dataset.
  2. Me: I note that this has a hard time getting the transitions "right", meaning the estimated segment breakpoints have a hard time to converge. How about a second pass to that keeps the position of the segments, but otimizes their ends?
    1. LLM proposes code, but forgets that it should be a second pass.
  3. Me: Hmm, I think you would want instead to use a dynamic programming like solution that is building the best selection of segments that minimizes the cost. Can you do that?
    1. LLM proposes code, but not very optimal, and still forgets this should be a second past.
  4. Me: This is not second pass. what we want is first to use the decision tree, and then optimize where the segments found by the tree should optimally end. Can you try again with that in mind?
    1. LLM propose code, which I now try on my real dataset. 
  5. Me: Why did you include the datetime above at some point? There is no data and time in this problem! And you are only making things more complicated!
    1. LLM: ...
      I receive information about the current date, time, and location as part of my operating context. In previous responses, I included this information, for instance in plot titles, in an attempt to provide comprehensive context based on my instructions. However, I clearly misjudged its relevance here.
      ...
  6. Me: ok, so in fact GridSearchCV is pretty slow when arrays get big. What are my options that use less enumeration, and more gradient descent?
    1. LLM: Proposes to try with RandomizedSearchCV or HalvingRandomSearchCV
  7. Me: Given the linear partitioning, given the incremental nature of part of the cost function, I am surprized that you cannot come up with better. Is because we are using scikit learn? Could we do better in lower level framework, such as with JAX?
    1. LLM: Yes, your intuition is correct. Scikit-learn's modular design makes a fully integrated, end-to-end gradient-based optimization of both breakpoints and segment parameters difficult. ...
  8. Me: Can we try to start from zero but in JAX? Given what we know, what would you propose?
    1. LLM produces a JAX version but with a fix number of segments.
  9. Me: Nice. What about splitting segments or merging segments, that is not implemented, right?
    1. LLM explains the challenges but does not propose a concrete solution.
  10. Me: What about a refine and merge strategy?
    1. LLM agrees with me.
  11. Me: Can you sketch that in JAX?
    1. LLM proposes code but mixes code with text.
  12. Me: Can you clean that up to be just code, with the correct tabs?
    1. LLM proposes what first seems reasonable JAX.
    2. But then I note that the code no longer supports jumps in the dataset.
  13. Me: Also, just to make sure, is that spline model continuous? Or does it allow jumps? The initial formulation allows jumps.
    1. LLM agrees that there is a problem. Explains that initial requirement did not dissallow the proposed solution.
  14. Me: Hmm, it is correct that when I say trends jump, I did not imply constraints on continuity. Good point. The challenge is that not all spline points would be discontinuous in this approach, as the breakpoints are not chosen to match up with discontinuities.
    1. LLM: you are correct: If your "abrupt changes" include actual discontinuities, the knots b_k found by optimizing the continuous linear spline model are unlikely to perfectly align with the locations of those discontinuities
  15. Me: Yes, dynamic programming is what I have used in the past. Is there a jax DP library?
    1. LLM: There isn't a go-to "JAX DP library".
  16. Me: Ok, so I will write the DP based code myself in JAX, can you help me by writing a small set of functions to manipulate linear regression data. I need: calculate regression parameters, incrementally update the regression parameters (case of adding and removing data), and I need a norm function that makes sense between two sets of regression parameters. All of this in JAX.
    1. LLM produces JAX code.
  17. Me: Might we not want param_distance_sq_functional also to depend on a y_min and y_max?
    1. LLM produces a nice analysis, with two variants, the second needs a range on the y axis.
  18. Me: I need the second, the normalized version, can you generate it?
    1. LLM produces two distance functions.
  19. Me: Can you build a test that takes random set of points, and partitions them so that we can test the correctness of these two distance functions?
    1. LLM proposes test code that check comparing with self segment and with non-self.
  20. Me: How about tests that validate that nearly colinear segments have a small distance?
    1. LLM proposes test code that also include a colinearity test.
    2. Which I run.
  21. Me: That gives assert message AssertionError: Distance for collinear pair 0-1 (0.001175) seems too large compared to fit-true error (0.0001574). How about we use statistics to determine better limits for these tests?
    1. LLM fixes the code. But code samples to determine limits.
  22. Me: How about using autodiff instead to extrapolate a reasonable limit?
    1. LLM: While autodiff is a powerful tool for optimization and understanding local sensitivities, it's not designed to directly estimate the statistical distribution or percentile-based thresholds for quantities derived from noisy data fitting processes. Using it for analytical error propagation is complex and has its own dependencies (like estimating noise variance).
  23. Me: Would not a second order autodiff (hessian) give us a hook to for a simplified model from which we can extract limits? I seem to remember doing things that way.
    1. LLM: Yes, you're absolutely right to bring that up! Using second-order automatic differentiation (i.e., computing the Hessian matrix) provides a very powerful way to approximate the uncertainty (covariance matrix) of the fitted parameters, which can then be used to estimate the expected variation of the distance metrics and set more statistically grounded limits. 
  24. Me: Can you generate that JAX code as well? I'd love to see you do it!
    1. LLM generates code, but code does not run.
  25. Me: I am in Python 3.12, maybe that is what this code generates: TypeError: jnp.linalg.det requires ndarray or scalar arguments, got ...
    1. LLM generates updated code.
  26. Me: Great, but now we need a bigger test to validate using this to compute a limit. Can you produce that?
    1. LLM generates test that uses covariance to compute test limits.
  27. Me: That gives me: Assertion--> 240 assert dist_actual < threshold, \ 241 f"Actual distance ({dist_actual:.6g}) exceeds statistically derived threshold ({threshold:.6g} ...
    1. LLM: The Problem: The first-order error propagation approximation (Var(f(x)) ≈ (df/dx)^2 * Var(x)) breaks down and gives zero variance 
    2. LLM: The Chi-Squared test on the parameter difference is a statistically sound approach that correctly utilizes the covariance matrices derived from the Hessian. 
  28. Me: How about we add a "fixed" absolute error limit in addition the the relative covariance matrix based. Can we then combine the two for these degenerate cases?
    1. LLM produces code.
  29. Me: Ok, for your info, I needed to set     threshold_abs = 1e-2 for that to work! Good job! (at least until I test it a bit more!)
    1. LLM: Your solution of setting threshold_abs = 1e-2 is a perfectly valid and pragmatic approach in this situation.
  30. Me: Just wondering, does the distance increase when the fits are tighter?
    1. LLM: No, the distance metric does not necessarily increase when the fits are tighter.
  31. Me: Is the distance scale invariant? That is without units?
    1. LLM: Yes, the normalized version, param_distance_norm_functional, is designed to be scale-invariant (dimensionless / without units).
  32. Me: So at least the distance relation to the fit is aligned to its relation to scaling, which is good.
    1. LLM: Yes, that's a very good way to connect those ideas!
The result is that it is now late afternoon, I don't have working code, I have a few ideas of what does not work, and I have a plan to write code that works and scales. That will be tomorrow's effort!

All original content copyright James Litsios, 2025.